mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rewrite maybe_reduce more carefully for unbacked SymInt (#119562)
Fixes https://github.com/pytorch/pytorch/issues/119476 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/119562 Approved by: https://github.com/albanD ghstack dependencies: #119559
This commit is contained in:
committed by
PyTorch MergeBot
parent
28f299a870
commit
6665b96ebb
@ -57,13 +57,59 @@ at::Tensor InputMetadata::maybe_reduce(
|
||||
const size_t i,
|
||||
at::Tensor grad,
|
||||
const std::function<std::string(const std::string&)>& format_error) const {
|
||||
if (!is_same_shape(grad)) {
|
||||
if (is_expandable_to_shape(grad)) {
|
||||
return reduce_grad(grad);
|
||||
auto fail = [&]() {
|
||||
const auto message = incompatible_shape_error_message(i, grad);
|
||||
TORCH_CHECK(false, format_error(message.str()));
|
||||
};
|
||||
|
||||
// Nested tensor makes my brain explode, so I've just hard-coded the logic
|
||||
// for this case, at risk of code duplication. This logic does NOT do the
|
||||
// careful oblivious logic as seen below
|
||||
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
|
||||
::torch::autograd::is_cpp_nested_tensor(grad)) {
|
||||
if (!is_same_shape(grad)) {
|
||||
if (is_expandable_to_shape(grad)) {
|
||||
return reduce_grad(grad);
|
||||
} else {
|
||||
fail();
|
||||
}
|
||||
} else {
|
||||
const auto message = incompatible_shape_error_message(i, grad);
|
||||
TORCH_CHECK(false, format_error(message.str()));
|
||||
return grad;
|
||||
}
|
||||
}
|
||||
|
||||
auto shape = shape_as_dim_vector();
|
||||
auto desired = grad.sym_sizes();
|
||||
|
||||
size_t ndim = shape.size();
|
||||
size_t target_dim = desired.size();
|
||||
if (ndim > target_dim) {
|
||||
fail();
|
||||
}
|
||||
bool needs_reduce = false;
|
||||
for (const auto i : c10::irange(ndim)) {
|
||||
const auto& size = shape[ndim - i - 1];
|
||||
const auto& target = desired[target_dim - i - 1];
|
||||
// The conditions here are written carefully so that we are able to
|
||||
// infer deferred runtime asserts
|
||||
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
|
||||
// NB: we could short circuit this once needs_reduce is true but there's
|
||||
// no point since the reduction function will guard on this anyway
|
||||
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
|
||||
needs_reduce = true;
|
||||
}
|
||||
} else {
|
||||
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
|
||||
fail();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ndim != target_dim) {
|
||||
needs_reduce = true;
|
||||
}
|
||||
|
||||
if (needs_reduce) {
|
||||
return reduce_grad(grad);
|
||||
} else {
|
||||
return grad;
|
||||
}
|
||||
|
Reference in New Issue
Block a user