mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
shape guards (#161178)
Summary: This PR introduces shape guards to export. Previously only value ranges, equalities, and specializations would be tracked for symbolic expressions, and we had a forward hook to check them. Instead now we create a function to check shape guards and call it in the exported program. Test Plan: updated several tests Rollback Plan: Differential Revision: D80713603 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161178 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
2c538c9acf
commit
711c8c821e
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: export"]
|
||||
import copy
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -407,7 +408,12 @@ class TestDraftExport(TestCase):
|
||||
|
||||
inp = (torch.ones(3, 3),)
|
||||
|
||||
ep = draft_export(M(), inp, dynamic_shapes={"a": {0: Dim("a0")}})
|
||||
ep = draft_export(
|
||||
M(),
|
||||
inp,
|
||||
dynamic_shapes={"a": {0: Dim("a0")}},
|
||||
prefer_deferred_runtime_asserts_over_guards=True,
|
||||
)
|
||||
report = ep._report
|
||||
|
||||
self.assertEqual(len(report.failures), 1)
|
||||
@ -417,7 +423,11 @@ class TestDraftExport(TestCase):
|
||||
self.assertEqual(ep.module()(*inp), M()(*inp))
|
||||
|
||||
inp = (torch.randn(4, 3),)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
re.escape("Guard failed: a.size()[0] <= 3"),
|
||||
):
|
||||
# expected <= 3, but got 4
|
||||
ep.module()(*inp)
|
||||
|
||||
def test_side_effect1(self):
|
||||
|
@ -319,6 +319,7 @@ def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
linear_weight = self.linear.weight
|
||||
linear_bias = self.linear.bias
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
|
||||
return pytree.tree_unflatten((linear,), self._out_spec)""",
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from re import escape
|
||||
from typing import Dict, List, Union
|
||||
@ -514,7 +514,13 @@ class TestExport(TestCase):
|
||||
# )
|
||||
|
||||
def _check_dynamic_shapes_specs_and_shapes(
|
||||
self, model, inputs, specs, passing_shapes, failing_shapes, test_serdes=False
|
||||
self,
|
||||
model,
|
||||
inputs,
|
||||
specs,
|
||||
passing_shapes,
|
||||
failing_shapes,
|
||||
test_serdes=False,
|
||||
):
|
||||
from torch._export.serde.dynamic_shapes import (
|
||||
_dump_dynamic_shapes,
|
||||
@ -556,7 +562,7 @@ class TestExport(TestCase):
|
||||
ep.module()(*test_inputs)
|
||||
for shapes in failing_shapes:
|
||||
test_inputs = _construct_inputs(shapes)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaisesRegex(AssertionError, "Guard failed"):
|
||||
ep.module()(*test_inputs)
|
||||
|
||||
def test_basic(self):
|
||||
@ -635,7 +641,7 @@ class TestExport(TestCase):
|
||||
from torch.fx.traceback import NodeSourceAction
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op in ("placeholder", "output"):
|
||||
if node.op in ("placeholder", "output", "call_module"):
|
||||
continue
|
||||
if "weight" in node.name or "bias" in node.name:
|
||||
self.assertTrue(
|
||||
@ -664,7 +670,7 @@ class TestExport(TestCase):
|
||||
graph_id = id(ep2.graph)
|
||||
|
||||
for node in gm2.graph.nodes:
|
||||
if node.op in ("placeholder", "output"):
|
||||
if node.op in ("placeholder", "output", "call_module"):
|
||||
continue
|
||||
|
||||
if "weight" in node.name or "bias" in node.name:
|
||||
@ -927,7 +933,8 @@ graph():
|
||||
"""\
|
||||
graph():
|
||||
%lifted_tensor_0 : [num_users=1] = get_attr[target=lifted_tensor_0]
|
||||
%x : [num_users=1] = placeholder[target=x]
|
||||
%x : [num_users=2] = placeholder[target=x]
|
||||
%_guards_fn : [num_users=0] = call_module[target=_guards_fn](args = (%x,), kwargs = {})
|
||||
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %lifted_tensor_0), kwargs = {})
|
||||
return (add,)""",
|
||||
)
|
||||
@ -1491,7 +1498,11 @@ graph():
|
||||
{"a": torch.zeros(5), "b": torch.ones(5)},
|
||||
torch.ones(4),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "to be equal to 6, but got 5"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: ys[0].size()[0] == x.size()[0]"),
|
||||
):
|
||||
# expected 6, but got 5
|
||||
ep_ns.module()(*bad_runtime_inp1)
|
||||
|
||||
bad_runtime_inp2 = (
|
||||
@ -1501,9 +1512,10 @@ graph():
|
||||
torch.ones(6),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[3].shape[0] to be equal to 4, but got 6"),
|
||||
AssertionError,
|
||||
escape("Guard failed: c.size()[0] == 4"),
|
||||
):
|
||||
# expected 4, but got 6
|
||||
ep_ns.module()(*bad_runtime_inp2)
|
||||
|
||||
good_runtime_inp = (
|
||||
@ -1651,6 +1663,8 @@ class GraphModule(torch.nn.Module):
|
||||
x: "f32[3, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||
|
||||
@ -1736,6 +1750,8 @@ class GraphModule(torch.nn.Module):
|
||||
x: "f32[3, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(x)
|
||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 3); sum_1 = None
|
||||
|
||||
@ -3104,16 +3120,22 @@ def forward(self, causal_mask, fill_value):
|
||||
ep = export(Foo(), inputs, dynamic_shapes=shapes)
|
||||
ep.module()(torch.randn(8, 5), torch.randn(8, 5))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Expected input at .* to be >= 4, but got 3"
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] >= 4"),
|
||||
):
|
||||
# expected >= 4, but got 3
|
||||
ep.module()(torch.randn(3, 5), torch.randn(3, 5))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Expected input at .* to be <= 16, but got 17"
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] <= 16"),
|
||||
):
|
||||
# expected <= 16, but got 17
|
||||
ep.module()(torch.randn(17, 5), torch.randn(17, 5))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Expected input at .* to be <= 32, but got 33"
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[1] <= 32"),
|
||||
):
|
||||
# expected <= 32, but got 33
|
||||
ep.module()(torch.randn(9, 33), torch.randn(9, 33))
|
||||
|
||||
def test_dim_hint_range_violations(self):
|
||||
@ -3368,11 +3390,12 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
|
||||
actual_torch_fns = []
|
||||
for mod in gm.modules():
|
||||
for node in mod.graph.nodes:
|
||||
if node.name in {"sin", "cos"}:
|
||||
torch_fn = node.meta.get("torch_fn")
|
||||
print(torch_fn)
|
||||
actual_torch_fns.append(torch_fn)
|
||||
if hasattr(mod, "graph"):
|
||||
for node in mod.graph.nodes:
|
||||
if node.name in {"sin", "cos"}:
|
||||
torch_fn = node.meta.get("torch_fn")
|
||||
print(torch_fn)
|
||||
actual_torch_fns.append(torch_fn)
|
||||
exp_torch_fns = [
|
||||
("cos_1", "method_descriptor.cos"),
|
||||
("sin_1", "method_descriptor.sin"),
|
||||
@ -3545,9 +3568,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 5, but got 6",
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] == -1 + y.size()[0]"),
|
||||
):
|
||||
# expected 5, but got 6
|
||||
ep.module()(torch.randn(4), torch.randn(6))
|
||||
|
||||
self.assertEqual(ep.module()(torch.randn(4), torch.randn(5)).size()[0], 4)
|
||||
@ -3606,13 +3630,16 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimz}, {0: dimy}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Expected input.*shape.*to be <= 7, but got 8"
|
||||
AssertionError,
|
||||
escape("Guard failed: z.size()[0] <= 7"),
|
||||
):
|
||||
# expected <= 7, but got 8
|
||||
ep.module()(torch.randn(8), torch.randn(15))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 9, but got 8",
|
||||
AssertionError,
|
||||
escape("Guard failed: -1 + 2 * z.size()[0] == y.size()[0]"),
|
||||
):
|
||||
# expected 9, but got 8
|
||||
ep.module()(torch.randn(5), torch.randn(8))
|
||||
|
||||
self.assertEqual(ep.module()(torch.randn(5), torch.randn(9)).size()[0], 4)
|
||||
@ -3648,17 +3675,18 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimw},),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*= 9 to be "
|
||||
"of the form 2\\*s92, where s92 is an integer",
|
||||
AssertionError,
|
||||
escape("Guard failed: w.size()[0] % 2 == 0"),
|
||||
):
|
||||
# expected 2*..., got 9
|
||||
ep.module()(torch.randn(9))
|
||||
|
||||
self.assertEqual(ep.module()(torch.randn(8)).size()[0], 4)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be <= 12, but got 14",
|
||||
AssertionError,
|
||||
escape("Guard failed: w.size()[0] <= 12"),
|
||||
):
|
||||
# expected <= 12, but got 14
|
||||
ep.module()(torch.randn(14))
|
||||
|
||||
def test_derived_dim_repeat_derived(self):
|
||||
@ -3696,9 +3724,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 8, but got 5",
|
||||
AssertionError,
|
||||
escape("Guard failed: z.size()[0] >= 6"),
|
||||
):
|
||||
# expected 8, but got 5
|
||||
ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
|
||||
|
||||
self.assertEqual(
|
||||
@ -3731,9 +3760,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 6, but got 5",
|
||||
AssertionError,
|
||||
escape("Guard failed: x2.size()[0] == x.size()[0]"),
|
||||
):
|
||||
# expected 6, but got 5
|
||||
ep.module()(
|
||||
torch.randn(6),
|
||||
torch.randn(7),
|
||||
@ -3759,9 +3789,10 @@ def forward(self, p_linear_weight, p_linear_bias, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}, {0: dimx1}, {0: dimx2}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 6, but got 5",
|
||||
AssertionError,
|
||||
escape("Guard failed: x2.size()[0] == x.size()[0]"),
|
||||
):
|
||||
# expected 6, but got 5
|
||||
ep.module()(
|
||||
torch.randn(6),
|
||||
torch.randn(7),
|
||||
@ -4197,9 +4228,10 @@ def forward(self, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimz}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 8, but got 5",
|
||||
AssertionError,
|
||||
escape("Guard failed: z.size()[0] >= 6"),
|
||||
):
|
||||
# expected 8, but got 5
|
||||
ep.module()(torch.randn(6), torch.randn(7), torch.randn(5))
|
||||
|
||||
self.assertEqual(
|
||||
@ -4234,6 +4266,7 @@ def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
linear_weight = self.linear.weight
|
||||
linear_bias = self.linear.bias
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
|
||||
return pytree.tree_unflatten((linear,), self._out_spec)""",
|
||||
)
|
||||
@ -4274,6 +4307,7 @@ def forward(self, b_buffer, x):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
buffer = self.buffer
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add_ = torch.ops.aten.add_.Tensor(x, 5); x = None
|
||||
add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None
|
||||
add = torch.ops.aten.add.Tensor(add_, add__1); add_ = add__1 = None
|
||||
@ -4595,9 +4629,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
dynamic_shapes=({0: dimx}, {0: dimy}, {0: dimy}, {0: dimz}),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected input.*shape.*to be equal to 7, but got 5",
|
||||
AssertionError,
|
||||
escape("Guard failed: y1.size()[0] == y.size()[0]"),
|
||||
):
|
||||
# expected 7, but got 5
|
||||
ep.module()(
|
||||
torch.randn(6),
|
||||
torch.randn(7),
|
||||
@ -4649,8 +4684,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
ep = export(foo, inputs, dynamic_shapes=dynamic_shapes)
|
||||
self.assertEqual(foo(*inputs), ep.module()(*inputs))
|
||||
for wrong_inputs in wrong_shape_inputs:
|
||||
with self.assertRaises(RuntimeError):
|
||||
ep.module()(*wrong_inputs)
|
||||
with self.assertRaisesRegex(AssertionError, "Guard failed"):
|
||||
with self.assertRaises(RuntimeError):
|
||||
ep.module()(*wrong_inputs)
|
||||
|
||||
# check range_constraints - static dims shouldn't be present
|
||||
ep = export(foo, inputs, dynamic_shapes=((dx, None), (dy, 4), (dz, 3)))
|
||||
@ -4686,8 +4722,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
ep.module()(torch.randn(1, 2))
|
||||
ep.module()(torch.randn(2, 2))
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Expected input at .* to be <= 2, but got 3"
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] <= 2"),
|
||||
):
|
||||
# expected <= 2, but got 3
|
||||
ep.module()(torch.randn(3, 2))
|
||||
vr = list(ep.range_constraints.values())[0]
|
||||
self.assertEqual(vr.lower, 1)
|
||||
@ -4704,7 +4742,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
(torch.randn(2, 2), torch.randn(3, 2)),
|
||||
dynamic_shapes=({0: dx, 1: None}, {0: dx + 1, 1: None}),
|
||||
)
|
||||
ep.module()(torch.randn(1, 2), torch.randn(2, 2))
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: -1 + y.size()[0] != 1"),
|
||||
):
|
||||
# TODO: this should not error?
|
||||
ep.module()(torch.randn(1, 2), torch.randn(2, 2))
|
||||
range_lower_bounds = sorted(vr.lower for vr in ep.range_constraints.values())
|
||||
range_upper_bounds = sorted(vr.upper for vr in ep.range_constraints.values())
|
||||
self.assertEqual(range_lower_bounds, [1, 2])
|
||||
@ -4894,7 +4937,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
self.assertEqual(got_shapes, expected_shapes)
|
||||
|
||||
def expect_error(bad_args, run_time_msg, compile_time_msg):
|
||||
with self.assertRaisesRegex(RuntimeError, run_time_msg):
|
||||
with self.assertRaisesRegex(AssertionError, run_time_msg):
|
||||
ep.module()(*bad_args)
|
||||
|
||||
additional_inputs = torch.export.AdditionalInputs()
|
||||
@ -4906,21 +4949,27 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
expect_error(
|
||||
# 4->2, 4->2, 3->3
|
||||
bad_args=(torch.randn(2), [torch.randn(2)], {"k": torch.randn(3)}),
|
||||
run_time_msg="Expected input.*to be >= 3, but got 2",
|
||||
run_time_msg=escape(
|
||||
"Guard failed: x.size()[0] >= 3"
|
||||
), # expected >= 3, but got 2
|
||||
compile_time_msg="Expected input.*to be >= 3, but got 2",
|
||||
)
|
||||
|
||||
expect_error(
|
||||
# 4->6, 4->7, 3->3
|
||||
bad_args=(torch.randn(6), [torch.randn(7)], {"k": torch.randn(3)}),
|
||||
run_time_msg="Expected input.*to be equal to 6, but got 7",
|
||||
run_time_msg=escape(
|
||||
"Guard failed: y[0].size()[0] == x.size()[0]"
|
||||
), # expected 6, but got 7
|
||||
compile_time_msg="Expected input.*to be equal to 6, but got 7",
|
||||
)
|
||||
|
||||
expect_error(
|
||||
# 4->5, 4->5, 3->4
|
||||
bad_args=(torch.randn(5), [torch.randn(5)], {"k": torch.randn(4)}),
|
||||
run_time_msg="Expected input.*to be equal to 3, but got 4",
|
||||
run_time_msg=escape(
|
||||
"Guard failed: z['k'].size()[0] == 3"
|
||||
), # expected 3, but got 4
|
||||
compile_time_msg=r"You marked.*but your code specialized it to be a constant.*If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO",
|
||||
)
|
||||
|
||||
@ -5569,7 +5618,18 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
self.assertTrue(torch.allclose(ep.module()(x, y), model(x, y)))
|
||||
x2 = torch.arange(4).reshape((2, 2))
|
||||
y2 = torch.arange(9).reshape((3, 3))
|
||||
self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2)))
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
(
|
||||
escape("Guard failed: max(x.size()[1], y.size()[1]) == x.size()[1]")
|
||||
if is_retracebility_test(self._testMethodName)
|
||||
else escape(
|
||||
"Guard failed: max(1, x.size()[1], y.size()[1]) == x.size()[1]"
|
||||
)
|
||||
),
|
||||
):
|
||||
# TODO: this should not error?
|
||||
self.assertTrue(torch.allclose(ep.module()(x2, y2), model(x2, y2)))
|
||||
|
||||
def test_export_max_nonstrict(self):
|
||||
class FooMax(torch.nn.Module):
|
||||
@ -5712,9 +5772,10 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
em = torch.export.export(m, (a,), dynamic_shapes=dynamic_shapes)
|
||||
x = torch.randn(3, 5)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Expected.*shape\\[1\\] = 5 to be of the form 2\\*s33, where s33 is an integer",
|
||||
AssertionError,
|
||||
escape("Guard failed: 3 * x.size()[1] % 2 == 0"),
|
||||
):
|
||||
# expected 2*..., but got 5
|
||||
em.module()(x)
|
||||
|
||||
def test_dont_duck_size_for_auto_dynamic(self):
|
||||
@ -7814,6 +7875,7 @@ def forward(self, x):
|
||||
bn_running_mean = self.bn.running_mean
|
||||
bn_running_var = self.bn.running_var
|
||||
bn_num_batches_tracked = self.bn.num_batches_tracked; bn_num_batches_tracked = None
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
||||
batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, False, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
|
||||
return pytree.tree_unflatten((batch_norm,), self._out_spec)""",
|
||||
@ -7833,6 +7895,7 @@ def forward(self, x):
|
||||
bn_running_mean = self.bn.running_mean
|
||||
bn_running_var = self.bn.running_var
|
||||
bn_num_batches_tracked = self.bn.num_batches_tracked
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
conv2d = torch.ops.aten.conv2d.default(x, conv_weight, conv_bias); x = conv_weight = conv_bias = None
|
||||
add_ = torch.ops.aten.add_.Tensor(bn_num_batches_tracked, 1); bn_num_batches_tracked = add_ = None
|
||||
batch_norm = torch.ops.aten.batch_norm.default(conv2d, bn_weight, bn_bias, bn_running_mean, bn_running_var, True, 0.1, 1e-05, True); conv2d = bn_weight = bn_bias = bn_running_mean = bn_running_var = None
|
||||
@ -8426,18 +8489,20 @@ def forward(self, x):
|
||||
)
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1] to be equal to 5, but got 6"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y == 5"),
|
||||
):
|
||||
# expected 5, but got 6
|
||||
_ = exported.module()(torch.ones(8, 5), 6)
|
||||
|
||||
exported = torch.export.export(
|
||||
foo, (tensor_inp, 5.0), dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1] to be equal to 5.0, but got 6.0"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y == 5.0"),
|
||||
):
|
||||
# expected 5.0, but got 6.0
|
||||
_ = exported.module()(torch.ones(7, 5), 6.0)
|
||||
|
||||
def test_runtime_assert_for_prm_str(self):
|
||||
@ -8449,8 +8514,10 @@ def forward(self, x):
|
||||
inps = (torch.randn(4, 4), torch.randn(4), "trunc")
|
||||
exported = export(foo, inps)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "to be equal to trunc, but got floor"
|
||||
AssertionError,
|
||||
escape("Guard failed: mode == 'trunc'"),
|
||||
):
|
||||
# expected 'trunc', but got 'floor'
|
||||
_ = exported.module()(torch.randn(4, 4), torch.randn(4), "floor")
|
||||
self.assertTrue(torch.allclose(exported.module()(*inps), foo(*inps)))
|
||||
|
||||
@ -8577,9 +8644,12 @@ def forward(self, x):
|
||||
dim0_x = torch.export.Dim("dim0_x")
|
||||
exported = torch.export.export(Foo(), (inp,), dynamic_shapes=({0: dim0_x},))
|
||||
reexported = torch.export.export(exported.module(), (inp,))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "shape\[0\] to be equal to 5, but got 7"
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] == 5"),
|
||||
):
|
||||
# expected 5, but got 7
|
||||
reexported.module()(torch.ones(7, 5))
|
||||
|
||||
reexported = torch.export.export(
|
||||
@ -8597,9 +8667,10 @@ def forward(self, x):
|
||||
Foo(), (inp,), dynamic_shapes={"x": {0: dim0_x_v2}}
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[0] to be >= 3, but got 2"),
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] >= 3"),
|
||||
):
|
||||
# expected >= 3, but got 2
|
||||
torch.export.export(exported_v2.module(), (torch.randn(2, 2),))
|
||||
|
||||
def test_export_cond_symbool_pred(self):
|
||||
@ -9394,8 +9465,10 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
dynamic_shapes=(None, None),
|
||||
)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "shape\[0\] to be equal to 4, but got 7"
|
||||
AssertionError,
|
||||
escape("Guard failed: b.size()[0] == 4"),
|
||||
):
|
||||
# expected 4, but got 7
|
||||
ep_v2.module()(*test_inp)
|
||||
|
||||
def test_constant_output(self):
|
||||
@ -9475,7 +9548,11 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
ep = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
|
||||
|
||||
test_inp = ((torch.randn(4, 4), torch.randn(2, 4)), torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(RuntimeError, "shape\[0\] to be >= 3, but got 2"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: a[1].size()[0] >= 3"),
|
||||
):
|
||||
# expected >= 3, but got 2
|
||||
ep.module()(*test_inp)
|
||||
|
||||
def test_nested_module(self):
|
||||
@ -9677,13 +9754,17 @@ graph():
|
||||
).module()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, escape("Expected input at *args[0].shape[0]")
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] >= 3"),
|
||||
):
|
||||
# expected >= 3, got 2
|
||||
gm(torch.randn(2, 2))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, escape("Expected input at *args[0].shape[0]")
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] >= 3"),
|
||||
):
|
||||
# expected >= 3, got 2
|
||||
export(gm, (torch.randn(2, 2),))
|
||||
|
||||
ep = export(
|
||||
@ -11502,7 +11583,11 @@ graph():
|
||||
|
||||
ep = export(M(), (4, 5))
|
||||
self.assertEqual(ep.module()(4, 5), 20)
|
||||
with self.assertRaisesRegex(RuntimeError, r"to be equal to 4, but got 3"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x == 4"),
|
||||
):
|
||||
# expected 4, but got 3
|
||||
self.assertEqual(ep.module()(3, 6), 18)
|
||||
|
||||
ep = export(M(), (4, 5), dynamic_shapes={"x": Dim.DYNAMIC, "y": Dim.AUTO})
|
||||
@ -11515,7 +11600,11 @@ graph():
|
||||
|
||||
ep = export(M(), (5, 5), dynamic_shapes={"x": None, "y": Dim.AUTO})
|
||||
self.assertEqual(ep.module()(5, 6), 30)
|
||||
with self.assertRaisesRegex(RuntimeError, r"to be equal to 5, but got 3"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x == 5"),
|
||||
):
|
||||
# expected 5, but got 3
|
||||
self.assertEqual(ep.module()(3, 5), 18)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
@ -11531,7 +11620,6 @@ graph():
|
||||
self.assertTrue(torch.allclose(ep.module()(*inp), M()(*inp)))
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@testing.expectedFailureRetraceabilityNonStrict # no runtime asserts added for assert x == 3
|
||||
def test_symint_input_specialization(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
@ -11556,7 +11644,11 @@ graph():
|
||||
inp,
|
||||
dynamic_shapes=(Dim.AUTO, None),
|
||||
)
|
||||
with self.assertRaisesRegex(RuntimeError, "to be equal to 3, but got 4"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x == 3"),
|
||||
):
|
||||
# expected 3, but got 4
|
||||
ep.module()(4, torch.randn(4, 4))
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@ -11573,9 +11665,17 @@ graph():
|
||||
)
|
||||
|
||||
ep.module()(4, torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(RuntimeError, "to be <= 10, but got 16"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x <= 10"),
|
||||
):
|
||||
# expected <= 10, but got 16
|
||||
ep.module()(16, torch.randn(4, 4))
|
||||
with self.assertRaisesRegex(RuntimeError, "to be >= 3, but got 2"):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x >= 3"),
|
||||
):
|
||||
# expected >= 3, but got 2
|
||||
ep.module()(2, torch.randn(4, 4))
|
||||
|
||||
# While tracing the range was found to be a subset of the original range
|
||||
@ -11593,12 +11693,8 @@ graph():
|
||||
)
|
||||
constraints = list(ep.range_constraints.values())
|
||||
constraint = constraints[0]
|
||||
# retracebility does not remember the range asserts in the forward
|
||||
lower, upper = (
|
||||
(3, 10) if is_retracebility_test(self._testMethodName) else (4, 5)
|
||||
)
|
||||
self.assertEqual(constraint.lower, lower)
|
||||
self.assertEqual(constraint.upper, upper)
|
||||
self.assertEqual(constraint.lower, 4)
|
||||
self.assertEqual(constraint.upper, 5)
|
||||
|
||||
# While tracing the range was found to be bigger than the original range
|
||||
class M(torch.nn.Module):
|
||||
@ -12260,6 +12356,7 @@ def forward(self, c_submod_params, x):
|
||||
[
|
||||
fqn
|
||||
for fqn, _ in unflattened.named_modules(remove_duplicate=False)
|
||||
if fqn != "_guards_fn"
|
||||
]
|
||||
),
|
||||
expected_fqns,
|
||||
@ -13648,9 +13745,12 @@ def forward(self, x, y):
|
||||
else:
|
||||
# no runtime assert in exported module but it fails anyway with wrong inputs
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"The size of tensor a \(40\) must match the size of tensor b \(20\) at non-singleton dimension 0",
|
||||
AssertionError,
|
||||
escape(
|
||||
"Guard failed: x.size()[1] * x.size()[0] == y.size()[0] * y.size()[1]"
|
||||
),
|
||||
):
|
||||
# expected 40, but got 20
|
||||
ep.module()(torch.randn(5, 8), torch.randn(4, 5), torch.randn(30))
|
||||
|
||||
# case 3: 3d reshape (previously failing with different issue)
|
||||
@ -15168,9 +15268,12 @@ def forward(self, x):
|
||||
self.assertEqual(num_asserts, 0)
|
||||
# but it fails anyway with wrong inputs
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"shape '\[3, -1\]' is invalid for input of size 8",
|
||||
AssertionError,
|
||||
escape(
|
||||
"Guard failed: x.size()[1] * x.size()[0] % (-1 + x.size()[0]) == 0"
|
||||
),
|
||||
):
|
||||
# expected 3*..., but got 8
|
||||
ep.module()(torch.randn(4, 2))
|
||||
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
@ -15928,9 +16031,10 @@ def forward(self, q, k, v):
|
||||
self.assertEqual(res[1], 5)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1] to be equal to 5, but got 20"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y == 5"),
|
||||
):
|
||||
# expected 5, but got 20
|
||||
res = ep.module()(torch.tensor(4), 20)
|
||||
|
||||
class F(torch.nn.Module):
|
||||
|
@ -411,9 +411,10 @@ class TestPasses(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[1] <= 6"),
|
||||
):
|
||||
# expected <= 6, but got 7
|
||||
ep.module()(torch.zeros(2, 7, 3))
|
||||
|
||||
self.assertEqual(
|
||||
@ -442,15 +443,17 @@ class TestPasses(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[1] <= 6"),
|
||||
):
|
||||
# expected <= 6, but got 7
|
||||
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1].shape[0] to be >= 3, but got 2"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y.size()[0] >= 3"),
|
||||
):
|
||||
# expected >= 3, but got 2
|
||||
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
def test_runtime_assert_some_dims_not_specified(self) -> None:
|
||||
@ -475,16 +478,18 @@ class TestPasses(TestCase):
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[1] to be <= 6, but got 7"),
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[1] <= 6"),
|
||||
):
|
||||
# expected <= 6, but got 7
|
||||
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y.size()[0] == 5"),
|
||||
):
|
||||
# expected 5, but got 2
|
||||
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
|
||||
@ -509,14 +514,19 @@ class TestPasses(TestCase):
|
||||
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}, strict=True
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[1] == 2"),
|
||||
):
|
||||
# expected 2, but got 7
|
||||
ep.module()(torch.zeros(4, 7, 3), torch.ones(5, 5, 5))
|
||||
|
||||
# y is specialized to 5
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[1].shape[0] to be equal to 5, but got 2"),
|
||||
AssertionError,
|
||||
escape("Guard failed: y.size()[0] == 5"),
|
||||
):
|
||||
# expected 5, but got 2
|
||||
ep.module()(torch.zeros(4, 2, 3), torch.ones(2, 5, 5))
|
||||
|
||||
# Since we didn't insert the constraint for x[1] >= 2, it should work for case where x[1] == 1
|
||||
@ -803,6 +813,7 @@ def forward(self, token, obj_attr, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
sin = torch.ops.aten.sin.default(add); add = None
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
@ -822,6 +833,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
sin = torch.ops.aten.sin.default(add); add = None
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
@ -841,6 +853,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
sin = torch.ops.aten.sin.default(add); add = None
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
@ -860,6 +873,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_5 = self.submod_1
|
||||
sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
|
||||
@ -880,6 +894,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
sin = torch.ops.aten.sin.default(add)
|
||||
sum_1 = torch.ops.aten.sum.default(sin); sin = None
|
||||
@ -905,6 +920,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_5 = self.submod_1
|
||||
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
|
||||
@ -940,6 +956,7 @@ def forward(self, x):
|
||||
"""\
|
||||
def forward(self, x1, x2):
|
||||
x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
|
||||
submod_0 = self.submod_0(x1, x2); submod_0 = None
|
||||
submod_1 = self.submod_1(x1, x2); x1 = x2 = None
|
||||
getitem = submod_1[0]
|
||||
getitem_1 = submod_1[1]; submod_1 = None
|
||||
@ -995,6 +1012,7 @@ def forward(self, sin, cos):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
submod_3 = self.submod_3
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
sin = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, add); submod_3 = add = None
|
||||
@ -1033,6 +1051,7 @@ def forward(self, cos):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_3 = self.submod_1
|
||||
add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_3, add); submod_3 = add = None
|
||||
@ -1065,6 +1084,7 @@ def forward(self, add):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_4 = self.submod_1
|
||||
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
|
||||
@ -1115,6 +1135,7 @@ def forward(self, add_1):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_4 = self.submod_1
|
||||
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
|
||||
@ -1172,6 +1193,7 @@ def forward(self, add_1, add_2):
|
||||
"""\
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
add = torch.ops.aten.add.Tensor(x, 1); x = None
|
||||
submod_4 = self.submod_1
|
||||
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
|
||||
@ -1213,6 +1235,7 @@ def forward(self, add_1):
|
||||
)
|
||||
after_inline_str = new_gm.print_readable(print_output=False)
|
||||
self.assertEqual(before_str, after_inline_str)
|
||||
new_gm._guards_fn = gm._guards_fn
|
||||
self.assertEqual(gm(*args), new_gm(*args))
|
||||
|
||||
def test_remove_auto_functionalized_pass(self) -> None:
|
||||
|
@ -2027,6 +2027,7 @@ def forward(self, obj_attr, x):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo); x = takes_foo = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
|
@ -185,6 +185,7 @@ class TestExportTorchbind(TestCase):
|
||||
def forward(self, x, n):
|
||||
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x, n); n = _guards_fn = None
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
||||
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
@ -232,6 +233,7 @@ def forward(self, token, obj_attr, x, n):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
|
||||
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
@ -266,6 +268,7 @@ def forward(self, token, obj_attr, x):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
@ -300,6 +303,7 @@ def forward(self, token, obj_attr, x):
|
||||
"""\
|
||||
def forward(self, x, cc):
|
||||
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x, cc); _guards_fn = None
|
||||
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
|
||||
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
@ -362,6 +366,7 @@ def forward(self, token, x, cc):
|
||||
"""\
|
||||
def forward(self, x, cc):
|
||||
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x, cc); _guards_fn = None
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
||||
return pytree.tree_unflatten((add,), self._out_spec)""",
|
||||
@ -457,6 +462,7 @@ def forward(self, token, x, cc):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
|
||||
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
|
||||
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
|
||||
@ -499,6 +505,7 @@ def forward(self, token, obj_attr, x):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
|
||||
getitem_2 = takes_foo_list_return_default[0]
|
||||
getitem_3 = takes_foo_list_return_default[1]
|
||||
@ -551,6 +558,7 @@ def forward(self, token, obj_attr, x):
|
||||
def forward(self, x):
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
attr = self.attr
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
|
||||
getitem_1 = takes_foo_tuple_return_default[0]
|
||||
getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None
|
||||
@ -1065,6 +1073,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
"""\
|
||||
def forward(self, tq, x):
|
||||
tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(tq, x); _guards_fn = None
|
||||
queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None
|
||||
return pytree.tree_unflatten((tq,), self._out_spec)""",
|
||||
)
|
||||
|
@ -359,9 +359,10 @@ class TestUnflatten(TestCase):
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
||||
AssertionError,
|
||||
escape("Guard failed: x.size()[0] == 2"),
|
||||
):
|
||||
# expected 2, but got 6
|
||||
export_module.module()(torch.randn(6, 6))
|
||||
|
||||
unflattened = unflatten(export_module)
|
||||
|
@ -7805,6 +7805,8 @@ class GraphModule(torch.nn.Module):
|
||||
x: "f32[s77, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
|
||||
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
|
||||
@ -7953,6 +7955,8 @@ class GraphModule(torch.nn.Module):
|
||||
t: "f32[2, 3]";
|
||||
|
||||
t, = fx_pytree.tree_flatten_spec(([t], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(t); _guards_fn = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(t)
|
||||
_assert_tensor_metadata_default = torch.ops.aten._assert_tensor_metadata.default(sum_1, dtype = torch.float32, device = device(type='cpu'), layout = torch.strided); _assert_tensor_metadata_default = None
|
||||
to: "i64[]" = torch.ops.aten.to.dtype(sum_1, torch.int64); sum_1 = None
|
||||
@ -8101,6 +8105,8 @@ class GraphModule(torch.nn.Module):
|
||||
x: "f32[s77, 3]";
|
||||
|
||||
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x); _guards_fn = None
|
||||
|
||||
sym_size_int_1: "Sym(s77)" = torch.ops.aten.sym_size.int(x, 0)
|
||||
|
||||
sin: "f32[s77, 3]" = torch.ops.aten.sin.default(x); x = None
|
||||
@ -8520,6 +8526,8 @@ class GraphModule(torch.nn.Module):
|
||||
a: "b8[]"; b1: "i64[1]"; b2: "i64[1]"; c: "f32[10]";
|
||||
|
||||
a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(a, b1, b2, c); _guards_fn = None
|
||||
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, (c, b1, b2)); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None
|
||||
@ -8602,6 +8610,8 @@ class GraphModule(torch.nn.Module):
|
||||
x: "f32[s68, 3]"; y: "f32[s17]"; z: "f32[s68, 3]";
|
||||
|
||||
x, y, z, = fx_pytree.tree_flatten_spec(([x, y, z], {}), self._in_spec)
|
||||
_guards_fn = self._guards_fn(x, y, z); _guards_fn = None
|
||||
|
||||
sym_size_int_4: "Sym(s17)" = torch.ops.aten.sym_size.int(y, 0); y = None
|
||||
sym_size_int_5: "Sym(s68)" = torch.ops.aten.sym_size.int(z, 0)
|
||||
|
||||
|
@ -249,7 +249,7 @@ with torch.no_grad():
|
||||
assert res is not None
|
||||
ep_file_path = res.get_exported_program_path()
|
||||
assert ep_file_path is not None
|
||||
gm = export_load(ep_file_path).module()
|
||||
gm = export_load(ep_file_path).module(check_guards=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
|
@ -63,7 +63,7 @@ class MinifierUtilsTests(TestCase):
|
||||
)
|
||||
|
||||
model = M()
|
||||
gm = torch.export.export(model, inputs, strict=False).module()
|
||||
gm = torch.export.export(model, inputs, strict=False).module(check_guards=False)
|
||||
|
||||
# TODO: make NNModuleToString.convert() generate string for nested submodules.
|
||||
model_string = get_module_string(gm)
|
||||
|
@ -162,7 +162,7 @@ def save_graph_repro_ep(
|
||||
assert args is not None
|
||||
exported_program = torch.export.export(gm, args, strict=strict)
|
||||
elif gm is None:
|
||||
gm = exported_program.module()
|
||||
gm = exported_program.module(check_guards=False)
|
||||
|
||||
# save a graph preview using gm
|
||||
module_string = get_module_string(gm) # type: ignore[arg-type]
|
||||
@ -302,7 +302,7 @@ def repro_common(
|
||||
options: Any, exported_program: ExportedProgram
|
||||
) -> tuple[torch.fx.GraphModule, Any, Any]:
|
||||
torch._inductor.config.generate_intermediate_hooks = True
|
||||
mod = exported_program.module()
|
||||
mod = exported_program.module(check_guards=False)
|
||||
args, kwargs = exported_program.example_inputs
|
||||
return mod, args, kwargs # type: ignore[return-value]
|
||||
|
||||
@ -368,7 +368,7 @@ def export_for_aoti_minifier(
|
||||
|
||||
try:
|
||||
ep = torch.export.export(gm, tuple_inputs, strict=strict)
|
||||
gm = ep.module()
|
||||
gm = ep.module(check_guards=False)
|
||||
return gm
|
||||
except Exception as e:
|
||||
if skip_export_error:
|
||||
|
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<e623701883a0cff67761e994ac9b3d5e44d3f27102c9420932a1275b5b0ad41d>>
|
||||
// checksum<<a1c01cb72b55ca996960afa7e3b5ab6714405b036d8a3c33a81084a84e58bbab>>
|
||||
|
||||
namespace py3 torch._export
|
||||
namespace cpp2 torch._export.schema
|
||||
@ -342,6 +342,7 @@ struct ExportedProgram {
|
||||
60: SchemaVersion schema_version;
|
||||
70: list<string> verifiers;
|
||||
80: string torch_version;
|
||||
90: list<string> guards_code;
|
||||
}
|
||||
|
||||
struct PayloadMeta {
|
||||
|
@ -9,7 +9,7 @@ from torch._export.serde.union import _Union, _union_dataclass
|
||||
|
||||
|
||||
# NOTE: Please update this value if any modifications are made to the schema
|
||||
SCHEMA_VERSION = (8, 13)
|
||||
SCHEMA_VERSION = (8, 14)
|
||||
TREESPEC_VERSION = 1
|
||||
|
||||
|
||||
@ -449,6 +449,7 @@ class ExportedProgram:
|
||||
schema_version: Annotated[SchemaVersion, 60]
|
||||
verifiers: Annotated[list[str], 70] = field(default_factory=list)
|
||||
torch_version: Annotated[str, 80] = "<=2.4"
|
||||
guards_code: Annotated[list[str], 90] = field(default_factory=list)
|
||||
|
||||
|
||||
#########################################################################
|
||||
|
@ -1,5 +1,5 @@
|
||||
# @generated by update_schema.py
|
||||
# checksum<<60fd32a0a8ae87c628c02d23641902e9339b813d4f553cdb39a2f9533b33060f>>
|
||||
# checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>>
|
||||
AOTInductorModelPickleData:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -140,6 +140,9 @@ ExportedProgram:
|
||||
torch_version:
|
||||
type: str
|
||||
default: <=2.4
|
||||
guards_code:
|
||||
type: List[str]
|
||||
default: '[]'
|
||||
ExternKernelNode:
|
||||
kind: struct
|
||||
fields:
|
||||
@ -548,5 +551,5 @@ UserOutputSpec:
|
||||
type: Argument
|
||||
SCHEMA_VERSION:
|
||||
- 8
|
||||
- 13
|
||||
- 14
|
||||
TREESPEC_VERSION: 1
|
||||
|
@ -1794,6 +1794,7 @@ class ExportedProgramSerializer(metaclass=Final):
|
||||
),
|
||||
verifiers=[v.dialect for v in exported_program.verifiers],
|
||||
torch_version=torch.__version__,
|
||||
guards_code=exported_program._guards_code,
|
||||
)
|
||||
|
||||
# Test canonical form is well defined.
|
||||
@ -3038,6 +3039,7 @@ class ExportedProgramDeserializer(metaclass=Final):
|
||||
constants=res.constants,
|
||||
verifiers=[load_verifier(v) for v in exported_program.verifiers],
|
||||
)
|
||||
result._guards_code = exported_program.guards_code
|
||||
log.debug("\n[deserialize]: %s", result)
|
||||
return result
|
||||
|
||||
@ -3502,6 +3504,7 @@ def canonicalize(
|
||||
range_constraints = dict(
|
||||
sorted(ep.range_constraints.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
guards_code = sorted(ep.guards_code)
|
||||
module_call_graph = sorted(ep.graph_module.module_call_graph, key=lambda x: x.fqn)
|
||||
signature = ep.graph_module.signature
|
||||
graph = ep.graph_module.graph
|
||||
@ -3698,6 +3701,7 @@ def canonicalize(
|
||||
schema_version=ep.schema_version,
|
||||
verifiers=ep.verifiers,
|
||||
torch_version=ep.torch_version,
|
||||
guards_code=guards_code,
|
||||
)
|
||||
|
||||
|
||||
|
@ -419,7 +419,7 @@ def _check_symint(
|
||||
# this means we deferred a guard from export analysis to runtime, let this pass
|
||||
# we'll add a runtime assert checking equality to this replacement expression
|
||||
pass
|
||||
elif arg != symint:
|
||||
elif arg != int(symint):
|
||||
path = get_keystr(keypath)
|
||||
if i is not None:
|
||||
path += f".shape[{i}]"
|
||||
|
@ -292,6 +292,15 @@ def aot_compile(
|
||||
"""
|
||||
from .compile_fx import _aoti_flatten_inputs, compile_fx_aot
|
||||
|
||||
if hasattr(gm, "_guards_fn"):
|
||||
# Do not compile the guards function, since it may contain checks
|
||||
# that are not currently supported by AOTI. In particular, non-Tensor
|
||||
# arguments are converted to None and will fail specialization checks.
|
||||
node = next(iter(gm.graph.find_nodes(op="call_module", target="_guards_fn")))
|
||||
gm.graph.erase_node(node)
|
||||
delattr(gm, "_guards_fn")
|
||||
gm.recompile()
|
||||
|
||||
flat_example_inputs, options = _aoti_flatten_inputs(
|
||||
gm, args, kwargs, options=options
|
||||
)
|
||||
|
@ -1250,7 +1250,7 @@ def aot_inductor_minifier_wrapper(
|
||||
|
||||
use_minifier = config.aot_inductor.dump_aoti_minifier
|
||||
|
||||
gm = exported_program.module()
|
||||
gm = exported_program.module(check_guards=False)
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
|
||||
args, kwargs = exported_program.example_inputs
|
||||
@ -1279,7 +1279,7 @@ def aot_inductor_minifier_wrapper(
|
||||
tuple_inputs = tuple(flat_example_inputs)
|
||||
flattened_ep = torch.export.export(gm_copy, tuple_inputs, strict=False)
|
||||
func(
|
||||
flattened_ep.module(),
|
||||
flattened_ep.module(check_guards=False),
|
||||
tuple_inputs,
|
||||
inductor_configs=config_copy,
|
||||
package_path=package_path,
|
||||
|
@ -361,7 +361,7 @@ def _get_aten_graph_module_for_pattern(
|
||||
example_inputs,
|
||||
kwargs,
|
||||
strict=True,
|
||||
).module()
|
||||
).module(check_guards=False)
|
||||
|
||||
aten_pattern.graph.eliminate_dead_code() # type: ignore[operator, union-attr]
|
||||
aten_pattern.recompile() # type: ignore[operator]
|
||||
|
13
torch/csrc/utils/generated_serialization_types.h
generated
13
torch/csrc/utils/generated_serialization_types.h
generated
@ -1,5 +1,5 @@
|
||||
// @generated by update_schema.py
|
||||
// checksum<<60fd32a0a8ae87c628c02d23641902e9339b813d4f553cdb39a2f9533b33060f>>
|
||||
// checksum<<74d07b92c36d5854263145c231553dcda15215f0460e7ace43554248c05378ec>>
|
||||
// clang-format off
|
||||
|
||||
#pragma once
|
||||
@ -3110,6 +3110,7 @@ class ExportedProgram {
|
||||
SchemaVersion schema_version;
|
||||
std::vector<std::string> verifiers = {};
|
||||
std::string torch_version = "<=2.4";
|
||||
std::vector<std::string> guards_code = {};
|
||||
|
||||
public:
|
||||
|
||||
@ -3161,6 +3162,14 @@ class ExportedProgram {
|
||||
torch_version = std::move(def);
|
||||
}
|
||||
|
||||
const std::vector<std::string>& get_guards_code() const {
|
||||
return guards_code;
|
||||
}
|
||||
|
||||
void set_guards_code(std::vector<std::string> def) {
|
||||
guards_code = std::move(def);
|
||||
}
|
||||
|
||||
friend void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nlohmann_json_t);
|
||||
friend void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t);
|
||||
};
|
||||
@ -3406,6 +3415,7 @@ inline void to_json(nlohmann::json& nlohmann_json_j, const ExportedProgram& nloh
|
||||
nlohmann_json_j["schema_version"] = nlohmann_json_t.schema_version;
|
||||
nlohmann_json_j["verifiers"] = nlohmann_json_t.verifiers;
|
||||
nlohmann_json_j["torch_version"] = nlohmann_json_t.torch_version;
|
||||
nlohmann_json_j["guards_code"] = nlohmann_json_t.guards_code;
|
||||
}
|
||||
|
||||
inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nlohmann_json_t) {
|
||||
@ -3416,6 +3426,7 @@ inline void from_json(const nlohmann::json& nlohmann_json_j, ExportedProgram& nl
|
||||
nlohmann_json_t.schema_version = nlohmann_json_j.value("schema_version", nlohmann_json_default_obj.schema_version);
|
||||
nlohmann_json_t.verifiers = nlohmann_json_j.value("verifiers", nlohmann_json_default_obj.verifiers);
|
||||
nlohmann_json_t.torch_version = nlohmann_json_j.value("torch_version", nlohmann_json_default_obj.torch_version);
|
||||
nlohmann_json_t.guards_code = nlohmann_json_j.value("guards_code", nlohmann_json_default_obj.guards_code);
|
||||
}
|
||||
|
||||
inline void to_json(nlohmann::json& nlohmann_json_j, const ExternKernelNode& nlohmann_json_t) {
|
||||
|
@ -681,7 +681,7 @@ class Pipe(torch.nn.Module):
|
||||
``output_loss_value_spec={'loss': True, 'model_out': False}``
|
||||
"""
|
||||
|
||||
traced = exported_program.module()
|
||||
traced = exported_program.module(check_guards=False)
|
||||
|
||||
if split_policy is not None:
|
||||
logger.info("Auto-splitting model")
|
||||
|
@ -540,6 +540,7 @@ def draft_export(
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any, ...], list[Any]]] = None,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
A version of torch.export.export which is designed to consistently produce
|
||||
@ -555,6 +556,7 @@ def draft_export(
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
strict=strict,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
|
||||
|
||||
|
@ -371,6 +371,7 @@ def draft_export(
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
strict: bool = False,
|
||||
pre_dispatch: bool = True,
|
||||
prefer_deferred_runtime_asserts_over_guards: bool = False,
|
||||
) -> ExportedProgram:
|
||||
start_time = time.time()
|
||||
kwargs = kwargs or {}
|
||||
@ -396,6 +397,7 @@ def draft_export(
|
||||
strict=strict,
|
||||
pre_dispatch=pre_dispatch,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
except Exception as exc:
|
||||
if (
|
||||
@ -420,6 +422,7 @@ def draft_export(
|
||||
strict=strict,
|
||||
pre_dispatch=pre_dispatch,
|
||||
preserve_module_call_signature=preserve_module_call_signature,
|
||||
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
|
||||
)
|
||||
else:
|
||||
log_draft_export_usage(
|
||||
|
@ -163,7 +163,7 @@ def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_module":
|
||||
if node.op == "call_module" and node.target != "_guards_fn":
|
||||
_try_remove_connecting_pytrees(node)
|
||||
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
@ -1,10 +1,14 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import inspect
|
||||
import math
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from itertools import chain
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.non_strict_utils import (
|
||||
@ -12,11 +16,16 @@ from torch._export.non_strict_utils import (
|
||||
_exit_enable_graph_inputs_of_type_nn_module,
|
||||
_get_graph_inputs_of_type_nn_module,
|
||||
)
|
||||
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
||||
_convert_range_to_int,
|
||||
)
|
||||
from torch._export.utils import _check_input_constraints_for_graph
|
||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from torch.fx.traceback import NodeSource, NodeSourceAction
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
||||
from ._tree_utils import reorder_kwargs
|
||||
@ -73,20 +82,107 @@ def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list:
|
||||
return flat_args_with_path
|
||||
|
||||
|
||||
def _convert_guards_code_to_fn(
|
||||
guards_code: list[str],
|
||||
paths_of_placeholders: list[pytree.KeyPath],
|
||||
):
|
||||
"""
|
||||
Generates Python code given guards code and paths of placeholders.
|
||||
We assume that, based on source information,
|
||||
- the tracer generates the guards code
|
||||
- the input spec generates the paths of placeholders.
|
||||
|
||||
Example:
|
||||
|
||||
Suppose we are given the guards code "L['z']['k'].size()[1] == 3"
|
||||
and we are given that ['z']['k'] is the path of placeholder #2.
|
||||
Then we will generate:
|
||||
```
|
||||
torch._assert(
|
||||
args[2].size()[0] == 3,
|
||||
"Guard failed: z['k'].size()[0] == 3",
|
||||
)
|
||||
```
|
||||
|
||||
FAQ: Why do we generate code based on (flattened) args instead of
|
||||
the original (unflattened) inputs? Because this would require
|
||||
inserting an additional pytree.unflatten call in our graph.
|
||||
|
||||
FAQ: Why do we not emit RuntimeError on guard failure as we used to?
|
||||
Because it is inconvenient :/, get used to AssertionError instead.
|
||||
"""
|
||||
|
||||
import ast
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP
|
||||
|
||||
actual_guards_code = []
|
||||
shadow_guards_code = []
|
||||
for c in guards_code:
|
||||
a, s = c, c
|
||||
for idx, path in enumerate(paths_of_placeholders):
|
||||
# e.g., replace L['z']['k'] with args[2] for Python code (actual)
|
||||
a = a.replace("L" + pytree.keystr(path), f"args[{idx}]")
|
||||
# e.g., replace L['z']['k'] with z['k'] for error message (shadow)
|
||||
s = s.replace(
|
||||
"L" + pytree.keystr(path),
|
||||
path[0].key + pytree.keystr(path[1:]), # type: ignore[attr-defined]
|
||||
)
|
||||
actual_guards_code.append(a)
|
||||
shadow_guards_code.append(s.replace("\n", ""))
|
||||
|
||||
# generate function code as str
|
||||
code_str = "\ndef _(*args):\n"
|
||||
for actual, shadow in zip(actual_guards_code, shadow_guards_code):
|
||||
# printing guards code may potentially introduce redundant parens;
|
||||
# we can normalize them out for readability by parsing/unparsing
|
||||
# NOTE: this is not necessary for correctness, just deemed desirable
|
||||
_shadow = ast.unparse(ast.parse(shadow, mode="eval"))
|
||||
# actual code and shadow error message
|
||||
code_str += f' torch._assert({actual}, "Guard failed: {_shadow}")\n'
|
||||
code_str += " return\n"
|
||||
|
||||
# populate namespace with sympy globals, materialize function (named `_`)
|
||||
namespace = {**SYMPY_INTERP}
|
||||
exec(code_str, namespace)
|
||||
|
||||
# create and return a module whose forward is the materialized function
|
||||
# NOTE: we want Dynamo to trace through this module, to repopulate guards:
|
||||
# otherwise we would lose them when retracing
|
||||
# NOTE: calling this module will be a side effect (no users): so it must
|
||||
# be marked impure to avoid being not cleaned up by DCE
|
||||
guards_fn = GuardsFn()
|
||||
guards_fn.forward = torch._dynamo.dont_skip_tracing(namespace["_"]) # type: ignore[call-overload, method-assign]
|
||||
guards_fn._is_impure = True # type: ignore[assignment]
|
||||
return guards_fn
|
||||
|
||||
|
||||
@torch._dynamo.disable
|
||||
def _check_input_constraints_pre_hook(self, args, kwargs):
|
||||
if not self.validate_inputs:
|
||||
return
|
||||
|
||||
def _check_input_constraints_for_module(self, args, kwargs):
|
||||
flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec)
|
||||
|
||||
_check_input_constraints_for_graph(
|
||||
[node for node in self.graph.nodes if node.op == "placeholder"],
|
||||
self.graph.find_nodes(op="placeholder"),
|
||||
flat_args_with_path,
|
||||
self.range_constraints,
|
||||
)
|
||||
|
||||
|
||||
def _check_input_constraints_pre_hook(self, args, kwargs):
|
||||
# preserve current behavior for clients that do not want any validation
|
||||
if not self.validate_inputs:
|
||||
return
|
||||
|
||||
# when a guards function exists, assume that the graph does calls it!
|
||||
# so we do not need to check input constraints...but we still want
|
||||
# to check inputs match, otherwise we'd get obscure pytree errors
|
||||
if hasattr(self, "_guards_fn"):
|
||||
_check_inputs_match(args, kwargs, self._in_spec)
|
||||
return
|
||||
|
||||
# NOTE: this call is Dynamo disabled, as it used to be
|
||||
_check_input_constraints_for_module(self, args, kwargs)
|
||||
|
||||
|
||||
def _unlift_inputs_as_getattr(
|
||||
gm: torch.fx.GraphModule,
|
||||
lifted_inputs: Sequence[Optional[str]],
|
||||
@ -419,10 +515,149 @@ def _create_stateful_graph_module(
|
||||
return stateful_gm
|
||||
|
||||
|
||||
def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.GraphModule:
|
||||
def _get_input_paths(example_inputs, signature):
|
||||
"""
|
||||
Generate paths of placeholders, needed for generating the guards function.
|
||||
|
||||
NOTE: Here we make use of the example inputs used for export as well as
|
||||
the signature of the unlifted graph module (not preserved by export).
|
||||
"""
|
||||
|
||||
args, kwargs = example_inputs
|
||||
ctx = signature.bind(*args, **kwargs).arguments
|
||||
flat_example_inputs_with_paths = pytree.tree_leaves_with_path(ctx)
|
||||
return [path for path, _ in flat_example_inputs_with_paths]
|
||||
|
||||
|
||||
def _get_input_guards_for_graph(
|
||||
placeholders: list[torch.fx.Node],
|
||||
range_constraints: dict[sympy.Symbol, ValueRanges],
|
||||
paths_for_placeholders: list[pytree.KeyPath],
|
||||
):
|
||||
"""
|
||||
Guards generated by the tracer include conditions observed in code, but
|
||||
but do not include some additional checks we typically do in export.
|
||||
For example, when dynamic shapes get specialized, are specified to be
|
||||
within a range, or are specified to be in some equational relation,
|
||||
corresponding input invalidation is done within a pre_hook, specifically,
|
||||
`_check_input_constraints_for_graph`.
|
||||
|
||||
Here we generate guards corresponding to the checks that happen in
|
||||
`_check_input_constraints_for_graph`, and add them to the guards already
|
||||
generated by the tracer. In the future, it may be worthwhile to separate
|
||||
them so that we can allow clients to turn off one but not the other.
|
||||
(Looking at you, AOTI.)
|
||||
|
||||
NOTE: We should eventually reconcile this logic with `build_guards` that
|
||||
is used by AOT Precompile.
|
||||
"""
|
||||
|
||||
deferred_expressions = []
|
||||
new_guards_code = []
|
||||
sources: dict[sympy.Expr, str] = {}
|
||||
|
||||
def handle_symint(expr, src):
|
||||
if len(expr.free_symbols) == 1:
|
||||
# complex equations (e.g., involving derived dims) need to
|
||||
# handled later, since we may not have enough information
|
||||
# just as we are passing through the placeholders in order
|
||||
deferred_expressions.append((src, expr))
|
||||
if expr in sources:
|
||||
# expressions that appear in multiple sources should force
|
||||
# inputs corresponding to those sources to be equal
|
||||
# e.g., x.shape[0] == y.shape[1]
|
||||
orig_src = sources[expr]
|
||||
new_guards_code.append(f"{src} == {orig_src}")
|
||||
else:
|
||||
sources[expr] = src
|
||||
# process value ranges as elsewhere in export
|
||||
min_val, max_val = _convert_range_to_int(range_constraints[expr])
|
||||
if min_val > 2:
|
||||
new_guards_code.append(f"{src} >= {min_val}")
|
||||
if max_val < math.inf:
|
||||
new_guards_code.append(f"{src} <= {max_val}")
|
||||
|
||||
for placeholder, path in zip(placeholders, paths_for_placeholders):
|
||||
src = "L" + pytree.keystr(path)
|
||||
meta = placeholder.meta["val"]
|
||||
# specializations
|
||||
if isinstance(meta, int):
|
||||
new_guards_code.append(f"{src} == {meta}")
|
||||
if isinstance(meta, float):
|
||||
if meta == math.inf:
|
||||
new_guards_code.append(f"{src} == math.inf")
|
||||
elif meta == -math.inf:
|
||||
new_guards_code.append(f"{src} == -math.inf")
|
||||
else:
|
||||
new_guards_code.append(f"{src} == {meta}")
|
||||
elif isinstance(meta, str):
|
||||
new_guards_code.append(f"{src} == '{meta}'")
|
||||
# range constraints and equalities
|
||||
elif isinstance(meta, torch.SymInt) and meta.node.expr in range_constraints:
|
||||
handle_symint(meta.node.expr, src)
|
||||
elif isinstance(meta, torch.Tensor):
|
||||
for i, dim in enumerate(meta.shape):
|
||||
src = "L" + pytree.keystr(path) + f".size()[{i}]"
|
||||
if isinstance(dim, int):
|
||||
# specializations
|
||||
new_guards_code.append(f"{src} == {dim}")
|
||||
elif (
|
||||
isinstance(dim, torch.SymInt) and dim.node.expr in range_constraints
|
||||
):
|
||||
# range constraints and equalities
|
||||
handle_symint(dim.node.expr, src)
|
||||
|
||||
unification_map: dict[sympy.Symbol, sympy.Expr] = {}
|
||||
py_printer = torch.utils._sympy.printers.PythonPrinter()
|
||||
|
||||
# process complex equations (e.g., involving derived dims)
|
||||
for src, expr in deferred_expressions:
|
||||
# we know this is the only symbol in expr (see check above)
|
||||
symbol = next(iter(expr.free_symbols))
|
||||
if symbol in sources:
|
||||
# if s0 is already known to be directly sourced from inputs,
|
||||
# e.g., z.shape[2], we do not need to do anything further
|
||||
# (assume we have already processed constraints on s0 above)
|
||||
continue
|
||||
|
||||
# otherwise s0 has some "hidden" source like 'dim'
|
||||
# example: src = y.shape[1], expr = s0 + 1
|
||||
if symbol in unification_map:
|
||||
# suppose that we already know that s0 = x.shape[0] * 2
|
||||
# so we can emit the guard: x.shape[0] * 2 + 1 = y.shape[1]
|
||||
substitution = expr.subs(unification_map)
|
||||
new_guards_code.append(
|
||||
py_printer.doprint(sympy.Eq(substitution, sympy.Symbol(src)))
|
||||
)
|
||||
else:
|
||||
# we do not yet know what s0 is, but given s0 + 1 = y.shape[1],
|
||||
# we can solve for s0...now knowing that s0 = y.shape[1] - 1
|
||||
solution = try_solve(sympy.Eq(expr, sympy.Symbol(src)), symbol)
|
||||
if solution is not None:
|
||||
definition = solution[1]
|
||||
unification_map[symbol] = definition
|
||||
|
||||
return new_guards_code
|
||||
|
||||
|
||||
def _unlift_exported_program_lifted_states(
|
||||
ep: ExportedProgram, check_guards=True
|
||||
) -> torch.fx.GraphModule:
|
||||
# force check_guards=False for executorch because
|
||||
# its pass infra has too many calls to .module()
|
||||
# and but does not like call modules in the graph
|
||||
# TODO: update executorch to check_guards=False
|
||||
frame = inspect.currentframe()
|
||||
while frame is not None:
|
||||
if "executorch" in frame.f_code.co_filename:
|
||||
check_guards = False
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# TODO T206340015
|
||||
if ep.verifiers[0].dialect != "TRAINING":
|
||||
ep = _remove_effect_tokens(ep)
|
||||
|
||||
new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph))
|
||||
_register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants)
|
||||
forward_arg_names = (
|
||||
@ -489,4 +724,37 @@ def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.fx.Grap
|
||||
)
|
||||
unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep)
|
||||
unlift_gm.meta.update(ep.graph_module.meta)
|
||||
|
||||
# create a _guards_fn submodule and insert a call to it after placeholders
|
||||
graph = unlift_gm.graph
|
||||
placeholders = graph.find_nodes(op="placeholder")
|
||||
if check_guards and placeholders and ep.example_inputs:
|
||||
input_paths = _get_input_paths(
|
||||
ep.example_inputs,
|
||||
inspect.signature(unlift_gm.forward),
|
||||
)
|
||||
guards_code = _get_input_guards_for_graph(
|
||||
placeholders, ep.range_constraints, input_paths
|
||||
)
|
||||
guards_code.extend(ep._guards_code)
|
||||
unlift_gm._guards_fn = _convert_guards_code_to_fn(guards_code, input_paths)
|
||||
|
||||
root_nn_module_stack = torch.fx._utils.first_call_function_nn_module_stack(
|
||||
graph
|
||||
)
|
||||
with graph.inserting_after(placeholders[-1]):
|
||||
node = graph.call_module("_guards_fn", tuple(placeholders))
|
||||
node.meta["nn_module_stack"] = root_nn_module_stack
|
||||
|
||||
unlift_gm.recompile()
|
||||
|
||||
return unlift_gm
|
||||
|
||||
|
||||
class GuardsFn(torch.nn.Module):
|
||||
"""
|
||||
Module class for guard functions.
|
||||
"""
|
||||
|
||||
def forward(self, *args):
|
||||
pass
|
||||
|
@ -887,7 +887,7 @@ class AdditionalInputs:
|
||||
|
||||
epm = ep.module()
|
||||
for args, kwargs in self._examples:
|
||||
torch.export._unlift._check_input_constraints_pre_hook(
|
||||
torch.export._unlift._check_input_constraints_for_module(
|
||||
epm, args, kwargs or {}
|
||||
)
|
||||
|
||||
|
@ -1047,6 +1047,8 @@ class ExportedProgram:
|
||||
_verifiers: list[type[Verifier]]
|
||||
"""List of verifier classes used to validate the exported program."""
|
||||
|
||||
_guards_code: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[torch.nn.Module, dict[str, Any]],
|
||||
@ -1084,6 +1086,8 @@ class ExportedProgram:
|
||||
# Validate should be always the last step of the constructor.
|
||||
self.validate()
|
||||
|
||||
self._guards_code = _convert_guards_to_code(_get_shape_env(self._graph_module))
|
||||
|
||||
@property
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def graph_module(self):
|
||||
@ -1379,13 +1383,20 @@ class ExportedProgram:
|
||||
)
|
||||
return string
|
||||
|
||||
def module(self) -> torch.fx.GraphModule:
|
||||
def module(self, check_guards=True) -> torch.fx.GraphModule:
|
||||
"""
|
||||
Returns a self contained GraphModule with all the parameters/buffers inlined.
|
||||
|
||||
- When `check_guards=True` (default), a `_guards_fn` submodule is generated
|
||||
and a call to a `_guards_fn` submodule is inserted right after placeholders
|
||||
in the graph. This module checks guards on inputs.
|
||||
- When `check_guards=False`, a subset of these checks are performed by a
|
||||
forward pre-hook on the graph module. No `_guards_fn` submodule is generated.
|
||||
|
||||
"""
|
||||
from ._unlift import _unlift_exported_program_lifted_states
|
||||
|
||||
module = _unlift_exported_program_lifted_states(self)
|
||||
module = _unlift_exported_program_lifted_states(self, check_guards=check_guards)
|
||||
|
||||
def _train(self, mode: bool = True):
|
||||
raise NotImplementedError("Calling train() is not supported yet.")
|
||||
@ -1677,3 +1688,25 @@ def _create_graph_module_for_export(root, graph):
|
||||
gm._graph = graph
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _convert_guards_to_code(shape_env):
|
||||
if shape_env is None:
|
||||
return []
|
||||
|
||||
local_vars = {
|
||||
var
|
||||
for var, sources in shape_env.var_to_sources.items()
|
||||
if all(
|
||||
not isinstance(source, torch._dynamo.source.ConstantSource)
|
||||
for source in sources
|
||||
)
|
||||
}
|
||||
py_printer = torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter(
|
||||
shape_env.var_to_sources, lambda s: s.name(), shape_env.var_to_sources
|
||||
)
|
||||
return [
|
||||
py_printer.doprint(guard.expr)
|
||||
for guard in shape_env.guards
|
||||
if guard.expr.free_symbols.issubset(local_vars)
|
||||
]
|
||||
|
@ -95,7 +95,7 @@ class SubgraphMatcher:
|
||||
)
|
||||
|
||||
for node in pattern.nodes:
|
||||
if node.op != "output":
|
||||
if node.op != "output" and not node.is_impure():
|
||||
assert len(node.users) > 0, (
|
||||
"SubgraphMatcher cannot be initialized with an pattern with dead code"
|
||||
)
|
||||
|
@ -3394,7 +3394,7 @@ def _generate_qdq_quantized_model(
|
||||
|
||||
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
|
||||
with maybe_no_grad:
|
||||
export_model = export_for_training(mod, inputs, strict=True).module()
|
||||
export_model = export_for_training(mod, inputs, strict=True).module(check_guards=False)
|
||||
quantizer = (
|
||||
quantizer
|
||||
if quantizer
|
||||
|
Reference in New Issue
Block a user