mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[attempt 2] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#157472)
Summary: When we compute contiguity for a tensor with dynamic shapes we first: 1) Try to compute it without guarding. 2) If all shapes hinted, compute it with potentially adding guards. 3) if any input is not hinted, compute it symbolically. sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called on it to avoid data dependent errors. ex: bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__); is_contiguous_or_false is a helper function that does that. In this PR I only handle default contiguity, will follow up with changes for other formats like channel_last . We use this patter in this PR for several locations to avoid DDEs. Test Plan: contbuild & OSS CI, Rollback Plan: Reviewed By: malfet Differential Revision: D77639021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
d40aaa42ee
commit
7cfd054075
@ -3336,8 +3336,8 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
|
||||
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
|
||||
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
|
||||
mul_19: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
return (mul_19,)""", # noqa: B950
|
||||
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
||||
return (mul_21,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
@ -3460,6 +3460,75 @@ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)",
|
||||
func(torch.ones(5, 6, 9, 8))
|
||||
self.assertEqual(cnt.frame_count, 3)
|
||||
|
||||
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
||||
@fresh_cache()
|
||||
def test_unbacked_contiguous(self):
|
||||
cnt = CompileCounterWithBackend("inductor")
|
||||
|
||||
def func(x):
|
||||
contig = x.contiguous()
|
||||
return (contig + 1) * 100
|
||||
|
||||
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
||||
|
||||
x = torch.randn(10, 10)
|
||||
# make x not contiguous.
|
||||
x = x.t_()
|
||||
torch._dynamo.decorators.mark_unbacked(x, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(x, 1)
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
compiled_func(x)
|
||||
self.assertEqual(compiled_func(x), func(x))
|
||||
y = torch.rand(20, 20).t()
|
||||
self.assertEqual(compiled_func(y), func(y))
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
self.assertExpectedInline(
|
||||
output,
|
||||
"""\
|
||||
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
|
||||
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
||||
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
|
||||
return (mul_6,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
log_stream, ctx = logs_to_string(
|
||||
"torch._inductor.compile_fx", "post_grad_graphs"
|
||||
)
|
||||
with ctx():
|
||||
# recompilation will happen due to stride specialization.
|
||||
y = torch.rand(20, 20)
|
||||
torch._dynamo.decorators.mark_unbacked(y, 0)
|
||||
torch._dynamo.decorators.mark_unbacked(y, 1)
|
||||
self.assertEqual(compiled_func(y), func(y))
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
||||
|
||||
# No clone this time since input is contiguous.
|
||||
self.assertExpectedInline(
|
||||
output,
|
||||
"""\
|
||||
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
||||
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
||||
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
||||
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
||||
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
|
||||
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
|
||||
return (mul_5,)""", # noqa: B950
|
||||
ignore_comments=True,
|
||||
ignore_empty_lines=True,
|
||||
)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestUnbacked)
|
||||
|
||||
|
Reference in New Issue
Block a user