mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Codemod][AddExplicitStrictExportArg] caffe2/test (#143688)
Reviewed By: avikchaudhuri Differential Revision: D67530154 Pull Request resolved: https://github.com/pytorch/pytorch/pull/143688 Approved by: https://github.com/tugsbayasgalan
This commit is contained in:
committed by
PyTorch MergeBot
parent
969415885d
commit
ba5cacbc17
@ -78,6 +78,7 @@ Result:
|
||||
example_case.example_args,
|
||||
example_case.example_kwargs,
|
||||
dynamic_shapes=example_case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
graph_output = str(exported_program)
|
||||
graph_output = re.sub(r" # File(.|\n)*?\n", "", graph_output)
|
||||
|
@ -73,8 +73,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||
with torch.no_grad():
|
||||
res = model(*inputs)
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
model, inputs, strict=True
|
||||
).run_decompositions()
|
||||
tp_exported_program = tensor_parallel_transformation(
|
||||
exported_program,
|
||||
@ -111,8 +110,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||
with torch.inference_mode():
|
||||
res = model(*inputs)
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
model, inputs, strict=True
|
||||
).run_decompositions()
|
||||
tp_exported_program = tensor_parallel_transformation(
|
||||
exported_program,
|
||||
@ -147,8 +145,7 @@ class TensorParallelTest(DTensorTestBase):
|
||||
with torch.inference_mode():
|
||||
res = model(*inputs)
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
inputs,
|
||||
model, inputs, strict=True
|
||||
).run_decompositions()
|
||||
tp_exported_program = tensor_parallel_transformation(
|
||||
exported_program,
|
||||
|
@ -3,6 +3,7 @@
|
||||
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
||||
with test_export_persist_assert)
|
||||
"""
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
@ -2479,7 +2480,7 @@ def forward(self, x):
|
||||
random_inputs = (torch.rand([32, 3, 32, 32]).to("cuda"),)
|
||||
dim_x = torch.export.Dim("dim_x", min=1, max=32)
|
||||
exp_program = torch.export.export(
|
||||
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}
|
||||
model, random_inputs, dynamic_shapes={"x": {0: dim_x}}, strict=True
|
||||
)
|
||||
output_buffer = io.BytesIO()
|
||||
# Tests if we can restore saved nn.Parameters when we load them again
|
||||
@ -2509,7 +2510,9 @@ def forward(self, x):
|
||||
batchsize = torch.export.Dim("dim0", min=3, max=1024)
|
||||
dynamic_shape_spec = {"a": [batchsize, None, None], "b": [None, None]}
|
||||
|
||||
torch.export.export(model, (a, b), dynamic_shapes=dynamic_shape_spec)
|
||||
torch.export.export(
|
||||
model, (a, b), dynamic_shapes=dynamic_shape_spec, strict=True
|
||||
)
|
||||
|
||||
def test_export_fast_binary_broadcast_check_unbacked(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
@ -2522,7 +2525,7 @@ def forward(self, x):
|
||||
model = MyModel().eval().cuda()
|
||||
numel = torch.tensor(10)
|
||||
scalar = torch.randn(1)
|
||||
torch.export.export(model, (numel, scalar))
|
||||
torch.export.export(model, (numel, scalar), strict=True)
|
||||
|
||||
def test_export_meta(self):
|
||||
class MyModule(torch.nn.Module):
|
||||
@ -2563,7 +2566,7 @@ def forward(self, x):
|
||||
"by dim0 = 2\\*dim1(.*\n)*.*"
|
||||
"Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*",
|
||||
):
|
||||
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
class Bar(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -2581,7 +2584,7 @@ def forward(self, x):
|
||||
torch._dynamo.exc.UserError,
|
||||
"Not all values.*valid.*inferred to be a constant",
|
||||
):
|
||||
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(bar, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
class Qux(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
@ -2599,7 +2602,7 @@ def forward(self, x):
|
||||
torch._dynamo.exc.UserError,
|
||||
"Not all values.*satisfy the generated guard",
|
||||
):
|
||||
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(qux, (t,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
def test_untracked_inputs_in_constraints(self):
|
||||
from copy import copy
|
||||
@ -2617,7 +2620,9 @@ def forward(self, x):
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
||||
|
||||
example_inputs = (copy(x), y)
|
||||
ep = torch.export.export(foo, example_inputs, dynamic_shapes=dynamic_shapes)
|
||||
ep = torch.export.export(
|
||||
foo, example_inputs, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
ep.module()(torch.randn(3), y) # no specialization error
|
||||
|
||||
def test_export_raise_guard_full_constraint(self):
|
||||
@ -2734,6 +2739,7 @@ def forward(self, x):
|
||||
foo,
|
||||
(a, {"k": b}),
|
||||
dynamic_shapes={"x": {0: dim0_a}, "y": {"k": {0: dim0_b}}},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
def test_enforce_equalities(self):
|
||||
@ -2752,16 +2758,10 @@ def forward(self, x):
|
||||
torch._dynamo.exc.UserError,
|
||||
".*y.*size.*2.* = 4 is not equal to .*x.*size.*1.* = 3",
|
||||
):
|
||||
torch.export.export(
|
||||
bar,
|
||||
(x, y),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
)
|
||||
torch.export.export(bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
y = torch.randn(10, 3, 3)
|
||||
ebar = torch.export.export(
|
||||
bar,
|
||||
(x, y),
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
bar, (x, y), dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
self.assertEqual(
|
||||
[
|
||||
@ -2923,15 +2923,15 @@ def forward(self, x):
|
||||
torch._dynamo.exc.UserError,
|
||||
r"Constraints violated \(dim0\)",
|
||||
):
|
||||
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(foo, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(bar, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.UserError,
|
||||
r"Constraints violated \(dim0\)",
|
||||
):
|
||||
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes)
|
||||
torch.export.export(qux, (x,), dynamic_shapes=dynamic_shapes, strict=True)
|
||||
|
||||
def test_list_contains(self):
|
||||
def func(x):
|
||||
|
@ -24,7 +24,7 @@ class FxPassesPreGradTests(torch._dynamo.test_case.TestCase):
|
||||
sample_input = torch.randn(4, 4)
|
||||
m = TestModule()
|
||||
m(sample_input)
|
||||
exported_program = torch.export.export(m, (sample_input,))
|
||||
exported_program = torch.export.export(m, (sample_input,), strict=True)
|
||||
gm = exported_program.graph_module
|
||||
|
||||
pass_execution_and_save(fx_pass, gm, sample_input, "Apply testing pass")
|
||||
|
@ -72,7 +72,7 @@ class SourceTests(torch._dynamo.test_case.TestCase):
|
||||
lambda x, _: CausalLMOutputWithPast(),
|
||||
)
|
||||
|
||||
torch.export.export(Model(), ())
|
||||
torch.export.export(Model(), (), strict=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -342,7 +342,7 @@ class TestDraftExport(TestCase):
|
||||
inputs = (torch.randn(3, 3),)
|
||||
with self.assertRaises(RuntimeError):
|
||||
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
||||
export(mod, inputs)
|
||||
export(mod, inputs, strict=True)
|
||||
|
||||
ep, report = draft_export(mod, inputs)
|
||||
for ep_out, eager_out in zip(ep.module()(*inputs), mod(*inputs)):
|
||||
@ -384,7 +384,7 @@ class TestDraftExport(TestCase):
|
||||
"Real tensor propagation found an aliasing mismatch",
|
||||
):
|
||||
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
||||
export(mod, inputs)
|
||||
export(mod, inputs, strict=True)
|
||||
|
||||
ep, report = draft_export(mod, inputs)
|
||||
for ep_out, eager_out in zip(
|
||||
|
@ -83,7 +83,7 @@ class TestHOP(TestCase):
|
||||
input = inp.input if isinstance(inp.input, tuple) else (inp.input,)
|
||||
args = (*input, *inp.args)
|
||||
kwargs = inp.kwargs
|
||||
ep = export(model, args, kwargs)
|
||||
ep = export(model, args, kwargs, strict=True)
|
||||
self._compare(model, ep, args, kwargs)
|
||||
|
||||
@ops(hop_tests, allowed_dtypes=(torch.float,))
|
||||
|
@ -24,7 +24,7 @@ class TestPassInfra(TestCase):
|
||||
class NullPass(_ExportPassBaseDeprecatedDoNotUse):
|
||||
pass
|
||||
|
||||
ep = export(f, (torch.ones(3, 2),))
|
||||
ep = export(f, (torch.ones(3, 2),), strict=True)
|
||||
old_nodes = ep.graph.nodes
|
||||
|
||||
ep = ep._transform_do_not_use(NullPass())
|
||||
@ -66,7 +66,7 @@ class TestPassInfra(TestCase):
|
||||
x = torch.tensor([2])
|
||||
y = torch.tensor([5])
|
||||
mod = M()
|
||||
_ = export(mod, (torch.tensor(True), x, y))._transform_do_not_use(
|
||||
_ = export(mod, (torch.tensor(True), x, y), strict=True)._transform_do_not_use(
|
||||
_ExportPassBaseDeprecatedDoNotUse()
|
||||
)
|
||||
|
||||
@ -98,7 +98,7 @@ class TestPassInfra(TestCase):
|
||||
inps = (torch.rand(1), torch.rand(1))
|
||||
m = CustomModule()
|
||||
|
||||
ep_before = export(m, inps)
|
||||
ep_before = export(m, inps, strict=True)
|
||||
|
||||
# No op transformation that doesn't perform any meaningful changes to node
|
||||
ep_after = ep_before._transform_do_not_use(_ExportPassBaseDeprecatedDoNotUse())
|
||||
@ -131,7 +131,9 @@ class TestPassInfra(TestCase):
|
||||
input_tensor1 = torch.tensor(5.0)
|
||||
input_tensor2 = torch.tensor(6.0)
|
||||
|
||||
ep_before = torch.export.export(my_module, (input_tensor1, input_tensor2))
|
||||
ep_before = torch.export.export(
|
||||
my_module, (input_tensor1, input_tensor2), strict=True
|
||||
)
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
|
||||
def modify_input_output_pass(gm):
|
||||
@ -169,7 +171,7 @@ class TestPassInfra(TestCase):
|
||||
|
||||
my_module = CustomModule()
|
||||
inputs = (torch.tensor(6.0), torch.tensor(7.0))
|
||||
ep_before = export(my_module, inputs)
|
||||
ep_before = export(my_module, inputs, strict=True)
|
||||
|
||||
def replace_pass(gm):
|
||||
for node in gm.graph.nodes:
|
||||
|
@ -404,7 +404,9 @@ class TestPasses(TestCase):
|
||||
x = torch.zeros(2, 2, 3)
|
||||
|
||||
dim1_x = torch.export.Dim("dim1_x", min=2, max=6)
|
||||
ep = torch.export.export(M(), (x,), dynamic_shapes={"x": {1: dim1_x}})
|
||||
ep = torch.export.export(
|
||||
M(), (x,), dynamic_shapes={"x": {1: dim1_x}}, strict=True
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
@ -431,7 +433,10 @@ class TestPasses(TestCase):
|
||||
dim0_x, dim0_y = torch.export.dims("dim0_x", "dim0_y", min=3)
|
||||
|
||||
ep = torch.export.export(
|
||||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}}
|
||||
M(),
|
||||
(x, y),
|
||||
dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": {0: dim0_y}},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -461,7 +466,10 @@ class TestPasses(TestCase):
|
||||
dim0_x = torch.export.Dim("dim0_x", min=3)
|
||||
|
||||
ep = torch.export.export(
|
||||
M(), (x, y), dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None}
|
||||
M(),
|
||||
(x, y),
|
||||
dynamic_shapes={"x": {0: dim0_x, 1: dim1_x}, "y": None},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
@ -496,7 +504,7 @@ class TestPasses(TestCase):
|
||||
|
||||
dim1_y = torch.export.Dim("dim1_y", min=3, max=6)
|
||||
ep = torch.export.export(
|
||||
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}
|
||||
M(), (x, y), dynamic_shapes={"x": None, "y": {1: dim1_y}}, strict=True
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, escape("shape[1] to be equal to 2")):
|
||||
@ -526,7 +534,7 @@ class TestPasses(TestCase):
|
||||
|
||||
x = torch.zeros(4, 2, 3)
|
||||
|
||||
ep = export(M(), (x,))
|
||||
ep = export(M(), (x,), strict=True)
|
||||
self.assertEqual(count_call_function(ep.graph, torch.ops.aten.view.default), 1)
|
||||
|
||||
ep = ep._transform_do_not_use(ReplaceViewOpsWithViewCopyOpsPass())
|
||||
@ -542,7 +550,7 @@ class TestPasses(TestCase):
|
||||
|
||||
x = torch.zeros(4, 2, 3)
|
||||
foo = Module()
|
||||
ep = export(foo, (x,))._transform_do_not_use(
|
||||
ep = export(foo, (x,), strict=True)._transform_do_not_use(
|
||||
ReplaceViewOpsWithViewCopyOpsPass()
|
||||
)
|
||||
# After this pass, there shouldn't be any view nodes in the graph
|
||||
@ -684,7 +692,7 @@ def forward(self, token, obj_attr, x):
|
||||
|
||||
x = torch.tensor([2])
|
||||
mod = M()
|
||||
ep = export(mod, (x,))
|
||||
ep = export(mod, (x,), strict=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, r"Runtime assertion failed for expression u[\d+] \<\= 5"
|
||||
@ -709,7 +717,9 @@ def forward(self, token, obj_attr, x):
|
||||
|
||||
mod = M()
|
||||
dim0_x = torch.export.Dim("dim0_x")
|
||||
ep = torch.export.export(mod, (x,), dynamic_shapes={"x": {0: dim0_x}})
|
||||
ep = torch.export.export(
|
||||
mod, (x,), dynamic_shapes={"x": {0: dim0_x}}, strict=True
|
||||
)
|
||||
|
||||
num_assert = count_call_function(
|
||||
ep.graph, torch.ops.aten._assert_scalar.default
|
||||
@ -762,7 +772,7 @@ def forward(self, token, obj_attr, x):
|
||||
x = torch.tensor([2])
|
||||
y = torch.tensor([5])
|
||||
mod = M()
|
||||
ep = export(mod, (torch.tensor(True), x, y))
|
||||
ep = export(mod, (torch.tensor(True), x, y), strict=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "is outside of inline constraint \\[2, 5\\]."
|
||||
@ -779,7 +789,7 @@ def forward(self, token, obj_attr, x):
|
||||
|
||||
func = Module()
|
||||
x = torch.randn(1, dtype=torch.float32)
|
||||
ep = torch.export.export(func, args=(x,))
|
||||
ep = torch.export.export(func, args=(x,), strict=True)
|
||||
_ExportPassBaseDeprecatedDoNotUse()(ep.graph_module)
|
||||
|
||||
def test_predispatch_set_grad(self):
|
||||
@ -1231,7 +1241,7 @@ def forward(self, add_1):
|
||||
|
||||
mod = M()
|
||||
x = torch.randn([3, 3])
|
||||
ep = export(mod, (x,))
|
||||
ep = export(mod, (x,), strict=True)
|
||||
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
|
||||
nodes = inplace_ep.graph.nodes
|
||||
for node in nodes:
|
||||
@ -1274,7 +1284,7 @@ def forward(self, add_1):
|
||||
|
||||
mod = M()
|
||||
x = torch.randn([3, 3])
|
||||
ep = export(mod, (x,)).run_decompositions({})
|
||||
ep = export(mod, (x,), strict=True).run_decompositions({})
|
||||
inplace_ep = unsafe_remove_auto_functionalized_pass(ep)
|
||||
graph_text = str(inplace_ep.graph)
|
||||
self.assertExpectedInline(
|
||||
@ -1304,7 +1314,7 @@ default](args = (%x, %b_state), kwargs = {})
|
||||
# move the exported program from cpu to cuda:0
|
||||
mod = Model()
|
||||
example_inputs = (torch.rand(1, 10, 4),)
|
||||
ep = export(mod, example_inputs)
|
||||
ep = export(mod, example_inputs, strict=True)
|
||||
location = torch.device("cuda:0")
|
||||
ep = move_to_device_pass(ep, location=location)
|
||||
gm = ep.module()
|
||||
|
@ -141,7 +141,7 @@ class TestSparseProp(TestCase):
|
||||
index_dtype=itype,
|
||||
):
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, (sparse_input,))
|
||||
prog = torch.export.export(net, (sparse_input,), strict=True)
|
||||
# Test arg/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -163,7 +163,7 @@ class TestSparseProp(TestCase):
|
||||
):
|
||||
result = net(sparse_input)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, (sparse_input,))
|
||||
prog = torch.export.export(net, (sparse_input,), strict=True)
|
||||
# Test arg/sum/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -187,7 +187,7 @@ class TestSparseProp(TestCase):
|
||||
):
|
||||
result = net(sparse_input)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, (sparse_input,))
|
||||
prog = torch.export.export(net, (sparse_input,), strict=True)
|
||||
# Test arg/neg/abs/mul/relu/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -209,7 +209,7 @@ class TestSparseProp(TestCase):
|
||||
):
|
||||
result = net(sparse_input)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, (sparse_input,))
|
||||
prog = torch.export.export(net, (sparse_input,), strict=True)
|
||||
# Test arg/todense/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -235,7 +235,7 @@ class TestSparseProp(TestCase):
|
||||
S = A.to_sparse_csr()
|
||||
result = net(S, Y)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, (S, Y))
|
||||
prog = torch.export.export(net, (S, Y), strict=True)
|
||||
# Test args/add/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -253,7 +253,7 @@ class TestSparseProp(TestCase):
|
||||
x = [torch.randn(3, 3) for _ in range(3)]
|
||||
result = net(x)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, args=(x,))
|
||||
prog = torch.export.export(net, args=(x,), strict=True)
|
||||
# Test args/to_sparse/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
@ -269,7 +269,7 @@ class TestSparseProp(TestCase):
|
||||
x = [torch.randn(3, 3) for _ in range(3)]
|
||||
result = net(x)
|
||||
# Build the traced graph.
|
||||
prog = torch.export.export(net, args=(x,))
|
||||
prog = torch.export.export(net, args=(x,), strict=True)
|
||||
# Test args/to_sparse/output.
|
||||
for i, node in enumerate(prog.graph.nodes):
|
||||
meta = node.meta.get("val", None)
|
||||
|
@ -436,7 +436,7 @@ def forward(self, x, y):
|
||||
def forward(self, a, b):
|
||||
return (CustomOutput(a * a, b * b), CustomOutput(a * b.T, a + b.T))
|
||||
|
||||
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)))
|
||||
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)), strict=True)
|
||||
swapped = _swap_modules(ep, {})
|
||||
inp = (torch.randn(2, 3), torch.randn(3, 2))
|
||||
res1 = torch.fx.Interpreter(swapped).run(*inp)
|
||||
|
@ -94,7 +94,7 @@ class TestUnflatten(TestCase):
|
||||
return x
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = export(orig_eager, (torch.rand(2, 3),), {})
|
||||
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
inputs = (torch.rand(2, 3),)
|
||||
@ -134,7 +134,7 @@ class TestUnflatten(TestCase):
|
||||
return x * self.rootparam
|
||||
|
||||
eager_module = MyModule()
|
||||
export_module = export(eager_module, (torch.rand(2, 3),), {})
|
||||
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
|
||||
unflattened_module = unflatten(export_module)
|
||||
|
||||
# Buffer should look the same before and after one run
|
||||
@ -170,7 +170,7 @@ class TestUnflatten(TestCase):
|
||||
return x
|
||||
|
||||
eager_module = MyModule()
|
||||
export_module = export(eager_module, (torch.rand(2, 3),), {})
|
||||
export_module = export(eager_module, (torch.rand(2, 3),), {}, strict=True)
|
||||
unflattened_module = unflatten(export_module)
|
||||
|
||||
inputs = (torch.rand(2, 3),)
|
||||
@ -193,7 +193,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
eager_module = Shared()
|
||||
inps = (torch.rand(10),)
|
||||
export_module = export(eager_module, inps, {})
|
||||
export_module = export(eager_module, inps, {}, strict=True)
|
||||
unflattened_module = unflatten(export_module)
|
||||
self.compare_outputs(eager_module, unflattened_module, inps)
|
||||
self.assertTrue(hasattr(unflattened_module, "sub_net"))
|
||||
@ -297,7 +297,7 @@ class TestUnflatten(TestCase):
|
||||
x = x + self.param_dict[f"key_{i}"]
|
||||
return x
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
self.compare_outputs(
|
||||
@ -348,7 +348,7 @@ class TestUnflatten(TestCase):
|
||||
a = a + self.param_dict[f"key_{i}"].sum()
|
||||
return a
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
|
||||
@ -404,7 +404,9 @@ class TestUnflatten(TestCase):
|
||||
return x
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
||||
export_module = torch.export.export(
|
||||
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
||||
)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
# in-place compilation should work. Pass fullgraph to ensure no graph breaks.
|
||||
@ -431,7 +433,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
orig_eager = MyModule()
|
||||
inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
|
||||
export_module = export(orig_eager, inputs, {})
|
||||
export_module = export(orig_eager, inputs, {}, strict=True)
|
||||
|
||||
unflattened = unflatten(export_module)
|
||||
torch.fx.symbolic_trace(
|
||||
@ -463,7 +465,9 @@ class TestUnflatten(TestCase):
|
||||
return x + self.submod.subsubmod(x)
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
||||
export_module = torch.export.export(
|
||||
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
||||
)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
inputs = (torch.rand(2, 3),)
|
||||
@ -499,7 +503,7 @@ class TestUnflatten(TestCase):
|
||||
inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
|
||||
mod = Foo()
|
||||
|
||||
ep_strict = torch.export.export(mod, inp) # noqa: F841
|
||||
ep_strict = torch.export.export(mod, inp, strict=True) # noqa: F841
|
||||
ep_non_strict = torch.export.export(mod, inp, strict=False)
|
||||
|
||||
gm_unflat_non_strict = unflatten(ep_non_strict)
|
||||
@ -523,7 +527,9 @@ class TestUnflatten(TestCase):
|
||||
return x + sum(self.submod(x))
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
|
||||
export_module = torch.export.export(
|
||||
orig_eager, (torch.rand(2, 3),), {}, strict=True
|
||||
)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
inputs = (torch.rand(2, 3),)
|
||||
@ -551,7 +557,7 @@ class TestUnflatten(TestCase):
|
||||
mod = M()
|
||||
|
||||
inputs = (torch.randn(3, 3, device="meta"),)
|
||||
ep = export(mod, inputs)
|
||||
ep = export(mod, inputs, strict=True)
|
||||
unflattened = unflatten(ep)
|
||||
self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
|
||||
self.assertTrue(unflattened.p.requires_grad is False)
|
||||
@ -567,7 +573,7 @@ class TestUnflatten(TestCase):
|
||||
return x.transpose(0, 1)
|
||||
|
||||
x = torch.randn(32, 3, 64, 64)
|
||||
exported_program = export(TransposeModule(), args=(x,))
|
||||
exported_program = export(TransposeModule(), args=(x,), strict=True)
|
||||
unflattened_module = unflatten(exported_program)
|
||||
|
||||
# Check the inputs of the created call_module node are in order
|
||||
@ -599,7 +605,7 @@ class TestUnflatten(TestCase):
|
||||
def forward(self, x):
|
||||
return x + self.submod(x)
|
||||
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
|
||||
export_module = torch.export.export(Mod(), (torch.randn((2, 3)),), strict=True)
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
self.compare_outputs(
|
||||
@ -751,7 +757,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
mod = Module()
|
||||
|
||||
ep = torch.export.export(mod, (torch.randn(3, 4),))
|
||||
ep = torch.export.export(mod, (torch.randn(3, 4),), strict=True)
|
||||
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
|
||||
@ -808,7 +814,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
m = Foo()
|
||||
inps = (torch.randn(4, 4),)
|
||||
ep = export(m, inps)
|
||||
ep = export(m, inps, strict=True)
|
||||
unep = unflatten(ep)
|
||||
self.assertTrue(id(unep.m.bias) == id(unep.bias))
|
||||
|
||||
@ -827,7 +833,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
m = Foo()
|
||||
inps = (torch.randn(4, 4),)
|
||||
ep = export(m, inps)
|
||||
ep = export(m, inps, strict=True)
|
||||
unep = unflatten(ep)
|
||||
self.assertTrue(torch.allclose(unep(*inps), m(*inps)))
|
||||
|
||||
@ -849,7 +855,7 @@ class TestUnflatten(TestCase):
|
||||
|
||||
mod = M()
|
||||
x = torch.randn(4, 8)
|
||||
ep = export(mod, (x,))
|
||||
ep = export(mod, (x,), strict=True)
|
||||
unflattened = unflatten(ep)
|
||||
torch.testing.assert_close(unflattened(x), mod(x))
|
||||
|
||||
@ -942,7 +948,7 @@ class TestUnflatten(TestCase):
|
||||
return x
|
||||
|
||||
orig_eager = MyModule()
|
||||
export_module = export(orig_eager, (torch.rand(2, 3),), {})
|
||||
export_module = export(orig_eager, (torch.rand(2, 3),), {}, strict=True)
|
||||
with _disable_interpreter():
|
||||
unflattened = unflatten(export_module)
|
||||
|
||||
|
@ -6606,7 +6606,7 @@ class TestHopSchema(TestCase):
|
||||
x,
|
||||
)
|
||||
model = M()
|
||||
torch.export.export(model, args)
|
||||
torch.export.export(model, args, strict=True)
|
||||
graph_str = self._check_export(model, args, None)
|
||||
self.assertExpectedInline(
|
||||
graph_str,
|
||||
|
@ -193,7 +193,9 @@ class TestSplitOutputType(TestCase):
|
||||
relu
|
||||
"""
|
||||
tag_node = defaultdict(list)
|
||||
gm: torch.fx.GraphModule = torch.export.export(module, (inputs,)).module()
|
||||
gm: torch.fx.GraphModule = torch.export.export(
|
||||
module, (inputs,), strict=True
|
||||
).module()
|
||||
# Add tag to all nodes and build dictionary record tag to call_module nodes
|
||||
for node in gm.graph.nodes:
|
||||
if "conv" in node.name:
|
||||
|
@ -92,10 +92,7 @@ class TestFXNodeSource(TestCase):
|
||||
|
||||
model = Model()
|
||||
example_inputs = (torch.randn(8, 10),)
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_inputs,
|
||||
)
|
||||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
gm = ep.module()
|
||||
provenance = get_graph_provenance_json(gm.graph)
|
||||
provenance = json.loads(provenance)
|
||||
|
@ -221,7 +221,7 @@ class TestSourceMatcher(JitTestCase):
|
||||
torch._check(b + 1 < y.size(0))
|
||||
return y[: b + 1]
|
||||
|
||||
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
|
||||
ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)), strict=True)
|
||||
fake_inputs = [
|
||||
node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
|
@ -1885,7 +1885,7 @@ class AOTInductorTestsTemplate:
|
||||
example_inputs = (torch.randn(10, 10), torch.randn(10, 10))
|
||||
|
||||
# Export on CPU
|
||||
exported_program = export(Model(), example_inputs)
|
||||
exported_program = export(Model(), example_inputs, strict=True)
|
||||
|
||||
# Compile exported model on GPU
|
||||
gm = exported_program.graph_module.to(self.device)
|
||||
|
@ -156,10 +156,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
|
||||
torch.manual_seed(0)
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
ep = torch.export.export(
|
||||
model,
|
||||
example_inputs,
|
||||
)
|
||||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
with fresh_inductor_cache():
|
||||
# cubin files are removed when exiting this context
|
||||
package_path = torch._inductor.aoti_compile_and_package(
|
||||
@ -304,7 +301,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
torch.randn(3, 4, device=self.device),
|
||||
)
|
||||
ep1 = torch.export.export(
|
||||
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
|
||||
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
aoti_files1 = torch._inductor.aot_compile(
|
||||
ep1.module(), example_inputs1, options=options
|
||||
@ -321,7 +318,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
return x * t
|
||||
|
||||
example_inputs2 = (torch.randn(5, 5, device=self.device),)
|
||||
ep2 = torch.export.export(Model2(self.device), example_inputs2)
|
||||
ep2 = torch.export.export(Model2(self.device), example_inputs2, strict=True)
|
||||
aoti_files2 = torch._inductor.aot_compile(
|
||||
ep2.module(), example_inputs2, options=options
|
||||
)
|
||||
@ -360,7 +357,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
)
|
||||
self.check_model(Model1(), example_inputs1)
|
||||
ep1 = torch.export.export(
|
||||
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes
|
||||
Model1(), example_inputs1, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
aoti_files1 = torch._inductor.aot_compile(
|
||||
ep1.module(), example_inputs1, options=options
|
||||
@ -372,7 +369,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
torch.randn(3, 4, device=device),
|
||||
)
|
||||
ep2 = torch.export.export(
|
||||
Model1(), example_inputs2, dynamic_shapes=dynamic_shapes
|
||||
Model1(), example_inputs2, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
aoti_files2 = torch._inductor.aot_compile(
|
||||
ep2.module(), example_inputs2, options=options
|
||||
@ -404,7 +401,7 @@ class TestAOTInductorPackage(TestCase):
|
||||
torch.randn(2, 4, device=self.device),
|
||||
torch.randn(3, 4, device=self.device),
|
||||
)
|
||||
ep = torch.export.export(Model(), example_inputs)
|
||||
ep = torch.export.export(Model(), example_inputs, strict=True)
|
||||
aoti_files = torch._inductor.aot_compile(
|
||||
ep.module(),
|
||||
example_inputs,
|
||||
@ -433,12 +430,10 @@ class TestAOTInductorPackage(TestCase):
|
||||
torch.randn(2, 4, device=self.device),
|
||||
torch.randn(3, 4, device=self.device),
|
||||
)
|
||||
ep = torch.export.export(Model(), example_inputs)
|
||||
ep = torch.export.export(Model(), example_inputs, strict=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
buffer = torch._inductor.aoti_compile_and_package(
|
||||
ep, package_path=buffer
|
||||
) # type: ignore[arg-type]
|
||||
buffer = torch._inductor.aoti_compile_and_package(ep, package_path=buffer) # type: ignore[arg-type]
|
||||
for _ in range(2):
|
||||
loaded = load_package(buffer)
|
||||
self.assertTrue(
|
||||
|
@ -149,6 +149,7 @@ class TestExportAPIDynamo(common_utils.TestCase):
|
||||
2: torch.export.Dim("customb_dim_2"),
|
||||
},
|
||||
},
|
||||
strict=True,
|
||||
)
|
||||
|
||||
self.assert_export(exported_program)
|
||||
|
@ -273,7 +273,9 @@ class _TestONNXRuntime(pytorch_test_common.ExportTestCase):
|
||||
== pytorch_test_common.TorchModelType.TORCH_EXPORT_EXPORTEDPROGRAM
|
||||
):
|
||||
with _dynamo_config.patch(do_not_emit_runtime_asserts=True):
|
||||
ref_model = torch.export.export(ref_model, args=ref_input_args)
|
||||
ref_model = torch.export.export(
|
||||
ref_model, args=ref_input_args, strict=True
|
||||
)
|
||||
if (
|
||||
self.dynamic_shapes
|
||||
): # TODO: Support dynamic shapes for torch.export.ExportedProgram
|
||||
|
@ -138,7 +138,7 @@ class TestModularizePass(common_utils.TestCase):
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3))
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
@ -185,7 +185,7 @@ class TestModularizePass(common_utils.TestCase):
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3))
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
@ -239,7 +239,7 @@ class TestModularizePass(common_utils.TestCase):
|
||||
|
||||
if is_exported_program:
|
||||
model = torch.export.export(
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3))
|
||||
TestModule(), args=(torch.randn(3), torch.randn(3)), strict=True
|
||||
)
|
||||
else:
|
||||
model = TestModule()
|
||||
|
@ -194,7 +194,7 @@ class TestFxToOnnx(pytorch_test_common.ExportTestCase):
|
||||
|
||||
x = torch.randn(2, 3)
|
||||
with torch.no_grad():
|
||||
exported_program = torch.export.export(Model(), args=(x,))
|
||||
exported_program = torch.export.export(Model(), args=(x,), strict=True)
|
||||
_ = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
x,
|
||||
|
@ -57,7 +57,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
x = torch.randn(2, 3, 4, dtype=torch.float)
|
||||
dim0 = torch.export.Dim("dim0")
|
||||
exported_program = torch.export.export(
|
||||
Model(), (x,), dynamic_shapes={"x": {0: dim0}}
|
||||
Model(), (x,), dynamic_shapes={"x": {0: dim0}}, strict=True
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, x)
|
||||
|
||||
@ -75,7 +75,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
return x + 1.0
|
||||
|
||||
x = torch.randn(1, 1, 2, dtype=torch.float)
|
||||
exported_program = torch.export.export(Model(), args=(x,))
|
||||
exported_program = torch.export.export(Model(), args=(x,), strict=True)
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, x)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pte") as f:
|
||||
@ -101,7 +101,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": None}
|
||||
# specialized input y to 5 during tracing
|
||||
exported_program = torch.export.export(
|
||||
f, (tensor_input, 5), dynamic_shapes=dynamic_shapes
|
||||
f, (tensor_input, 5), dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, tensor_input, 5)
|
||||
|
||||
@ -134,13 +134,16 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
return bar.sum() + self.buf.sum()
|
||||
|
||||
tensor_input = torch.ones(5, 5)
|
||||
exported_program = torch.export.export(Foo(), (tensor_input,))
|
||||
exported_program = torch.export.export(Foo(), (tensor_input,), strict=True)
|
||||
|
||||
dim0_x = torch.export.Dim("dim0_x")
|
||||
# NOTE: If input is ExportedProgram, we need to specify dynamic_shapes
|
||||
# as a tuple.
|
||||
reexported_program = torch.export.export(
|
||||
exported_program.module(), (tensor_input,), dynamic_shapes=({0: dim0_x},)
|
||||
exported_program.module(),
|
||||
(tensor_input,),
|
||||
dynamic_shapes=({0: dim0_x},),
|
||||
strict=True,
|
||||
)
|
||||
reexported_onnx_program = torch.onnx.dynamo_export(
|
||||
reexported_program, tensor_input
|
||||
@ -162,7 +165,10 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
|
||||
dim = torch.export.Dim("dim")
|
||||
exported_program = torch.export.export(
|
||||
foo, (torch.randn(4, 4), torch.randn(4, 4)), dynamic_shapes=(None, {0: dim})
|
||||
foo,
|
||||
(torch.randn(4, 4), torch.randn(4, 4)),
|
||||
dynamic_shapes=(None, {0: dim}),
|
||||
strict=True,
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
exported_program, torch.randn(4, 4), torch.randn(4, 4)
|
||||
@ -192,6 +198,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
# We are specifying dynamism on the first kwarg even though user passed in
|
||||
# different order
|
||||
dynamic_shapes=(None, {0: dim}, {0: dim_for_kw1}, None),
|
||||
strict=True,
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
exported_program,
|
||||
@ -237,6 +244,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
input_b,
|
||||
),
|
||||
dynamic_shapes=({0: dim}, {0: dim}),
|
||||
strict=True,
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, input_x, input_b)
|
||||
|
||||
@ -276,7 +284,9 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
return None
|
||||
|
||||
dynamic_shapes = torch_pytree.tree_map(dynamify_inp, inp)
|
||||
exported_program = torch.export.export(foo, inp, dynamic_shapes=dynamic_shapes)
|
||||
exported_program = torch.export.export(
|
||||
foo, inp, dynamic_shapes=dynamic_shapes, strict=True
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(exported_program, inp_a, inp_b)
|
||||
|
||||
# NOTE: Careful with the input format. The input format should be
|
||||
@ -302,6 +312,7 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime):
|
||||
(),
|
||||
{"x": torch.randn(3, 3), "y": torch.randn(3, 3)},
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
onnx_program = torch.onnx.dynamo_export(
|
||||
exported_program, x=torch.randn(3, 3), y=torch.randn(3, 3)
|
||||
|
@ -61,7 +61,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
def _test_quant_tag_preservation_through_decomp(
|
||||
self, model, example_inputs, from_node_to_tags
|
||||
):
|
||||
ep = torch.export.export(model, example_inputs)
|
||||
ep = torch.export.export(model, example_inputs, strict=True)
|
||||
found_tags = True
|
||||
not_found_nodes = ""
|
||||
for from_node, tag in from_node_to_tags.items():
|
||||
|
@ -133,7 +133,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_copy_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = torch.export.export(m, example_inputs)
|
||||
ep = torch.export.export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -148,7 +148,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_deepcopy_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = torch.export.export(m, example_inputs)
|
||||
ep = torch.export.export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
debug_handle_map_ref = self._extract_debug_handles(ep)
|
||||
|
@ -1460,7 +1460,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
with TemporaryFileName() as fname:
|
||||
# serialization
|
||||
quantized_ep = torch.export.export(m, example_inputs)
|
||||
quantized_ep = torch.export.export(m, example_inputs, strict=True)
|
||||
torch.export.save(quantized_ep, fname)
|
||||
# deserialization
|
||||
loaded_ep = torch.export.load(fname)
|
||||
|
@ -1130,7 +1130,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
model_pt2e = convert_pt2e(model_pt2e)
|
||||
quant_result_pt2e = model_pt2e(*example_inputs) # noqa: F841
|
||||
|
||||
exported_model = torch.export.export(model_pt2e, example_inputs)
|
||||
exported_model = torch.export.export(model_pt2e, example_inputs, strict=True)
|
||||
|
||||
node_occurrence = {
|
||||
# conv2d: 1 for act, 1 for weight, 1 for output
|
||||
|
@ -912,7 +912,9 @@ class FakeTensorTest(TestCase):
|
||||
return input + np.random.randn(*input.shape)
|
||||
|
||||
with FakeTensorMode():
|
||||
ep = torch.export.export(MyNumpyModel(), args=(torch.randn(1000),))
|
||||
ep = torch.export.export(
|
||||
MyNumpyModel(), args=(torch.randn(1000),), strict=True
|
||||
)
|
||||
self.assertTrue(isinstance(ep, torch.export.ExportedProgram))
|
||||
|
||||
def test_unsqueeze_copy(self):
|
||||
|
@ -61,11 +61,10 @@ class TestOutDtypeOp(TestCase):
|
||||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
m = M(weight)
|
||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
ep = torch.export.export(
|
||||
m,
|
||||
(x,),
|
||||
)
|
||||
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
|
||||
ep = torch.export.export(m, (x,), strict=True)
|
||||
FileCheck().check("torch.ops.higher_order.out_dtype").check(
|
||||
"aten.mm.default"
|
||||
).run(ep.graph_module.code)
|
||||
self.assertTrue(torch.allclose(m(x), ep.module()(x)))
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target is out_dtype:
|
||||
@ -128,7 +127,12 @@ class TestOutDtypeOp(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
|
||||
_ = torch.export.export(
|
||||
M(), (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
||||
M(),
|
||||
(
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
torch.randint(-128, 127, (5, 5), dtype=torch.int8),
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
|
||||
def test_out_dtype_non_op_overload(self):
|
||||
|
Reference in New Issue
Block a user