Don't aggressively rewrite asserts for symbolic expressions (#120564)

Fixes: https://github.com/pytorch/pytorch/issues/118417

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120564
Approved by: https://github.com/ezyang
This commit is contained in:
Tugsbayasgalan Manlaibaatar
2024-02-29 18:00:07 -08:00
committed by PyTorch MergeBot
parent c844b377fa
commit f01a23d01b
3 changed files with 83 additions and 3 deletions

View File

@ -2921,11 +2921,11 @@ def forward(self, x):
@config.patch(assume_static_by_default=False)
def test_export_persist_assert(self):
def f(x):
assert x.shape[0] > 4, "Shape must be more than 4"
assert x[0].sum() > 4, "Shape must be more than 4"
return x.cos() + x.sin()
gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
torch.randn(5, 4, 6)
torch.ones(5, 4, 6)
)
def has_aten_op(gm, op):
@ -2941,7 +2941,7 @@ def forward(self, x):
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
gm(torch.randn(3, 4, 5))
gm(torch.zeros(3, 4, 5))
@common_utils.parametrize(
"type_fn",

View File

@ -4217,6 +4217,71 @@ class ReproTests(torch._dynamo.test_case.TestCase):
T = IncByTwo
self.assertEqual(fn(x), opt_fn(x))
def test_dont_aggressively_write_assert(self):
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
@torch.compile(dynamic=True, backend=record_graph)
def f(x):
assert x.shape[0] > 3
assert x[0].sum() > 0
assert 1 % (x.shape[0] // 2) != 0
assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
return x.cos()
f(torch.ones(6, 4))
graph = record_graph.graphs[0]
# It is bit annoying that we generate useless statements for
# shape guards, but DCE should be able to remove them since t
# there is no backed assert on them. The reason this is ok is
# because dynamo will only skip the assert statement, but not
# the instructions before it.
self.assertExpectedInline(
str(graph.code).strip(),
"""\
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
l_x_ = L_x_
size = l_x_.size()
getitem = size[0]; size = None
gt = getitem > 3; getitem = None
getitem_2 = l_x_[0]
sum_1 = getitem_2.sum(); getitem_2 = None
gt_1 = sum_1 > 0; sum_1 = None
_assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = None
size_1 = l_x_.size()
getitem_3 = size_1[0]; size_1 = None
floordiv = getitem_3 // 2; getitem_3 = None
mod = 1 % floordiv; floordiv = None
ne = mod != 0; mod = None
size_2 = l_x_.size()
getitem_5 = size_2[0]; size_2 = None
floordiv_1 = getitem_5 // 2; getitem_5 = None
pow_1 = floordiv_1 ** 2; floordiv_1 = None
mul = 32 * pow_1; pow_1 = None
size_3 = l_x_.size()
getitem_7 = size_3[0]; size_3 = None
floordiv_2 = getitem_7 // 2; getitem_7 = None
mul_1 = 16 * floordiv_2; floordiv_2 = None
sub = mul - mul_1; mul = mul_1 = None
ne_1 = sub != 0; sub = None
cos = l_x_.cos(); l_x_ = None
return (cos,)""",
)
for node in graph.graph.nodes:
if "example_value" in node.meta and isinstance(
node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
):
shape_env = node.meta["example_value"].fake_mode.shape_env
lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
self.assertTrue(lower_ranges == [4, 2])
@torch.compile(dynamic=True, backend=record_graph)
def f_fail(x):
assert x.shape[0] < 3
# We graph-break here, so the failure should be eager
with self.assertRaisesRegex(AssertionError, ""):
f_fail(torch.ones(6, 4))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -341,6 +341,21 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
self.jump(inst)
return
if isinstance(value, SymNodeVariable):
# if the assertion is normal shape expression.
# just install guard and bail out.
sym_expr = value.sym_num
if not isinstance(sym_expr, torch.SymBool):
sym_expr = sym_expr != 0
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
if not result:
raise unimplemented(
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?"
)
self.jump(inst)
return
scalar_to_tensor_proxy = self.output.create_proxy(
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
)