[attempt 2] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#157472)

Summary:
When we compute contiguity for a tensor with dynamic shapes we first:
1) Try to compute it without guarding.
2) If all shapes hinted, compute it with potentially adding guards.
3) if any input is not hinted, compute it symbolically.

sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called
on it to avoid data dependent errors.

ex:
 bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__);
is_contiguous_or_false is a helper function that does that.

In this PR I only handle default contiguity, will follow up with changes for other formats like  channel_last .
We use this patter in this PR for several locations to avoid DDEs.

Test Plan:
contbuild & OSS CI,

Rollback Plan:

Reviewed By: malfet

Differential Revision: D77639021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472
Approved by: https://github.com/aorenste
This commit is contained in:
Laith Sakka
2025-07-02 23:12:29 +00:00
committed by PyTorch MergeBot
parent d40aaa42ee
commit 7cfd054075
34 changed files with 390 additions and 114 deletions

View File

@ -3336,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_19,)""", # noqa: B950
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_21,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
@ -3460,6 +3460,75 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@fresh_cache()
def test_unbacked_contiguous(self):
cnt = CompileCounterWithBackend("inductor")
def func(x):
contig = x.contiguous()
return (contig + 1) * 100
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_func(x)
self.assertEqual(compiled_func(x), func(x))
y = torch.rand(20, 20).t()
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 1)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
return (mul_6,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
# recompilation will happen due to stride specialization.
y = torch.rand(20, 20)
torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(y, 1)
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 2)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
# No clone this time since input is contiguous.
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
instantiate_parametrized_tests(TestUnbacked)