shape guards (#161178)

Summary: This PR introduces shape guards to export. Previously only value ranges,  equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program.

Test Plan:
updated several tests

Rollback Plan:

Differential Revision: D80713603

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Avik Chaudhuri
2025-09-08 22:44:05 +00:00
committed by PyTorch MergeBot
parent 2c538c9acf
commit 711c8c821e
29 changed files with 617 additions and 123 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: export"]
import copy
import re
import tempfile
import unittest
@ -407,7 +408,12 @@ class TestDraftExport(TestCase):
inp = (torch.ones(3, 3),)
ep = draft_export(M(), inp, dynamic_shapes={"a": {0: Dim("a0")}})
ep = draft_export(
M(),
inp,
dynamic_shapes={"a": {0: Dim("a0")}},
prefer_deferred_runtime_asserts_over_guards=True,
)
report = ep._report
self.assertEqual(len(report.failures), 1)
@ -417,7 +423,11 @@ class TestDraftExport(TestCase):
self.assertEqual(ep.module()(*inp), M()(*inp))
inp = (torch.randn(4, 3),)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(
AssertionError,
re.escape("Guard failed: a.size()[0] <= 3"),
):
# expected <= 3, but got 4
ep.module()(*inp)
def test_side_effect1(self):

View File

@ -319,6 +319,7 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
_guards_fn = self._guards_fn(x); _guards_fn = None
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
return pytree.tree_unflatten((linear,), self._out_spec)""",
)

View File

@ -13,7 +13,7 @@ import traceback
import unittest
import warnings
import weakref
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from re import escape
from typing import Dict, List, Union
@ -514,7 +514,13 @@ class TestExport(TestCase):
# )
def _check_dynamic_shapes_specs_and_shapes(
self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False
self,
model,
inputs,
specs,
passing_shapes,
failing_shapes,
test_serdes=False,
):
from torch._export.serde.dynamic_shapes import (
_dump_dynamic_shapes,
@ -556,7 +562,7 @@ class TestExport(TestCase):
ep.module()(*test_inputs)
for shapes in failing_shapes:
test_inputs = _construct_inputs(shapes)
with self.assertRaises(RuntimeError):
with self.assertRaisesRegex(AssertionError, "Guard failed"):
ep.module()(*test_inputs)
def test_basic(self):
@ -635,7 +641,7 @@ class TestExport(TestCase):
from torch.fx.traceback import NodeSourceAction
for node in gm.graph.nodes:
if node.op in ("placeholder", "output"):
if node.op in ("placeholder", "output", "call_module"):
continue
if "weight" in node.name or "bias" in node.name:
self.assertTrue(
@ -664,7 +670,7 @@ class TestExport(TestCase):
graph_id = id(ep2.graph)
for node in gm2.graph.nodes:
if node.op in ("placeholder", "output"):
if node.op in ("placeholder", "output", "call_module"):
continue
if "weight" in node.name or "bias" in node.name:
@ -927,7 +933,8 @@ graph():
"""\
graph():
%lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
%x : [num_users=1] = placeholder[target=x]
%x : [num_users=2] = placeholder[target=x]
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_0), kwargs = {})
return (add,)""",
)
@ -1491,7 +1498,11 @@ graph():
{"a": torch.zeros(5), "b": torch.ones(5)},
torch.ones(4),
)
with self.assertRaisesRegex(RuntimeError, "to be equal to 6, but got 5"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: ys[0].size()[0] == x.size()[0]"),
):
# expected 6, but got 5
ep_ns.module()(*bad_runtime_inp1)
bad_runtime_inp2 = (
@ -1501,9 +1512,10 @@ graph():
torch.ones(6),
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"),
AssertionError,
escape("Guard failed: c.size()[0] == 4"),
):
# expected 4, but got 6
ep_ns.module()(*bad_runtime_inp2)
good_runtime_inp = (
@ -1651,6 +1663,8 @@ class GraphModule(torch.nn.Module):
x: "f32[3, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
@ -1736,6 +1750,8 @@ class GraphModule(torch.nn.Module):
x: "f32[3, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
@ -3104,16 +3120,22 @@ def forward(self, causal_mask, fill_value):
ep = export(Foo(), inputs, dynamic_shapes=shapes)
ep.module()(torch.randn(8, 5), torch.randn(8, 5))
with self.assertRaisesRegex(
RuntimeError, "Expected input at .* to be >= 4, but got 3"
AssertionError,
escape("Guard failed: x.size()[0] >= 4"),
):
# expected >= 4, but got 3
ep.module()(torch.randn(3, 5), torch.randn(3, 5))
with self.assertRaisesRegex(
RuntimeError, "Expected input at .* to be <= 16, but got 17"
AssertionError,
escape("Guard failed: x.size()[0] <= 16"),
):
# expected <= 16, but got 17
ep.module()(torch.randn(17, 5), torch.randn(17, 5))
with self.assertRaisesRegex(
RuntimeError, "Expected input at .* to be <= 32, but got 33"
AssertionError,
escape("Guard failed: x.size()[1] <= 32"),
):
# expected <= 32, but got 33
ep.module()(torch.randn(9, 33), torch.randn(9, 33))
def test_dim_hint_range_violations(self):
@ -3368,11 +3390,12 @@ def forward(self, p_linear_weight, p_linear_bias, x):
actual_torch_fns = []
for mod in gm.modules():
for node in mod.graph.nodes:
if node.name in {"sin", "cos"}:
torch_fn = node.meta.get("torch_fn")
print(torch_fn)
actual_torch_fns.append(torch_fn)
if hasattr(mod, "graph"):
for node in mod.graph.nodes:
if node.name in {"sin", "cos"}:
torch_fn = node.meta.get("torch_fn")
print(torch_fn)
actual_torch_fns.append(torch_fn)
exp_torch_fns = [
("cos_1", "method_descriptor.cos"),
("sin_1", "method_descriptor.sin"),
@ -3545,9 +3568,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimx}, {0: dimy}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 5, but got 6",
AssertionError,
escape("Guard failed: x.size()[0] == -1 + y.size()[0]"),
):
# expected 5, but got 6
ep.module()(torch.randn(4), torch.randn(6))
self.assertEqual(ep.module()(torch.randn(4), torch.randn(5)).size()[0], 4)
@ -3606,13 +3630,16 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimz}, {0: dimy}),
)
with self.assertRaisesRegex(
RuntimeError, "Expected input.*shape.*to be <= 7, but got 8"
AssertionError,
escape("Guard failed: z.size()[0] <= 7"),
):
# expected <= 7, but got 8
ep.module()(torch.randn(8), torch.randn(15))
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 9, but got 8",
AssertionError,
escape("Guard failed: -1 + 2 * z.size()[0] == y.size()[0]"),
):
# expected 9, but got 8
ep.module()(torch.randn(5), torch.randn(8))
self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4)
@ -3648,17 +3675,18 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimw},),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*= 9 to be "
"of the form 2\\*s92, where s92 is an integer",
AssertionError,
escape("Guard failed: w.size()[0] % 2 == 0"),
):
# expected 2*..., got 9
ep.module()(torch.randn(9))
self.assertEqual(ep.module()(torch.randn(8)).size()[0], 4)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be <= 12, but got 14",
AssertionError,
escape("Guard failed: w.size()[0] <= 12"),
):
# expected <= 12, but got 14
ep.module()(torch.randn(14))
def test_derived_dim_repeat_derived(self):
@ -3696,9 +3724,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 8, but got 5",
AssertionError,
escape("Guard failed: z.size()[0] >= 6"),
):
# expected 8, but got 5
ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
self.assertEqual(
@ -3731,9 +3760,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 6, but got 5",
AssertionError,
escape("Guard failed: x2.size()[0] == x.size()[0]"),
):
# expected 6, but got 5
ep.module()(
torch.randn(6),
torch.randn(7),
@ -3759,9 +3789,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 6, but got 5",
AssertionError,
escape("Guard failed: x2.size()[0] == x.size()[0]"),
):
# expected 6, but got 5
ep.module()(
torch.randn(6),
torch.randn(7),
@ -4197,9 +4228,10 @@ def forward(self, x):
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 8, but got 5",
AssertionError,
escape("Guard failed: z.size()[0] >= 6"),
):
# expected 8, but got 5
ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
self.assertEqual(
@ -4234,6 +4266,7 @@ def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
_guards_fn = self._guards_fn(x); _guards_fn = None
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
return pytree.tree_unflatten((linear,), self._out_spec)""",
)
@ -4274,6 +4307,7 @@ def forward(self, b_buffer, x):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
buffer = self.buffer
_guards_fn = self._guards_fn(x); _guards_fn = None
add_ = torch.ops.aten.add_.Tensor(x, 5); x = None
add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None
add = torch.ops.aten.add.Tensor(add_, add__1); add_ = add__1 = None
@ -4595,9 +4629,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimy}, {0: dimz}),
)
with self.assertRaisesRegex(
RuntimeError,
"Expected input.*shape.*to be equal to 7, but got 5",
AssertionError,
escape("Guard failed: y1.size()[0] == y.size()[0]"),
):
# expected 7, but got 5
ep.module()(
torch.randn(6),
torch.randn(7),
@ -4649,8 +4684,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
ep = export(foo, inputs, dynamic_shapes=dynamic_shapes)
self.assertEqual(foo(*inputs), ep.module()(*inputs))
for wrong_inputs in wrong_shape_inputs:
with self.assertRaises(RuntimeError):
ep.module()(*wrong_inputs)
with self.assertRaisesRegex(AssertionError, "Guard failed"):
with self.assertRaises(RuntimeError):
ep.module()(*wrong_inputs)
# check range_constraints - static dims shouldn't be present
ep = export(foo, inputs, dynamic_shapes=((dx, None), (dy, 4), (dz, 3)))
@ -4686,8 +4722,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
ep.module()(torch.randn(1, 2))
ep.module()(torch.randn(2, 2))
with self.assertRaisesRegex(
RuntimeError, "Expected input at .* to be <= 2, but got 3"
AssertionError,
escape("Guard failed: x.size()[0] <= 2"),
):
# expected <= 2, but got 3
ep.module()(torch.randn(3, 2))
vr = list(ep.range_constraints.values())[0]
self.assertEqual(vr.lower, 1)
@ -4704,7 +4742,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
(torch.randn(2, 2), torch.randn(3, 2)),
dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}),
)
ep.module()(torch.randn(1, 2), torch.randn(2, 2))
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: -1 + y.size()[0] != 1"),
):
# TODO: this should not error?
ep.module()(torch.randn(1, 2), torch.randn(2, 2))
range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values())
range_upper_bounds = sorted(vr.upper for vr in ep.range_constraints.values())
self.assertEqual(range_lower_bounds, [1, 2])
@ -4894,7 +4937,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertEqual(got_shapes, expected_shapes)
def expect_error(bad_args, run_time_msg, compile_time_msg):
with self.assertRaisesRegex(RuntimeError, run_time_msg):
with self.assertRaisesRegex(AssertionError, run_time_msg):
ep.module()(*bad_args)
additional_inputs = torch.export.AdditionalInputs()
@ -4906,21 +4949,27 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
expect_error(
# 4->2, 4->2, 3->3
bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}),
run_time_msg="Expected input.*to be >= 3, but got 2",
run_time_msg=escape(
"Guard failed: x.size()[0] >= 3"
), # expected >= 3, but got 2
compile_time_msg="Expected input.*to be >= 3, but got 2",
)
expect_error(
# 4->6, 4->7, 3->3
bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}),
run_time_msg="Expected input.*to be equal to 6, but got 7",
run_time_msg=escape(
"Guard failed: y[0].size()[0] == x.size()[0]"
), # expected 6, but got 7
compile_time_msg="Expected input.*to be equal to 6, but got 7",
)
expect_error(
# 4->5, 4->5, 3->4
bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}),
run_time_msg="Expected input.*to be equal to 3, but got 4",
run_time_msg=escape(
"Guard failed: z['k'].size()[0] == 3"
), # expected 3, but got 4
compile_time_msg=r"You marked.*but your code specialized it to be a constant.*If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO",
)
@ -5569,7 +5618,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
self.assertTrue(torch.allclose(ep.module()(x, y), model(x, y)))
x2 = torch.arange(4).reshape((2, 2))
y2 = torch.arange(9).reshape((3, 3))
self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2)))
with self.assertRaisesRegex(
AssertionError,
(
escape("Guard failed: max(x.size()[1], y.size()[1]) == x.size()[1]")
if is_retracebility_test(self._testMethodName)
else escape(
"Guard failed: max(1, x.size()[1], y.size()[1]) == x.size()[1]"
)
),
):
# TODO: this should not error?
self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2)))
def test_export_max_nonstrict(self):
class FooMax(torch.nn.Module):
@ -5712,9 +5772,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
em = torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
x = torch.randn(3, 5)
with self.assertRaisesRegex(
RuntimeError,
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
AssertionError,
escape("Guard failed: 3 * x.size()[1] % 2 == 0"),
):
# expected 2*..., but got 5
em.module()(x)
def test_dont_duck_size_for_auto_dynamic(self):
@ -7814,6 +7875,7 @@ def forward(self, x):
bn_running_mean = self.bn.running_mean
bn_running_var = self.bn.running_var
bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None
_guards_fn = self._guards_fn(x); _guards_fn = None
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
return pytree.tree_unflatten((batch_norm,), self._out_spec)""",
@ -7833,6 +7895,7 @@ def forward(self, x):
bn_running_mean = self.bn.running_mean
bn_running_var = self.bn.running_var
bn_num_batches_tracked = self.bn.num_batches_tracked
_guards_fn = self._guards_fn(x); _guards_fn = None
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
add_ = torch.ops.aten.add_.Tensor(bn_num_batches_tracked, 1); bn_num_batches_tracked = add_ = None
batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
@ -8426,18 +8489,20 @@ def forward(self, x):
)
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1] to be equal to 5, but got 6"),
AssertionError,
escape("Guard failed: y == 5"),
):
# expected 5, but got 6
_ = exported.module()(torch.ones(8, 5), 6)
exported = torch.export.export(
foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"),
AssertionError,
escape("Guard failed: y == 5.0"),
):
# expected 5.0, but got 6.0
_ = exported.module()(torch.ones(7, 5), 6.0)
def test_runtime_assert_for_prm_str(self):
@ -8449,8 +8514,10 @@ def forward(self, x):
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
exported = export(foo, inps)
with self.assertRaisesRegex(
RuntimeError, "to be equal to trunc, but got floor"
AssertionError,
escape("Guard failed: mode == 'trunc'"),
):
# expected 'trunc', but got 'floor'
_ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor")
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
@ -8577,9 +8644,12 @@ def forward(self, x):
dim0_x = torch.export.Dim("dim0_x")
exported = torch.export.export(Foo(), (inp,), dynamic_shapes=({0: dim0_x},))
reexported = torch.export.export(exported.module(), (inp,))
with self.assertRaisesRegex(
RuntimeError, "shape\[0\] to be equal to 5, but got 7"
AssertionError,
escape("Guard failed: x.size()[0] == 5"),
):
# expected 5, but got 7
reexported.module()(torch.ones(7, 5))
reexported = torch.export.export(
@ -8597,9 +8667,10 @@ def forward(self, x):
Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"),
AssertionError,
escape("Guard failed: x.size()[0] >= 3"),
):
# expected >= 3, but got 2
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
def test_export_cond_symbool_pred(self):
@ -9394,8 +9465,10 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
dynamic_shapes=(None, None),
)
with self.assertRaisesRegex(
RuntimeError, "shape\[0\] to be equal to 4, but got 7"
AssertionError,
escape("Guard failed: b.size()[0] == 4"),
):
# expected 4, but got 7
ep_v2.module()(*test_inp)
def test_constant_output(self):
@ -9475,7 +9548,11 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: a[1].size()[0] >= 3"),
):
# expected >= 3, but got 2
ep.module()(*test_inp)
def test_nested_module(self):
@ -9677,13 +9754,17 @@ graph():
).module()
with self.assertRaisesRegex(
RuntimeError, escape("Expected input at *args[0].shape[0]")
AssertionError,
escape("Guard failed: x.size()[0] >= 3"),
):
# expected >= 3, got 2
gm(torch.randn(2, 2))
with self.assertRaisesRegex(
RuntimeError, escape("Expected input at *args[0].shape[0]")
AssertionError,
escape("Guard failed: x.size()[0] >= 3"),
):
# expected >= 3, got 2
export(gm, (torch.randn(2, 2),))
ep = export(
@ -11502,7 +11583,11 @@ graph():
ep = export(M(), (4, 5))
self.assertEqual(ep.module()(4, 5), 20)
with self.assertRaisesRegex(RuntimeError, r"to be equal to 4, but got 3"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x == 4"),
):
# expected 4, but got 3
self.assertEqual(ep.module()(3, 6), 18)
ep = export(M(), (4, 5), dynamic_shapes={"x": Dim.DYNAMIC, "y": Dim.AUTO})
@ -11515,7 +11600,11 @@ graph():
ep = export(M(), (5, 5), dynamic_shapes={"x": None, "y": Dim.AUTO})
self.assertEqual(ep.module()(5, 6), 30)
with self.assertRaisesRegex(RuntimeError, r"to be equal to 5, but got 3"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x == 5"),
):
# expected 5, but got 3
self.assertEqual(ep.module()(3, 5), 18)
class M(torch.nn.Module):
@ -11531,7 +11620,6 @@ graph():
self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp)))
@testing.expectedFailureCppRuntime
@testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3
def test_symint_input_specialization(self):
class M(torch.nn.Module):
def forward(self, x, y):
@ -11556,7 +11644,11 @@ graph():
inp,
dynamic_shapes=(Dim.AUTO, None),
)
with self.assertRaisesRegex(RuntimeError, "to be equal to 3, but got 4"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x == 3"),
):
# expected 3, but got 4
ep.module()(4, torch.randn(4, 4))
@testing.expectedFailureCppRuntime
@ -11573,9 +11665,17 @@ graph():
)
ep.module()(4, torch.randn(4, 4))
with self.assertRaisesRegex(RuntimeError, "to be <= 10, but got 16"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x <= 10"),
):
# expected <= 10, but got 16
ep.module()(16, torch.randn(4, 4))
with self.assertRaisesRegex(RuntimeError, "to be >= 3, but got 2"):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x >= 3"),
):
# expected >= 3, but got 2
ep.module()(2, torch.randn(4, 4))
# While tracing the range was found to be a subset of the original range
@ -11593,12 +11693,8 @@ graph():
)
constraints = list(ep.range_constraints.values())
constraint = constraints[0]
# retracebility does not remember the range asserts in the forward
lower, upper = (
(3, 10) if is_retracebility_test(self._testMethodName) else (4, 5)
)
self.assertEqual(constraint.lower, lower)
self.assertEqual(constraint.upper, upper)
self.assertEqual(constraint.lower, 4)
self.assertEqual(constraint.upper, 5)
# While tracing the range was found to be bigger than the original range
class M(torch.nn.Module):
@ -12260,6 +12356,7 @@ def forward(self, c_submod_params, x):
[
fqn
for fqn, _ in unflattened.named_modules(remove_duplicate=False)
if fqn != "_guards_fn"
]
),
expected_fqns,
@ -13648,9 +13745,12 @@ def forward(self, x, y):
else:
# no runtime assert in exported module but it fails anyway with wrong inputs
with self.assertRaisesRegex(
RuntimeError,
r"The size of tensor a \(40\) must match the size of tensor b \(20\) at non-singleton dimension 0",
AssertionError,
escape(
"Guard failed: x.size()[1] * x.size()[0] == y.size()[0] * y.size()[1]"
),
):
# expected 40, but got 20
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30))
# case 3: 3d reshape (previously failing with different issue)
@ -15168,9 +15268,12 @@ def forward(self, x):
self.assertEqual(num_asserts, 0)
# but it fails anyway with wrong inputs
with self.assertRaisesRegex(
RuntimeError,
r"shape '\[3, -1\]' is invalid for input of size 8",
AssertionError,
escape(
"Guard failed: x.size()[1] * x.size()[0] % (-1 + x.size()[0]) == 0"
),
):
# expected 3*..., but got 8
ep.module()(torch.randn(4, 2))
@testing.expectedFailureSerDer # T195866111
@ -15928,9 +16031,10 @@ def forward(self, q, k, v):
self.assertEqual(res[1], 5)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1] to be equal to 5, but got 20"),
AssertionError,
escape("Guard failed: y == 5"),
):
# expected 5, but got 20
res = ep.module()(torch.tensor(4), 20)
class F(torch.nn.Module):

View File

@ -411,9 +411,10 @@ class TestPasses(TestCase):
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
AssertionError,
escape("Guard failed: x.size()[1] <= 6"),
):
# expected <= 6, but got 7
ep.module()(torch.zeros(2, 7, 3))
self.assertEqual(
@ -442,15 +443,17 @@ class TestPasses(TestCase):
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
AssertionError,
escape("Guard failed: x.size()[1] <= 6"),
):
# expected <= 6, but got 7
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"),
AssertionError,
escape("Guard failed: y.size()[0] >= 3"),
):
# expected >= 3, but got 2
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
def test_runtime_assert_some_dims_not_specified(self) -> None:
@ -475,16 +478,18 @@ class TestPasses(TestCase):
)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
AssertionError,
escape("Guard failed: x.size()[1] <= 6"),
):
# expected <= 6, but got 7
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
AssertionError,
escape("Guard failed: y.size()[0] == 5"),
):
# expected 5, but got 2
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
@ -509,14 +514,19 @@ class TestPasses(TestCase):
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}, strict=True
)
with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
with self.assertRaisesRegex(
AssertionError,
escape("Guard failed: x.size()[1] == 2"),
):
# expected 2, but got 7
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
# y is specialized to 5
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
AssertionError,
escape("Guard failed: y.size()[0] == 5"),
):
# expected 5, but got 2
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
@ -803,6 +813,7 @@ def forward(self, token, obj_attr, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
@ -822,6 +833,7 @@ def forward(self, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
@ -841,6 +853,7 @@ def forward(self, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
@ -860,6 +873,7 @@ def forward(self, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
@ -880,6 +894,7 @@ def forward(self, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add)
sum_1 = torch.ops.aten.sum.default(sin); sin = None
@ -905,6 +920,7 @@ def forward(self, x):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
@ -940,6 +956,7 @@ def forward(self, x):
"""\
def forward(self, x1, x2):
x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
submod_0 = self.submod_0(x1, x2); submod_0 = None
submod_1 = self.submod_1(x1, x2); x1 = x2 = None
getitem = submod_1[0]
getitem_1 = submod_1[1]; submod_1 = None
@ -995,6 +1012,7 @@ def forward(self, sin, cos):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
submod_3 = self.submod_3
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, add); submod_3 = add = None
@ -1033,6 +1051,7 @@ def forward(self, cos):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_3 = self.submod_1
add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_3, add); submod_3 = add = None
@ -1065,6 +1084,7 @@ def forward(self, add):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
@ -1115,6 +1135,7 @@ def forward(self, add_1):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
@ -1172,6 +1193,7 @@ def forward(self, add_1, add_2):
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
@ -1213,6 +1235,7 @@ def forward(self, add_1):
)
after_inline_str = new_gm.print_readable(print_output=False)
self.assertEqual(before_str, after_inline_str)
new_gm._guards_fn = gm._guards_fn
self.assertEqual(gm(*args), new_gm(*args))
def test_remove_auto_functionalized_pass(self) -> None:

View File

@ -2027,6 +2027,7 @@ def forward(self, obj_attr, x):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
add = torch.ops.aten.add.Tensor(x, takes_foo); x = takes_foo = None
return pytree.tree_unflatten((add,), self._out_spec)""",

View File

@ -185,6 +185,7 @@ class TestExportTorchbind(TestCase):
def forward(self, x, n):
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x, n); n = _guards_fn = None
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
@ -232,6 +233,7 @@ def forward(self, token, obj_attr, x, n):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
@ -266,6 +268,7 @@ def forward(self, token, obj_attr, x):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
@ -300,6 +303,7 @@ def forward(self, token, obj_attr, x):
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
_guards_fn = self._guards_fn(x, cc); _guards_fn = None
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
@ -362,6 +366,7 @@ def forward(self, token, x, cc):
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
_guards_fn = self._guards_fn(x, cc); _guards_fn = None
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
@ -457,6 +462,7 @@ def forward(self, token, x, cc):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
@ -499,6 +505,7 @@ def forward(self, token, obj_attr, x):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
getitem_2 = takes_foo_list_return_default[0]
getitem_3 = takes_foo_list_return_default[1]
@ -551,6 +558,7 @@ def forward(self, token, obj_attr, x):
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
attr = self.attr
_guards_fn = self._guards_fn(x); _guards_fn = None
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
getitem_1 = takes_foo_tuple_return_default[0]
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
@ -1065,6 +1073,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
"""\
def forward(self, tq, x):
tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec)
_guards_fn = self._guards_fn(tq, x); _guards_fn = None
queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None
return pytree.tree_unflatten((tq,), self._out_spec)""",
)

View File

@ -359,9 +359,10 @@ class TestUnflatten(TestCase):
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
AssertionError,
escape("Guard failed: x.size()[0] == 2"),
):
# expected 2, but got 6
export_module.module()(torch.randn(6, 6))
unflattened = unflatten(export_module)

View File

@ -7805,6 +7805,8 @@ class GraphModule(torch.nn.Module):
x: "f32[s77, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
@ -7953,6 +7955,8 @@ class GraphModule(torch.nn.Module):
t: "f32[2, 3]";
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
_guards_fn = self._guards_fn(t); _guards_fn = None
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
@ -8101,6 +8105,8 @@ class GraphModule(torch.nn.Module):
x: "f32[s77, 3]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
_guards_fn = self._guards_fn(x); _guards_fn = None
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
@ -8520,6 +8526,8 @@ class GraphModule(torch.nn.Module):
a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
_guards_fn = self._guards_fn(a, b1, b2, c); _guards_fn = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, (c, b1, b2)); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
@ -8602,6 +8610,8 @@ class GraphModule(torch.nn.Module):
x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]";
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
_guards_fn = self._guards_fn(x, y, z); _guards_fn = None
sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None
sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0)

View File

@ -249,7 +249,7 @@ with torch.no_grad():
assert res is not None
ep_file_path = res.get_exported_program_path()
assert ep_file_path is not None
gm = export_load(ep_file_path).module()
gm = export_load(ep_file_path).module(check_guards=False)
self.assertExpectedInline(
str(gm.code).strip(),
"""\

View File

@ -63,7 +63,7 @@ class MinifierUtilsTests(TestCase):
)
model = M()
gm = torch.export.export(model, inputs, strict=False).module()
gm = torch.export.export(model, inputs, strict=False).module(check_guards=False)
# TODO: make NNModuleToString.convert() generate string for nested submodules.
model_string = get_module_string(gm)

View File

@ -162,7 +162,7 @@ def save_graph_repro_ep(
assert args is not None
exported_program = torch.export.export(gm, args, strict=strict)
elif gm is None:
gm = exported_program.module()
gm = exported_program.module(check_guards=False)
# save a graph preview using gm
module_string = get_module_string(gm) # type: ignore[arg-type]
@ -302,7 +302,7 @@ def repro_common(
options: Any, exported_program: ExportedProgram
) -> tuple[torch.fx.GraphModule, Any, Any]:
torch._inductor.config.generate_intermediate_hooks = True
mod = exported_program.module()
mod = exported_program.module(check_guards=False)
args, kwargs = exported_program.example_inputs
return mod, args, kwargs # type: ignore[return-value]
@ -368,7 +368,7 @@ def export_for_aoti_minifier(
try:
ep = torch.export.export(gm, tuple_inputs, strict=strict)
gm = ep.module()
gm = ep.module(check_guards=False)
return gm
except Exception as e:
if skip_export_error:

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<e623701883a0cff67761e994ac9b3d5e44d3f27102c9420932a1275b5b0ad41d>>
// checksum<<a1c01cb72b55ca996960afa7e3b5ab6714405b036d8a3c33a81084a84e58bbab>>
namespace py3 torch._export
namespace cpp2 torch._export.schema
@ -342,6 +342,7 @@ struct ExportedProgram {
60: SchemaVersion schema_version;
70: list<string> verifiers;
80: string torch_version;
90: list<string> guards_code;
}
struct PayloadMeta {

View File

@ -9,7 +9,7 @@ from torch._export.serde.union import _Union, _union_dataclass
# NOTE: Please update this value if any modifications are made to the schema
SCHEMA_VERSION = (8, 13)
SCHEMA_VERSION = (8, 14)
TREESPEC_VERSION = 1
@ -449,6 +449,7 @@ class ExportedProgram:
schema_version: Annotated[SchemaVersion, 60]
verifiers: Annotated[list[str], 70] = field(default_factory=list)
torch_version: Annotated[str, 80] = "<=2.4"
guards_code: Annotated[list[str], 90] = field(default_factory=list)
#########################################################################

View File

@ -1,5 +1,5 @@
# @generated by update_schema.py
# checksum<<60fd32a0a8ae87c628c02d23641902e9339b813d4f553cdb39a2f9533b33060f>>
# checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>>
AOTInductorModelPickleData:
kind: struct
fields:
@ -140,6 +140,9 @@ ExportedProgram:
torch_version:
type: str
default: <=2.4
guards_code:
type: List[str]
default: '[]'
ExternKernelNode:
kind: struct
fields:
@ -548,5 +551,5 @@ UserOutputSpec:
type: Argument
SCHEMA_VERSION:
- 8
- 13
- 14
TREESPEC_VERSION: 1

View File

@ -1794,6 +1794,7 @@ class ExportedProgramSerializer(metaclass=Final):
),
verifiers=[v.dialect for v in exported_program.verifiers],
torch_version=torch.__version__,
guards_code=exported_program._guards_code,
)
# Test canonical form is well defined.
@ -3038,6 +3039,7 @@ class ExportedProgramDeserializer(metaclass=Final):
constants=res.constants,
verifiers=[load_verifier(v) for v in exported_program.verifiers],
)
result._guards_code = exported_program.guards_code
log.debug("\n[deserialize]: %s", result)
return result
@ -3502,6 +3504,7 @@ def canonicalize(
range_constraints = dict(
sorted(ep.range_constraints.items(), key=operator.itemgetter(0))
)
guards_code = sorted(ep.guards_code)
module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
signature = ep.graph_module.signature
graph = ep.graph_module.graph
@ -3698,6 +3701,7 @@ def canonicalize(
schema_version=ep.schema_version,
verifiers=ep.verifiers,
torch_version=ep.torch_version,
guards_code=guards_code,
)

View File

@ -419,7 +419,7 @@ def _check_symint(
# this means we deferred a guard from export analysis to runtime, let this pass
# we'll add a runtime assert checking equality to this replacement expression
pass
elif arg != symint:
elif arg != int(symint):
path = get_keystr(keypath)
if i is not None:
path += f".shape[{i}]"

View File

@ -292,6 +292,15 @@ def aot_compile(
"""
from .compile_fx import _aoti_flatten_inputs, compile_fx_aot
if hasattr(gm, "_guards_fn"):
# Do not compile the guards function, since it may contain checks
# that are not currently supported by AOTI. In particular, non-Tensor
# arguments are converted to None and will fail specialization checks.
node = next(iter(gm.graph.find_nodes(op="call_module", target="_guards_fn")))
gm.graph.erase_node(node)
delattr(gm, "_guards_fn")
gm.recompile()
flat_example_inputs, options = _aoti_flatten_inputs(
gm, args, kwargs, options=options
)

View File

@ -1250,7 +1250,7 @@ def aot_inductor_minifier_wrapper(
use_minifier = config.aot_inductor.dump_aoti_minifier
gm = exported_program.module()
gm = exported_program.module(check_guards=False)
assert isinstance(gm, torch.fx.GraphModule)
args, kwargs = exported_program.example_inputs
@ -1279,7 +1279,7 @@ def aot_inductor_minifier_wrapper(
tuple_inputs = tuple(flat_example_inputs)
flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False)
func(
flattened_ep.module(),
flattened_ep.module(check_guards=False),
tuple_inputs,
inductor_configs=config_copy,
package_path=package_path,

View File

@ -361,7 +361,7 @@ def _get_aten_graph_module_for_pattern(
example_inputs,
kwargs,
strict=True,
).module()
).module(check_guards=False)
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
aten_pattern.recompile() # type: ignore[operator]

View File

@ -1,5 +1,5 @@
// @generated by update_schema.py
// checksum<<60fd32a0a8ae87c628c02d23641902e9339b813d4f553cdb39a2f9533b33060f>>
// checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>>
// clang-format off
#pragma once
@ -3110,6 +3110,7 @@ class ExportedProgram {
SchemaVersion schema_version;
std::vector<std::string> verifiers = {};
std::string torch_version = "<=2.4";
std::vector<std::string> guards_code = {};
public:
@ -3161,6 +3162,14 @@ class ExportedProgram {
torch_version = std::move(def);
}
const std::vector<std::string>& get_guards_code() const {
return guards_code;
}
void set_guards_code(std::vector<std::string> def) {
guards_code = std::move(def);
}
friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t);
friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t);
};
@ -3406,6 +3415,7 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nloh
nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version;
nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers;
nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version;
nlohmann_json_j["guards_code"] = nlohmann_json_t.guards_code;
}
inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) {
@ -3416,6 +3426,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nl
nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version);
nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers);
nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version);
nlohmann_json_t.guards_code = nlohmann_json_j.value("guards_code", nlohmann_json_default_obj.guards_code);
}
inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) {

View File

@ -681,7 +681,7 @@ class Pipe(torch.nn.Module):
``output_loss_value_spec={'loss': True, 'model_out': False}``
"""
traced = exported_program.module()
traced = exported_program.module(check_guards=False)
if split_policy is not None:
logger.info("Auto-splitting model")

View File

@ -540,6 +540,7 @@ def draft_export(
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram:
"""
A version of torch.export.export which is designed to consistently produce
@ -555,6 +556,7 @@ def draft_export(
dynamic_shapes=dynamic_shapes,
preserve_module_call_signature=preserve_module_call_signature,
strict=strict,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)

View File

@ -371,6 +371,7 @@ def draft_export(
preserve_module_call_signature: tuple[str, ...] = (),
strict: bool = False,
pre_dispatch: bool = True,
prefer_deferred_runtime_asserts_over_guards: bool = False,
) -> ExportedProgram:
start_time = time.time()
kwargs = kwargs or {}
@ -396,6 +397,7 @@ def draft_export(
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
except Exception as exc:
if (
@ -420,6 +422,7 @@ def draft_export(
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
else:
log_draft_export_usage(

View File

@ -163,7 +163,7 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
"""
for node in gm.graph.nodes:
if node.op == "call_module":
if node.op == "call_module" and node.target != "_guards_fn":
_try_remove_connecting_pytrees(node)
gm.graph.eliminate_dead_code()

View File

@ -1,10 +1,14 @@
# mypy: allow-untyped-defs
import copy
import inspect
import math
import warnings
from collections.abc import Sequence
from itertools import chain
from typing import Any, Optional
import sympy
import torch
import torch.utils._pytree as pytree
from torch._export.non_strict_utils import (
@ -12,11 +16,16 @@ from torch._export.non_strict_utils import (
_exit_enable_graph_inputs_of_type_nn_module,
_get_graph_inputs_of_type_nn_module,
)
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_convert_range_to_int,
)
from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.traceback import NodeSource, NodeSourceAction
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import ValueRanges
from ._remove_effect_tokens_pass import _remove_effect_tokens
from ._tree_utils import reorder_kwargs
@ -73,20 +82,107 @@ def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
return flat_args_with_path
def _convert_guards_code_to_fn(
guards_code: list[str],
paths_of_placeholders: list[pytree.KeyPath],
):
"""
Generates Python code given guards code and paths of placeholders.
We assume that, based on source information,
- the tracer generates the guards code
- the input spec generates the paths of placeholders.
Example:
Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
and we are given that ['z']['k'] is the path of placeholder #2.
Then we will generate:
```
torch._assert(
args[2].size()[0] == 3,
"Guard failed: z['k'].size()[0] == 3",
)
```
FAQ: Why do we generate code based on (flattened) args instead of
the original (unflattened) inputs? Because this would require
inserting an additional pytree.unflatten call in our graph.
FAQ: Why do we not emit RuntimeError on guard failure as we used to?
Because it is inconvenient :/, get used to AssertionError instead.
"""
import ast
from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
actual_guards_code = []
shadow_guards_code = []
for c in guards_code:
a, s = c, c
for idx, path in enumerate(paths_of_placeholders):
# e.g., replace L['z']['k'] with args[2] for Python code (actual)
a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
# e.g., replace L['z']['k'] with z['k'] for error message (shadow)
s = s.replace(
"L" + pytree.keystr(path),
path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
)
actual_guards_code.append(a)
shadow_guards_code.append(s.replace("\n", ""))
# generate function code as str
code_str = "\ndef _(*args):\n"
for actual, shadow in zip(actual_guards_code, shadow_guards_code):
# printing guards code may potentially introduce redundant parens;
# we can normalize them out for readability by parsing/unparsing
# NOTE: this is not necessary for correctness, just deemed desirable
_shadow = ast.unparse(ast.parse(shadow, mode="eval"))
# actual code and shadow error message
code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
code_str += " return\n"
# populate namespace with sympy globals, materialize function (named `_`)
namespace = {**SYMPY_INTERP}
exec(code_str, namespace)
# create and return a module whose forward is the materialized function
# NOTE: we want Dynamo to trace through this module, to repopulate guards:
# otherwise we would lose them when retracing
# NOTE: calling this module will be a side effect (no users): so it must
# be marked impure to avoid being not cleaned up by DCE
guards_fn = GuardsFn()
guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
guards_fn._is_impure = True # type: ignore[assignment]
return guards_fn
@torch._dynamo.disable
def _check_input_constraints_pre_hook(self, args, kwargs):
if not self.validate_inputs:
return
def _check_input_constraints_for_module(self, args, kwargs):
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
_check_input_constraints_for_graph(
[node for node in self.graph.nodes if node.op == "placeholder"],
self.graph.find_nodes(op="placeholder"),
flat_args_with_path,
self.range_constraints,
)
def _check_input_constraints_pre_hook(self, args, kwargs):
# preserve current behavior for clients that do not want any validation
if not self.validate_inputs:
return
# when a guards function exists, assume that the graph does calls it!
# so we do not need to check input constraints...but we still want
# to check inputs match, otherwise we'd get obscure pytree errors
if hasattr(self, "_guards_fn"):
_check_inputs_match(args, kwargs, self._in_spec)
return
# NOTE: this call is Dynamo disabled, as it used to be
_check_input_constraints_for_module(self, args, kwargs)
def _unlift_inputs_as_getattr(
gm: torch.fx.GraphModule,
lifted_inputs: Sequence[Optional[str]],
@ -419,10 +515,149 @@ def _create_stateful_graph_module(
return stateful_gm
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule:
def _get_input_paths(example_inputs, signature):
"""
Generate paths of placeholders, needed for generating the guards function.
NOTE: Here we make use of the example inputs used for export as well as
the signature of the unlifted graph module (not preserved by export).
"""
args, kwargs = example_inputs
ctx = signature.bind(*args, **kwargs).arguments
flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
return [path for path, _ in flat_example_inputs_with_paths]
def _get_input_guards_for_graph(
placeholders: list[torch.fx.Node],
range_constraints: dict[sympy.Symbol, ValueRanges],
paths_for_placeholders: list[pytree.KeyPath],
):
"""
Guards generated by the tracer include conditions observed in code, but
but do not include some additional checks we typically do in export.
For example, when dynamic shapes get specialized, are specified to be
within a range, or are specified to be in some equational relation,
corresponding input invalidation is done within a pre_hook, specifically,
`_check_input_constraints_for_graph`.
Here we generate guards corresponding to the checks that happen in
`_check_input_constraints_for_graph`, and add them to the guards already
generated by the tracer. In the future, it may be worthwhile to separate
them so that we can allow clients to turn off one but not the other.
(Looking at you, AOTI.)
NOTE: We should eventually reconcile this logic with `build_guards` that
is used by AOT Precompile.
"""
deferred_expressions = []
new_guards_code = []
sources: dict[sympy.Expr, str] = {}
def handle_symint(expr, src):
if len(expr.free_symbols) == 1:
# complex equations (e.g., involving derived dims) need to
# handled later, since we may not have enough information
# just as we are passing through the placeholders in order
deferred_expressions.append((src, expr))
if expr in sources:
# expressions that appear in multiple sources should force
# inputs corresponding to those sources to be equal
# e.g., x.shape[0] == y.shape[1]
orig_src = sources[expr]
new_guards_code.append(f"{src} == {orig_src}")
else:
sources[expr] = src
# process value ranges as elsewhere in export
min_val, max_val = _convert_range_to_int(range_constraints[expr])
if min_val > 2:
new_guards_code.append(f"{src} >= {min_val}")
if max_val < math.inf:
new_guards_code.append(f"{src} <= {max_val}")
for placeholder, path in zip(placeholders, paths_for_placeholders):
src = "L" + pytree.keystr(path)
meta = placeholder.meta["val"]
# specializations
if isinstance(meta, int):
new_guards_code.append(f"{src} == {meta}")
if isinstance(meta, float):
if meta == math.inf:
new_guards_code.append(f"{src} == math.inf")
elif meta == -math.inf:
new_guards_code.append(f"{src} == -math.inf")
else:
new_guards_code.append(f"{src} == {meta}")
elif isinstance(meta, str):
new_guards_code.append(f"{src} == '{meta}'")
# range constraints and equalities
elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
handle_symint(meta.node.expr, src)
elif isinstance(meta, torch.Tensor):
for i, dim in enumerate(meta.shape):
src = "L" + pytree.keystr(path) + f".size()[{i}]"
if isinstance(dim, int):
# specializations
new_guards_code.append(f"{src} == {dim}")
elif (
isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
):
# range constraints and equalities
handle_symint(dim.node.expr, src)
unification_map: dict[sympy.Symbol, sympy.Expr] = {}
py_printer = torch.utils._sympy.printers.PythonPrinter()
# process complex equations (e.g., involving derived dims)
for src, expr in deferred_expressions:
# we know this is the only symbol in expr (see check above)
symbol = next(iter(expr.free_symbols))
if symbol in sources:
# if s0 is already known to be directly sourced from inputs,
# e.g., z.shape[2], we do not need to do anything further
# (assume we have already processed constraints on s0 above)
continue
# otherwise s0 has some "hidden" source like 'dim'
# example: src = y.shape[1], expr = s0 + 1
if symbol in unification_map:
# suppose that we already know that s0 = x.shape[0] * 2
# so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
substitution = expr.subs(unification_map)
new_guards_code.append(
py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
)
else:
# we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
# we can solve for s0...now knowing that s0 = y.shape[1] - 1
solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
if solution is not None:
definition = solution[1]
unification_map[symbol] = definition
return new_guards_code
def _unlift_exported_program_lifted_states(
ep: ExportedProgram, check_guards=True
) -> torch.fx.GraphModule:
# force check_guards=False for executorch because
# its pass infra has too many calls to .module()
# and but does not like call modules in the graph
# TODO: update executorch to check_guards=False
frame = inspect.currentframe()
while frame is not None:
if "executorch" in frame.f_code.co_filename:
check_guards = False
break
frame = frame.f_back
# TODO T206340015
if ep.verifiers[0].dialect != "TRAINING":
ep = _remove_effect_tokens(ep)
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
forward_arg_names = (
@ -489,4 +724,37 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.Grap
)
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
unlift_gm.meta.update(ep.graph_module.meta)
# create a _guards_fn submodule and insert a call to it after placeholders
graph = unlift_gm.graph
placeholders = graph.find_nodes(op="placeholder")
if check_guards and placeholders and ep.example_inputs:
input_paths = _get_input_paths(
ep.example_inputs,
inspect.signature(unlift_gm.forward),
)
guards_code = _get_input_guards_for_graph(
placeholders, ep.range_constraints, input_paths
)
guards_code.extend(ep._guards_code)
unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
graph
)
with graph.inserting_after(placeholders[-1]):
node = graph.call_module("_guards_fn", tuple(placeholders))
node.meta["nn_module_stack"] = root_nn_module_stack
unlift_gm.recompile()
return unlift_gm
class GuardsFn(torch.nn.Module):
"""
Module class for guard functions.
"""
def forward(self, *args):
pass

View File

@ -887,7 +887,7 @@ class AdditionalInputs:
epm = ep.module()
for args, kwargs in self._examples:
torch.export._unlift._check_input_constraints_pre_hook(
torch.export._unlift._check_input_constraints_for_module(
epm, args, kwargs or {}
)

View File

@ -1047,6 +1047,8 @@ class ExportedProgram:
_verifiers: list[type[Verifier]]
"""List of verifier classes used to validate the exported program."""
_guards_code: list[str]
def __init__(
self,
root: Union[torch.nn.Module, dict[str, Any]],
@ -1084,6 +1086,8 @@ class ExportedProgram:
# Validate should be always the last step of the constructor.
self.validate()
self._guards_code = _convert_guards_to_code(_get_shape_env(self._graph_module))
@property
@compatibility(is_backward_compatible=False)
def graph_module(self):
@ -1379,13 +1383,20 @@ class ExportedProgram:
)
return string
def module(self) -> torch.fx.GraphModule:
def module(self, check_guards=True) -> torch.fx.GraphModule:
"""
Returns a self contained GraphModule with all the parameters/buffers inlined.
- When `check_guards=True` (default), a `_guards_fn` submodule is generated
and a call to a `_guards_fn` submodule is inserted right after placeholders
in the graph. This module checks guards on inputs.
- When `check_guards=False`, a subset of these checks are performed by a
forward pre-hook on the graph module. No `_guards_fn` submodule is generated.
"""
from ._unlift import _unlift_exported_program_lifted_states
module = _unlift_exported_program_lifted_states(self)
module = _unlift_exported_program_lifted_states(self, check_guards=check_guards)
def _train(self, mode: bool = True):
raise NotImplementedError("Calling train() is not supported yet.")
@ -1677,3 +1688,25 @@ def _create_graph_module_for_export(root, graph):
gm._graph = graph
return gm
def _convert_guards_to_code(shape_env):
if shape_env is None:
return []
local_vars = {
var
for var, sources in shape_env.var_to_sources.items()
if all(
not isinstance(source, torch._dynamo.source.ConstantSource)
for source in sources
)
}
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
)
return [
py_printer.doprint(guard.expr)
for guard in shape_env.guards
if guard.expr.free_symbols.issubset(local_vars)
]

View File

@ -95,7 +95,7 @@ class SubgraphMatcher:
)
for node in pattern.nodes:
if node.op != "output":
if node.op != "output" and not node.is_impure():
assert len(node.users) > 0, (
"SubgraphMatcher cannot be initialized with an pattern with dead code"
)

View File

@ -3394,7 +3394,7 @@ def _generate_qdq_quantized_model(
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
with maybe_no_grad:
export_model = export_for_training(mod, inputs, strict=True).module()
export_model = export_for_training(mod, inputs, strict=True).module(check_guards=False)
quantizer = (
quantizer
if quantizer