Rewrite maybe_reduce more carefully for unbacked SymInt (#119562)

Fixes https://github.com/pytorch/pytorch/issues/119476

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119562
Approved by: https://github.com/albanD
ghstack dependencies: #119559
This commit is contained in:
Edward Z. Yang
2024-02-13 11:19:12 -05:00
committed by PyTorch MergeBot
parent 28f299a870
commit 6665b96ebb
4 changed files with 91 additions and 18 deletions

View File

@ -57,13 +57,59 @@ at::Tensor InputMetadata::maybe_reduce(
const size_t i,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
auto fail = [&]() {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
};
// Nested tensor makes my brain explode, so I've just hard-coded the logic
// for this case, at risk of code duplication. This logic does NOT do the
// careful oblivious logic as seen below
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
::torch::autograd::is_cpp_nested_tensor(grad)) {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
} else {
fail();
}
} else {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
return grad;
}
}
auto shape = shape_as_dim_vector();
auto desired = grad.sym_sizes();
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
fail();
}
bool needs_reduce = false;
for (const auto i : c10::irange(ndim)) {
const auto& size = shape[ndim - i - 1];
const auto& target = desired[target_dim - i - 1];
// The conditions here are written carefully so that we are able to
// infer deferred runtime asserts
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
// NB: we could short circuit this once needs_reduce is true but there's
// no point since the reduction function will guard on this anyway
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
needs_reduce = true;
}
} else {
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
fail();
}
}
}
if (ndim != target_dim) {
needs_reduce = true;
}
if (needs_reduce) {
return reduce_grad(grad);
} else {
return grad;
}