mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Don't aggressively rewrite asserts for symbolic expressions (#120564)
Fixes: https://github.com/pytorch/pytorch/issues/118417 Pull Request resolved: https://github.com/pytorch/pytorch/pull/120564 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
c844b377fa
commit
f01a23d01b
@ -2921,11 +2921,11 @@ def forward(self, x):
|
||||
@config.patch(assume_static_by_default=False)
|
||||
def test_export_persist_assert(self):
|
||||
def f(x):
|
||||
assert x.shape[0] > 4, "Shape must be more than 4"
|
||||
assert x[0].sum() > 4, "Shape must be more than 4"
|
||||
return x.cos() + x.sin()
|
||||
|
||||
gm, guard = torch._dynamo.export(f, aten_graph=True, tracing_mode="symbolic")(
|
||||
torch.randn(5, 4, 6)
|
||||
torch.ones(5, 4, 6)
|
||||
)
|
||||
|
||||
def has_aten_op(gm, op):
|
||||
@ -2941,7 +2941,7 @@ def forward(self, x):
|
||||
self.assertTrue(has_aten_op(gm, torch.ops.aten._assert_async.msg))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "Shape must be more than 4"):
|
||||
gm(torch.randn(3, 4, 5))
|
||||
gm(torch.zeros(3, 4, 5))
|
||||
|
||||
@common_utils.parametrize(
|
||||
"type_fn",
|
||||
|
@ -4217,6 +4217,71 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
||||
T = IncByTwo
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
|
||||
def test_dont_aggressively_write_assert(self):
|
||||
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
|
||||
|
||||
@torch.compile(dynamic=True, backend=record_graph)
|
||||
def f(x):
|
||||
assert x.shape[0] > 3
|
||||
assert x[0].sum() > 0
|
||||
assert 1 % (x.shape[0] // 2) != 0
|
||||
assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
|
||||
return x.cos()
|
||||
|
||||
f(torch.ones(6, 4))
|
||||
graph = record_graph.graphs[0]
|
||||
# It is bit annoying that we generate useless statements for
|
||||
# shape guards, but DCE should be able to remove them since t
|
||||
# there is no backed assert on them. The reason this is ok is
|
||||
# because dynamo will only skip the assert statement, but not
|
||||
# the instructions before it.
|
||||
self.assertExpectedInline(
|
||||
str(graph.code).strip(),
|
||||
"""\
|
||||
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
l_x_ = L_x_
|
||||
size = l_x_.size()
|
||||
getitem = size[0]; size = None
|
||||
gt = getitem > 3; getitem = None
|
||||
getitem_2 = l_x_[0]
|
||||
sum_1 = getitem_2.sum(); getitem_2 = None
|
||||
gt_1 = sum_1 > 0; sum_1 = None
|
||||
_assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = None
|
||||
size_1 = l_x_.size()
|
||||
getitem_3 = size_1[0]; size_1 = None
|
||||
floordiv = getitem_3 // 2; getitem_3 = None
|
||||
mod = 1 % floordiv; floordiv = None
|
||||
ne = mod != 0; mod = None
|
||||
size_2 = l_x_.size()
|
||||
getitem_5 = size_2[0]; size_2 = None
|
||||
floordiv_1 = getitem_5 // 2; getitem_5 = None
|
||||
pow_1 = floordiv_1 ** 2; floordiv_1 = None
|
||||
mul = 32 * pow_1; pow_1 = None
|
||||
size_3 = l_x_.size()
|
||||
getitem_7 = size_3[0]; size_3 = None
|
||||
floordiv_2 = getitem_7 // 2; getitem_7 = None
|
||||
mul_1 = 16 * floordiv_2; floordiv_2 = None
|
||||
sub = mul - mul_1; mul = mul_1 = None
|
||||
ne_1 = sub != 0; sub = None
|
||||
cos = l_x_.cos(); l_x_ = None
|
||||
return (cos,)""",
|
||||
)
|
||||
for node in graph.graph.nodes:
|
||||
if "example_value" in node.meta and isinstance(
|
||||
node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
|
||||
):
|
||||
shape_env = node.meta["example_value"].fake_mode.shape_env
|
||||
lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
|
||||
self.assertTrue(lower_ranges == [4, 2])
|
||||
|
||||
@torch.compile(dynamic=True, backend=record_graph)
|
||||
def f_fail(x):
|
||||
assert x.shape[0] < 3
|
||||
|
||||
# We graph-break here, so the failure should be eager
|
||||
with self.assertRaisesRegex(AssertionError, ""):
|
||||
f_fail(torch.ones(6, 4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -341,6 +341,21 @@ def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool):
|
||||
self.jump(inst)
|
||||
return
|
||||
|
||||
if isinstance(value, SymNodeVariable):
|
||||
# if the assertion is normal shape expression.
|
||||
# just install guard and bail out.
|
||||
sym_expr = value.sym_num
|
||||
if not isinstance(sym_expr, torch.SymBool):
|
||||
sym_expr = sym_expr != 0
|
||||
|
||||
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
|
||||
if not result:
|
||||
raise unimplemented(
|
||||
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?"
|
||||
)
|
||||
self.jump(inst)
|
||||
return
|
||||
|
||||
scalar_to_tensor_proxy = self.output.create_proxy(
|
||||
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
|
||||
)
|
||||
|
Reference in New Issue
Block a user