[Codemod][AddExplicitStrictExportArg] caffe2/test (#143688)

Reviewed By: avikchaudhuri

Differential Revision: D67530154

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143688
Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
Yanan Cao (PyTorch)
2024-12-27 07:58:44 +00:00
committed by PyTorch MergeBot
parent 969415885d
commit ba5cacbc17
29 changed files with 149 additions and 119 deletions

View File

@ -78,6 +78,7 @@ Result:
example_case.example_args,
example_case.example_kwargs,
dynamic_shapes=example_case.dynamic_shapes,
strict=True,
)
graph_output = str(exported_program)
graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output)

View File

@ -73,8 +73,7 @@ class TensorParallelTest(DTensorTestBase):
with torch.no_grad():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
model, inputs, strict=True
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,
@ -111,8 +110,7 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
model, inputs, strict=True
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,
@ -147,8 +145,7 @@ class TensorParallelTest(DTensorTestBase):
with torch.inference_mode():
res = model(*inputs)
exported_program = torch.export.export(
model,
inputs,
model, inputs, strict=True
).run_decompositions()
tp_exported_program = tensor_parallel_transformation(
exported_program,

View File

@ -3,6 +3,7 @@
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
with test_export_persist_assert)
"""
import copy
import functools
import inspect
@ -2479,7 +2480,7 @@ def forward(self, x):
random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),)
dim_x = torch.export.Dim("dim_x", min=1, max=32)
exp_program = torch.export.export(
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}, strict=True
)
output_buffer = io.BytesIO()
# Tests if we can restore saved nn.Parameters when we load them again
@ -2509,7 +2510,9 @@ def forward(self, x):
batchsize = torch.export.Dim("dim0", min=3, max=1024)
dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}
torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec)
torch.export.export(
model, (a, b), dynamic_shapes=dynamic_shape_spec, strict=True
)
def test_export_fast_binary_broadcast_check_unbacked(self):
class MyModel(torch.nn.Module):
@ -2522,7 +2525,7 @@ def forward(self, x):
model = MyModel().eval().cuda()
numel = torch.tensor(10)
scalar = torch.randn(1)
torch.export.export(model, (numel, scalar))
torch.export.export(model, (numel, scalar), strict=True)
def test_export_meta(self):
class MyModule(torch.nn.Module):
@ -2563,7 +2566,7 @@ def forward(self, x):
"by dim0 = 2\\*dim1(.*\n)*.*"
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
):
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes, strict=True)
class Bar(torch.nn.Module):
def forward(self, x):
@ -2581,7 +2584,7 @@ def forward(self, x):
torch._dynamo.exc.UserError,
"Not all values.*valid.*inferred to be a constant",
):
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes)
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True)
class Qux(torch.nn.Module):
def forward(self, x):
@ -2599,7 +2602,7 @@ def forward(self, x):
torch._dynamo.exc.UserError,
"Not all values.*satisfy the generated guard",
):
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes)
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes, strict=True)
def test_untracked_inputs_in_constraints(self):
from copy import copy
@ -2617,7 +2620,9 @@ def forward(self, x):
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
example_inputs = (copy(x), y)
ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes)
ep = torch.export.export(
foo, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
)
ep.module()(torch.randn(3), y) # no specialization error
def test_export_raise_guard_full_constraint(self):
@ -2734,6 +2739,7 @@ def forward(self, x):
foo,
(a, {"k": b}),
dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
strict=True,
)
def test_enforce_equalities(self):
@ -2752,16 +2758,10 @@ def forward(self, x):
torch._dynamo.exc.UserError,
".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
):
torch.export.export(
bar,
(x, y),
dynamic_shapes=dynamic_shapes,
)
torch.export.export(bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True)
y = torch.randn(10, 3, 3)
ebar = torch.export.export(
bar,
(x, y),
dynamic_shapes=dynamic_shapes,
bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True
)
self.assertEqual(
[
@ -2923,15 +2923,15 @@ def forward(self, x):
torch._dynamo.exc.UserError,
r"Constraints violated \(dim0\)",
):
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes, strict=True)
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes)
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes, strict=True)
with self.assertRaisesRegex(
torch._dynamo.exc.UserError,
r"Constraints violated \(dim0\)",
):
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes, strict=True)
def test_list_contains(self):
def func(x):

View File

@ -24,7 +24,7 @@ class FxPassesPreGradTests(torch._dynamo.test_case.TestCase):
sample_input = torch.randn(4, 4)
m = TestModule()
m(sample_input)
exported_program = torch.export.export(m, (sample_input,))
exported_program = torch.export.export(m, (sample_input,), strict=True)
gm = exported_program.graph_module
pass_execution_and_save(fx_pass, gm, sample_input, "Apply testing pass")

View File

@ -72,7 +72,7 @@ class SourceTests(torch._dynamo.test_case.TestCase):
lambda x, _: CausalLMOutputWithPast(),
)
torch.export.export(Model(), ())
torch.export.export(Model(), (), strict=True)
if __name__ == "__main__":

View File

@ -342,7 +342,7 @@ class TestDraftExport(TestCase):
inputs = (torch.randn(3, 3),)
with self.assertRaises(RuntimeError):
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
export(mod, inputs)
export(mod, inputs, strict=True)
ep, report = draft_export(mod, inputs)
for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)):
@ -384,7 +384,7 @@ class TestDraftExport(TestCase):
"Real tensor propagation found an aliasing mismatch",
):
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
export(mod, inputs)
export(mod, inputs, strict=True)
ep, report = draft_export(mod, inputs)
for ep_out, eager_out in zip(

View File

@ -83,7 +83,7 @@ class TestHOP(TestCase):
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
args = (*input, *inp.args)
kwargs = inp.kwargs
ep = export(model, args, kwargs)
ep = export(model, args, kwargs, strict=True)
self._compare(model, ep, args, kwargs)
@ops(hop_tests, allowed_dtypes=(torch.float,))

View File

@ -24,7 +24,7 @@ class TestPassInfra(TestCase):
class NullPass(_ExportPassBaseDeprecatedDoNotUse):
pass
ep = export(f, (torch.ones(3, 2),))
ep = export(f, (torch.ones(3, 2),), strict=True)
old_nodes = ep.graph.nodes
ep = ep._transform_do_not_use(NullPass())
@ -66,7 +66,7 @@ class TestPassInfra(TestCase):
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
_ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use(
_ = export(mod, (torch.tensor(True), x, y), strict=True)._transform_do_not_use(
_ExportPassBaseDeprecatedDoNotUse()
)
@ -98,7 +98,7 @@ class TestPassInfra(TestCase):
inps = (torch.rand(1), torch.rand(1))
m = CustomModule()
ep_before = export(m, inps)
ep_before = export(m, inps, strict=True)
# No op transformation that doesn't perform any meaningful changes to node
ep_after = ep_before._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse())
@ -131,7 +131,9 @@ class TestPassInfra(TestCase):
input_tensor1 = torch.tensor(5.0)
input_tensor2 = torch.tensor(6.0)
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
ep_before = torch.export.export(
my_module, (input_tensor1, input_tensor2), strict=True
)
from torch.fx.passes.infra.pass_base import PassResult
def modify_input_output_pass(gm):
@ -169,7 +171,7 @@ class TestPassInfra(TestCase):
my_module = CustomModule()
inputs = (torch.tensor(6.0), torch.tensor(7.0))
ep_before = export(my_module, inputs)
ep_before = export(my_module, inputs, strict=True)
def replace_pass(gm):
for node in gm.graph.nodes:

View File

@ -404,7 +404,9 @@ class TestPasses(TestCase):
x = torch.zeros(2, 2, 3)
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
ep = torch.export.export(
M(), (x,), dynamic_shapes={"x": {1: dim1_x}}, strict=True
)
with self.assertRaisesRegex(
RuntimeError,
@ -431,7 +433,10 @@ class TestPasses(TestCase):
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
M(),
(x, y),
dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}},
strict=True,
)
with self.assertRaisesRegex(
@ -461,7 +466,10 @@ class TestPasses(TestCase):
dim0_x = torch.export.Dim("dim0_x", min=3)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
M(),
(x, y),
dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None},
strict=True,
)
with self.assertRaisesRegex(
@ -496,7 +504,7 @@ class TestPasses(TestCase):
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
ep = torch.export.export(
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}, strict=True
)
with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
@ -526,7 +534,7 @@ class TestPasses(TestCase):
x = torch.zeros(4, 2, 3)
ep = export(M(), (x,))
ep = export(M(), (x,), strict=True)
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)
ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass())
@ -542,7 +550,7 @@ class TestPasses(TestCase):
x = torch.zeros(4, 2, 3)
foo = Module()
ep = export(foo, (x,))._transform_do_not_use(
ep = export(foo, (x,), strict=True)._transform_do_not_use(
ReplaceViewOpsWithViewCopyOpsPass()
)
# After this pass, there shouldn't be any view nodes in the graph
@ -684,7 +692,7 @@ def forward(self, token, obj_attr, x):
x = torch.tensor([2])
mod = M()
ep = export(mod, (x,))
ep = export(mod, (x,), strict=True)
with self.assertRaisesRegex(
RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5"
@ -709,7 +717,9 @@ def forward(self, token, obj_attr, x):
mod = M()
dim0_x = torch.export.Dim("dim0_x")
ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}})
ep = torch.export.export(
mod, (x,), dynamic_shapes={"x": {0: dim0_x}}, strict=True
)
num_assert = count_call_function(
ep.graph, torch.ops.aten._assert_scalar.default
@ -762,7 +772,7 @@ def forward(self, token, obj_attr, x):
x = torch.tensor([2])
y = torch.tensor([5])
mod = M()
ep = export(mod, (torch.tensor(True), x, y))
ep = export(mod, (torch.tensor(True), x, y), strict=True)
with self.assertRaisesRegex(
RuntimeError, "is outside of inline constraint \\[2, 5\\]."
@ -779,7 +789,7 @@ def forward(self, token, obj_attr, x):
func = Module()
x = torch.randn(1, dtype=torch.float32)
ep = torch.export.export(func, args=(x,))
ep = torch.export.export(func, args=(x,), strict=True)
_ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
def test_predispatch_set_grad(self):
@ -1231,7 +1241,7 @@ def forward(self, add_1):
mod = M()
x = torch.randn([3, 3])
ep = export(mod, (x,))
ep = export(mod, (x,), strict=True)
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
nodes = inplace_ep.graph.nodes
for node in nodes:
@ -1274,7 +1284,7 @@ def forward(self, add_1):
mod = M()
x = torch.randn([3, 3])
ep = export(mod, (x,)).run_decompositions({})
ep = export(mod, (x,), strict=True).run_decompositions({})
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
graph_text = str(inplace_ep.graph)
self.assertExpectedInline(
@ -1304,7 +1314,7 @@ default](args = (%x, %b_state), kwargs = {})
# move the exported program from cpu to cuda:0
mod = Model()
example_inputs = (torch.rand(1, 10, 4),)
ep = export(mod, example_inputs)
ep = export(mod, example_inputs, strict=True)
location = torch.device("cuda:0")
ep = move_to_device_pass(ep, location=location)
gm = ep.module()

View File

@ -141,7 +141,7 @@ class TestSparseProp(TestCase):
index_dtype=itype,
):
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
prog = torch.export.export(net, (sparse_input,), strict=True)
# Test arg/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -163,7 +163,7 @@ class TestSparseProp(TestCase):
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
prog = torch.export.export(net, (sparse_input,), strict=True)
# Test arg/sum/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -187,7 +187,7 @@ class TestSparseProp(TestCase):
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
prog = torch.export.export(net, (sparse_input,), strict=True)
# Test arg/neg/abs/mul/relu/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -209,7 +209,7 @@ class TestSparseProp(TestCase):
):
result = net(sparse_input)
# Build the traced graph.
prog = torch.export.export(net, (sparse_input,))
prog = torch.export.export(net, (sparse_input,), strict=True)
# Test arg/todense/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -235,7 +235,7 @@ class TestSparseProp(TestCase):
S = A.to_sparse_csr()
result = net(S, Y)
# Build the traced graph.
prog = torch.export.export(net, (S, Y))
prog = torch.export.export(net, (S, Y), strict=True)
# Test args/add/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -253,7 +253,7 @@ class TestSparseProp(TestCase):
x = [torch.randn(3, 3) for _ in range(3)]
result = net(x)
# Build the traced graph.
prog = torch.export.export(net, args=(x,))
prog = torch.export.export(net, args=(x,), strict=True)
# Test args/to_sparse/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)
@ -269,7 +269,7 @@ class TestSparseProp(TestCase):
x = [torch.randn(3, 3) for _ in range(3)]
result = net(x)
# Build the traced graph.
prog = torch.export.export(net, args=(x,))
prog = torch.export.export(net, args=(x,), strict=True)
# Test args/to_sparse/output.
for i, node in enumerate(prog.graph.nodes):
meta = node.meta.get("val", None)

View File

@ -436,7 +436,7 @@ def forward(self, x, y):
def forward(self, a, b):
return (CustomOutput(a * a, b * b), CustomOutput(a * b.T, a + b.T))
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)))
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)), strict=True)
swapped = _swap_modules(ep, {})
inp = (torch.randn(2, 3), torch.randn(3, 2))
res1 = torch.fx.Interpreter(swapped).run(*inp)

View File

@ -94,7 +94,7 @@ class TestUnflatten(TestCase):
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {})
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
@ -134,7 +134,7 @@ class TestUnflatten(TestCase):
return x * self.rootparam
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
unflattened_module = unflatten(export_module)
# Buffer should look the same before and after one run
@ -170,7 +170,7 @@ class TestUnflatten(TestCase):
return x
eager_module = MyModule()
export_module = export(eager_module, (torch.rand(2, 3),), {})
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
unflattened_module = unflatten(export_module)
inputs = (torch.rand(2, 3),)
@ -193,7 +193,7 @@ class TestUnflatten(TestCase):
eager_module = Shared()
inps = (torch.rand(10),)
export_module = export(eager_module, inps, {})
export_module = export(eager_module, inps, {}, strict=True)
unflattened_module = unflatten(export_module)
self.compare_outputs(eager_module, unflattened_module, inps)
self.assertTrue(hasattr(unflattened_module, "sub_net"))
@ -297,7 +297,7 @@ class TestUnflatten(TestCase):
x = x + self.param_dict[f"key_{i}"]
return x
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
unflattened = unflatten(export_module)
self.compare_outputs(
@ -348,7 +348,7 @@ class TestUnflatten(TestCase):
a = a + self.param_dict[f"key_{i}"].sum()
return a
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
with self.assertRaisesRegex(
RuntimeError,
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
@ -404,7 +404,9 @@ class TestUnflatten(TestCase):
return x
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
@ -431,7 +433,7 @@ class TestUnflatten(TestCase):
orig_eager = MyModule()
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
export_module = export(orig_eager, inputs, {})
export_module = export(orig_eager, inputs, {}, strict=True)
unflattened = unflatten(export_module)
torch.fx.symbolic_trace(
@ -463,7 +465,9 @@ class TestUnflatten(TestCase):
return x + self.submod.subsubmod(x)
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
@ -499,7 +503,7 @@ class TestUnflatten(TestCase):
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
mod = Foo()
ep_strict = torch.export.export(mod, inp) # noqa: F841
ep_strict = torch.export.export(mod, inp, strict=True) # noqa: F841
ep_non_strict = torch.export.export(mod, inp, strict=False)
gm_unflat_non_strict = unflatten(ep_non_strict)
@ -523,7 +527,9 @@ class TestUnflatten(TestCase):
return x + sum(self.submod(x))
orig_eager = MyModule()
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
export_module = torch.export.export(
orig_eager, (torch.rand(2, 3),), {}, strict=True
)
unflattened = unflatten(export_module)
inputs = (torch.rand(2, 3),)
@ -551,7 +557,7 @@ class TestUnflatten(TestCase):
mod = M()
inputs = (torch.randn(3, 3, device="meta"),)
ep = export(mod, inputs)
ep = export(mod, inputs, strict=True)
unflattened = unflatten(ep)
self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
self.assertTrue(unflattened.p.requires_grad is False)
@ -567,7 +573,7 @@ class TestUnflatten(TestCase):
return x.transpose(0, 1)
x = torch.randn(32, 3, 64, 64)
exported_program = export(TransposeModule(), args=(x,))
exported_program = export(TransposeModule(), args=(x,), strict=True)
unflattened_module = unflatten(exported_program)
# Check the inputs of the created call_module node are in order
@ -599,7 +605,7 @@ class TestUnflatten(TestCase):
def forward(self, x):
return x + self.submod(x)
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
unflattened = unflatten(export_module)
self.compare_outputs(
@ -751,7 +757,7 @@ class TestUnflatten(TestCase):
mod = Module()
ep = torch.export.export(mod, (torch.randn(3, 4),))
ep = torch.export.export(mod, (torch.randn(3, 4),), strict=True)
unflattened = torch.export.unflatten(ep)
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
@ -808,7 +814,7 @@ class TestUnflatten(TestCase):
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps)
ep = export(m, inps, strict=True)
unep = unflatten(ep)
self.assertTrue(id(unep.m.bias) == id(unep.bias))
@ -827,7 +833,7 @@ class TestUnflatten(TestCase):
m = Foo()
inps = (torch.randn(4, 4),)
ep = export(m, inps)
ep = export(m, inps, strict=True)
unep = unflatten(ep)
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
@ -849,7 +855,7 @@ class TestUnflatten(TestCase):
mod = M()
x = torch.randn(4, 8)
ep = export(mod, (x,))
ep = export(mod, (x,), strict=True)
unflattened = unflatten(ep)
torch.testing.assert_close(unflattened(x), mod(x))
@ -942,7 +948,7 @@ class TestUnflatten(TestCase):
return x
orig_eager = MyModule()
export_module = export(orig_eager, (torch.rand(2, 3),), {})
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
with _disable_interpreter():
unflattened = unflatten(export_module)

View File

@ -6606,7 +6606,7 @@ class TestHopSchema(TestCase):
x,
)
model = M()
torch.export.export(model, args)
torch.export.export(model, args, strict=True)
graph_str = self._check_export(model, args, None)
self.assertExpectedInline(
graph_str,

View File

@ -193,7 +193,9 @@ class TestSplitOutputType(TestCase):
relu
"""
tag_node = defaultdict(list)
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
gm: torch.fx.GraphModule = torch.export.export(
module, (inputs,), strict=True
).module()
# Add tag to all nodes and build dictionary record tag to call_module nodes
for node in gm.graph.nodes:
if "conv" in node.name:

View File

@ -92,10 +92,7 @@ class TestFXNodeSource(TestCase):
model = Model()
example_inputs = (torch.randn(8, 10),)
ep = torch.export.export(
model,
example_inputs,
)
ep = torch.export.export(model, example_inputs, strict=True)
gm = ep.module()
provenance = get_graph_provenance_json(gm.graph)
provenance = json.loads(provenance)

View File

@ -221,7 +221,7 @@ class TestSourceMatcher(JitTestCase):
torch._check(b + 1 < y.size(0))
return y[: b + 1]
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)), strict=True)
fake_inputs = [
node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
]

View File

@ -1885,7 +1885,7 @@ class AOTInductorTestsTemplate:
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
# Export on CPU
exported_program = export(Model(), example_inputs)
exported_program = export(Model(), example_inputs, strict=True)
# Compile exported model on GPU
gm = exported_program.graph_module.to(self.device)

View File

@ -156,10 +156,7 @@ class TestAOTInductorPackage(TestCase):
torch.manual_seed(0)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
ep = torch.export.export(
model,
example_inputs,
)
ep = torch.export.export(model, example_inputs, strict=True)
with fresh_inductor_cache():
# cubin files are removed when exiting this context
package_path = torch._inductor.aoti_compile_and_package(
@ -304,7 +301,7 @@ class TestAOTInductorPackage(TestCase):
torch.randn(3, 4, device=self.device),
)
ep1 = torch.export.export(
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes, strict=True
)
aoti_files1 = torch._inductor.aot_compile(
ep1.module(), example_inputs1, options=options
@ -321,7 +318,7 @@ class TestAOTInductorPackage(TestCase):
return x * t
example_inputs2 = (torch.randn(5, 5, device=self.device),)
ep2 = torch.export.export(Model2(self.device), example_inputs2)
ep2 = torch.export.export(Model2(self.device), example_inputs2, strict=True)
aoti_files2 = torch._inductor.aot_compile(
ep2.module(), example_inputs2, options=options
)
@ -360,7 +357,7 @@ class TestAOTInductorPackage(TestCase):
)
self.check_model(Model1(), example_inputs1)
ep1 = torch.export.export(
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes, strict=True
)
aoti_files1 = torch._inductor.aot_compile(
ep1.module(), example_inputs1, options=options
@ -372,7 +369,7 @@ class TestAOTInductorPackage(TestCase):
torch.randn(3, 4, device=device),
)
ep2 = torch.export.export(
Model1(), example_inputs2, dynamic_shapes=dynamic_shapes
Model1(), example_inputs2, dynamic_shapes=dynamic_shapes, strict=True
)
aoti_files2 = torch._inductor.aot_compile(
ep2.module(), example_inputs2, options=options
@ -404,7 +401,7 @@ class TestAOTInductorPackage(TestCase):
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep = torch.export.export(Model(), example_inputs)
ep = torch.export.export(Model(), example_inputs, strict=True)
aoti_files = torch._inductor.aot_compile(
ep.module(),
example_inputs,
@ -433,12 +430,10 @@ class TestAOTInductorPackage(TestCase):
torch.randn(2, 4, device=self.device),
torch.randn(3, 4, device=self.device),
)
ep = torch.export.export(Model(), example_inputs)
ep = torch.export.export(Model(), example_inputs, strict=True)
buffer = io.BytesIO()
buffer = torch._inductor.aoti_compile_and_package(
ep, package_path=buffer
) # type: ignore[arg-type]
buffer = torch._inductor.aoti_compile_and_package(ep, package_path=buffer) # type: ignore[arg-type]
for _ in range(2):
loaded = load_package(buffer)
self.assertTrue(

View File

@ -149,6 +149,7 @@ class TestExportAPIDynamo(common_utils.TestCase):
2: torch.export.Dim("customb_dim_2"),
},
},
strict=True,
)
self.assert_export(exported_program)

View File

@ -273,7 +273,9 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
):
with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
ref_model = torch.export.export(ref_model, args=ref_input_args)
ref_model = torch.export.export(
ref_model, args=ref_input_args, strict=True
)
if (
self.dynamic_shapes
): # TODO: Support dynamic shapes for torch.export.ExportedProgram

View File

@ -138,7 +138,7 @@ class TestModularizePass(common_utils.TestCase):
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
)
else:
model = TestModule()
@ -185,7 +185,7 @@ class TestModularizePass(common_utils.TestCase):
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
)
else:
model = TestModule()
@ -239,7 +239,7 @@ class TestModularizePass(common_utils.TestCase):
if is_exported_program:
model = torch.export.export(
TestModule(), args=(torch.randn(3), torch.randn(3))
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
)
else:
model = TestModule()

View File

@ -194,7 +194,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
x = torch.randn(2, 3)
with torch.no_grad():
exported_program = torch.export.export(Model(), args=(x,))
exported_program = torch.export.export(Model(), args=(x,), strict=True)
_ = torch.onnx.dynamo_export(
exported_program,
x,

View File

@ -57,7 +57,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
x = torch.randn(2, 3, 4, dtype=torch.float)
dim0 = torch.export.Dim("dim0")
exported_program = torch.export.export(
Model(), (x,), dynamic_shapes={"x": {0: dim0}}
Model(), (x,), dynamic_shapes={"x": {0: dim0}}, strict=True
)
onnx_program = torch.onnx.dynamo_export(exported_program, x)
@ -75,7 +75,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
return x + 1.0
x = torch.randn(1, 1, 2, dtype=torch.float)
exported_program = torch.export.export(Model(), args=(x,))
exported_program = torch.export.export(Model(), args=(x,), strict=True)
onnx_program = torch.onnx.dynamo_export(exported_program, x)
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
@ -101,7 +101,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
# specialized input y to 5 during tracing
exported_program = torch.export.export(
f, (tensor_input, 5), dynamic_shapes=dynamic_shapes
f, (tensor_input, 5), dynamic_shapes=dynamic_shapes, strict=True
)
onnx_program = torch.onnx.dynamo_export(exported_program, tensor_input, 5)
@ -134,13 +134,16 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
return bar.sum() + self.buf.sum()
tensor_input = torch.ones(5, 5)
exported_program = torch.export.export(Foo(), (tensor_input,))
exported_program = torch.export.export(Foo(), (tensor_input,), strict=True)
dim0_x = torch.export.Dim("dim0_x")
# NOTE: If input is ExportedProgram, we need to specify dynamic_shapes
# as a tuple.
reexported_program = torch.export.export(
exported_program.module(), (tensor_input,), dynamic_shapes=({0: dim0_x},)
exported_program.module(),
(tensor_input,),
dynamic_shapes=({0: dim0_x},),
strict=True,
)
reexported_onnx_program = torch.onnx.dynamo_export(
reexported_program, tensor_input
@ -162,7 +165,10 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
dim = torch.export.Dim("dim")
exported_program = torch.export.export(
foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim})
foo,
(torch.randn(4, 4), torch.randn(4, 4)),
dynamic_shapes=(None, {0: dim}),
strict=True,
)
onnx_program = torch.onnx.dynamo_export(
exported_program, torch.randn(4, 4), torch.randn(4, 4)
@ -192,6 +198,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
# We are specifying dynamism on the first kwarg even though user passed in
# different order
dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None),
strict=True,
)
onnx_program = torch.onnx.dynamo_export(
exported_program,
@ -237,6 +244,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
input_b,
),
dynamic_shapes=({0: dim}, {0: dim}),
strict=True,
)
onnx_program = torch.onnx.dynamo_export(exported_program, input_x, input_b)
@ -276,7 +284,9 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
return None
dynamic_shapes = torch_pytree.tree_map(dynamify_inp, inp)
exported_program = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
exported_program = torch.export.export(
foo, inp, dynamic_shapes=dynamic_shapes, strict=True
)
onnx_program = torch.onnx.dynamo_export(exported_program, inp_a, inp_b)
# NOTE: Careful with the input format. The input format should be
@ -302,6 +312,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
(),
{"x": torch.randn(3, 3), "y": torch.randn(3, 3)},
dynamic_shapes=dynamic_shapes,
strict=True,
)
onnx_program = torch.onnx.dynamo_export(
exported_program, x=torch.randn(3, 3), y=torch.randn(3, 3)

View File

@ -61,7 +61,7 @@ class TestMetaDataPorting(QuantizationTestCase):
def _test_quant_tag_preservation_through_decomp(
self, model, example_inputs, from_node_to_tags
):
ep = torch.export.export(model, example_inputs)
ep = torch.export.export(model, example_inputs, strict=True)
found_tags = True
not_found_nodes = ""
for from_node, tag in from_node_to_tags.items():

View File

@ -133,7 +133,7 @@ class TestNumericDebugger(TestCase):
def test_copy_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs)
ep = torch.export.export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -148,7 +148,7 @@ class TestNumericDebugger(TestCase):
def test_deepcopy_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = torch.export.export(m, example_inputs)
ep = torch.export.export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)

View File

@ -1460,7 +1460,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
with TemporaryFileName() as fname:
# serialization
quantized_ep = torch.export.export(m, example_inputs)
quantized_ep = torch.export.export(m, example_inputs, strict=True)
torch.export.save(quantized_ep, fname)
# deserialization
loaded_ep = torch.export.load(fname)

View File

@ -1130,7 +1130,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
exported_model = torch.export.export(model_pt2e, example_inputs)
exported_model = torch.export.export(model_pt2e, example_inputs, strict=True)
node_occurrence = {
# conv2d: 1 for act, 1 for weight, 1 for output

View File

@ -912,7 +912,9 @@ class FakeTensorTest(TestCase):
return input + np.random.randn(*input.shape)
with FakeTensorMode():
ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
ep = torch.export.export(
MyNumpyModel(), args=(torch.randn(1000),), strict=True
)
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
def test_unsqueeze_copy(self):

View File

@ -61,11 +61,10 @@ class TestOutDtypeOp(TestCase):
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
ep = torch.export.export(
m,
(x,),
)
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
ep = torch.export.export(m, (x,), strict=True)
FileCheck().check("torch.ops.higher_order.out_dtype").check(
"aten.mm.default"
).run(ep.graph_module.code)
self.assertTrue(torch.allclose(m(x), ep.module()(x)))
for node in ep.graph.nodes:
if node.op == "call_function" and node.target is out_dtype:
@ -128,7 +127,12 @@ class TestOutDtypeOp(TestCase):
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
_ = torch.export.export(
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
M(),
(
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
),
strict=True,
)
def test_out_dtype_non_op_overload(self):