Rewrite maybe_reduce more carefully for unbacked SymInt (#119562)

Fixes https://github.com/pytorch/pytorch/issues/119476

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119562
Approved by: https://github.com/albanD
ghstack dependencies: #119559
This commit is contained in:
Edward Z. Yang
2024-02-13 11:19:12 -05:00
committed by PyTorch MergeBot
parent 28f299a870
commit 6665b96ebb
4 changed files with 91 additions and 18 deletions

View File

@ -413,4 +413,11 @@ inline SymBool sym_ge(const SymInt& a, const SymInt& b) {
return a.sym_ge(b);
}
inline bool definitely_true(
const c10::SymBool& b,
const char* file,
int64_t line) {
return b.has_hint() && b.guard_bool(file, line);
}
} // namespace c10

View File

@ -1,5 +1,6 @@
#include <c10/core/Contiguity.h>
#include <c10/core/MemoryFormat.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/core/SymbolicShapeMeta.h>
@ -130,17 +131,13 @@ DEFINE_SYMBOOL_COMPUTE(compute_non_overlapping_and_dense, is_non_overlapping_and
// test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 to run
// very slowly.
static bool definitely_true(const SymBool& b) {
return b.has_hint() && b.guard_bool(__FILE__, __LINE__);
}
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
init_is_contiguous();
if (definitely_true(is_contiguous())) {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous())) {
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | is_channels_last_contiguous() |
@ -149,7 +146,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim4() const {
SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
init_is_channels_last_contiguous();
if (definitely_true(is_channels_last_contiguous())) {
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last_contiguous() & compute_channels_last_contiguous_3d();
@ -157,7 +154,7 @@ SymBool SymbolicShapeMeta::compute_channels_last_contiguous_3d_dim5() const {
SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
init_is_channels_last_3d_contiguous();
if (definitely_true(is_channels_last_3d_contiguous())) {
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last_3d_contiguous() &
@ -165,20 +162,20 @@ SymBool SymbolicShapeMeta::compute_channels_last_2d_dim5() const {
}
SymBool SymbolicShapeMeta::compute_channels_last_3d_dim5() const {
if (definitely_true(is_channels_last())) {
if (definitely_true(is_channels_last(), __FILE__, __LINE__)) {
return false;
}
return ~is_channels_last() & compute_strides_like_channels_last_3d();
}
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
if (definitely_true(is_contiguous())) {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
if (definitely_true(is_channels_last_contiguous())) {
if (definitely_true(is_channels_last_contiguous(), __FILE__, __LINE__)) {
return true;
}
if (definitely_true(is_channels_last_3d_contiguous())) {
if (definitely_true(is_channels_last_3d_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | is_channels_last_contiguous() |
@ -186,7 +183,7 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_dim5() const {
}
SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const {
if (definitely_true(is_contiguous())) {
if (definitely_true(is_contiguous(), __FILE__, __LINE__)) {
return true;
}
return is_contiguous() | compute_non_overlapping_and_dense();

View File

@ -8257,6 +8257,29 @@ def ___make_guard_fn():
self.assertEqual(msg, "shape torch.Size([8, 8]) batch size 1.00")
self.assertEqual(res, img1 + torch.sin(img1))
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_validate_outputs_unbacked(self):
class SillyCat(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, i):
ctx.save_for_backward(i)
return torch.cat([x0, x1])
@staticmethod
def backward(ctx, grad_out):
(i,) = ctx.saved_tensors
i0, i1 = i.tolist()
g_x0, g_x1 = grad_out.split([i0, i1])
return g_x0, g_x1, None
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x, i):
i0, i1 = i.tolist()
x0, x1 = x.split([i0, i1])
return SillyCat.apply(x0, x1, i)
f(torch.randn(9, requires_grad=True), torch.tensor([3, 6]))
def test_str_format_assert1(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(img):

View File

@ -57,13 +57,59 @@ at::Tensor InputMetadata::maybe_reduce(
const size_t i,
at::Tensor grad,
const std::function<std::string(const std::string&)>& format_error) const {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
auto fail = [&]() {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
};
// Nested tensor makes my brain explode, so I've just hard-coded the logic
// for this case, at risk of code duplication. This logic does NOT do the
// careful oblivious logic as seen below
if (is_nested_ || is_cpp_nested_tensor() || grad.is_nested() ||
::torch::autograd::is_cpp_nested_tensor(grad)) {
if (!is_same_shape(grad)) {
if (is_expandable_to_shape(grad)) {
return reduce_grad(grad);
} else {
fail();
}
} else {
const auto message = incompatible_shape_error_message(i, grad);
TORCH_CHECK(false, format_error(message.str()));
return grad;
}
}
auto shape = shape_as_dim_vector();
auto desired = grad.sym_sizes();
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
fail();
}
bool needs_reduce = false;
for (const auto i : c10::irange(ndim)) {
const auto& size = shape[ndim - i - 1];
const auto& target = desired[target_dim - i - 1];
// The conditions here are written carefully so that we are able to
// infer deferred runtime asserts
if (TORCH_GUARD_SIZE_OBLIVIOUS(size.sym_eq(1))) {
// NB: we could short circuit this once needs_reduce is true but there's
// no point since the reduction function will guard on this anyway
if (!c10::definitely_true(size.sym_eq(target), __FILE__, __LINE__)) {
needs_reduce = true;
}
} else {
if (!size.sym_eq(target).expect_true(__FILE__, __LINE__)) {
fail();
}
}
}
if (ndim != target_dim) {
needs_reduce = true;
}
if (needs_reduce) {
return reduce_grad(grad);
} else {
return grad;
}