mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Replace export_for_training with export (#162396)
Summary: replace export_for_training with epxort Test Plan: CI Rollback Plan: Differential Revision: D81935792 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162396 Approved by: https://github.com/angelayi, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
fc1b09a52a
commit
de05dbc39c
@ -183,7 +183,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
|
||||
)
|
||||
torch.utils._pytree.register_constant(DeviceMesh)
|
||||
|
||||
ep = torch.export.export_for_training(
|
||||
ep = torch.export.export(
|
||||
Foo(), (torch.randn(4, 4, dtype=torch.float64),), strict=False
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
|
@ -9,7 +9,7 @@ from torch._export.db.examples import (
|
||||
filter_examples_by_support_level,
|
||||
get_rewrite_cases,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_WINDOWS,
|
||||
@ -35,7 +35,7 @@ class ExampleTests(TestCase):
|
||||
kwargs_export = case.example_kwargs
|
||||
args_model = copy.deepcopy(args_export)
|
||||
kwargs_model = copy.deepcopy(kwargs_export)
|
||||
exported_program = export_for_training(
|
||||
exported_program = export(
|
||||
model,
|
||||
args_export,
|
||||
kwargs_export,
|
||||
@ -68,7 +68,7 @@ class ExampleTests(TestCase):
|
||||
with self.assertRaises(
|
||||
(torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
|
||||
):
|
||||
export_for_training(
|
||||
export(
|
||||
model,
|
||||
case.example_args,
|
||||
case.example_kwargs,
|
||||
@ -94,7 +94,7 @@ class ExampleTests(TestCase):
|
||||
self, name: str, rewrite_case: ExportCase
|
||||
) -> None:
|
||||
# pyre-ignore
|
||||
export_for_training(
|
||||
export(
|
||||
rewrite_case.model,
|
||||
rewrite_case.example_args,
|
||||
rewrite_case.example_kwargs,
|
||||
|
@ -9,7 +9,7 @@ import torch
|
||||
import torch._dynamo
|
||||
from torch._dynamo.test_case import run_tests, TestCase
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch.export import export, export_for_training
|
||||
from torch.export import export
|
||||
from torch.export.experimental import _export_forward_backward, _sticky_export
|
||||
from torch.export.graph_signature import OutputKind
|
||||
from torch.testing import FileCheck
|
||||
@ -32,7 +32,7 @@ class TestExperiment(TestCase):
|
||||
m = Module()
|
||||
example_inputs = (torch.randn(3),)
|
||||
m(*example_inputs)
|
||||
ep = torch.export.export_for_training(m, example_inputs, strict=True)
|
||||
ep = torch.export.export(m, example_inputs, strict=True)
|
||||
joint_ep = _export_forward_backward(ep)
|
||||
self.assertExpectedInline(
|
||||
str(joint_ep.graph_module.code).strip(),
|
||||
@ -141,7 +141,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
m = Module()
|
||||
example_inputs = (torch.randn(3),)
|
||||
m(*example_inputs)
|
||||
ep = torch.export.export_for_training(
|
||||
ep = torch.export.export(
|
||||
m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True
|
||||
)
|
||||
_export_forward_backward(ep)
|
||||
@ -177,7 +177,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
labels = torch.ones(4, dtype=torch.int64)
|
||||
inputs = (x, labels)
|
||||
|
||||
ep = export_for_training(net, inputs, strict=True)
|
||||
ep = export(net, inputs, strict=True)
|
||||
ep = _export_forward_backward(ep)
|
||||
|
||||
def test_joint_loss_index(self):
|
||||
@ -197,7 +197,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
|
||||
|
||||
inputs = (torch.randn(4, 4),)
|
||||
for i in [0, 1]:
|
||||
ep = export_for_training(Foo(i), inputs, strict=True)
|
||||
ep = export(Foo(i), inputs, strict=True)
|
||||
ep_joint = _export_forward_backward(ep, joint_loss_index=i)
|
||||
for j, spec in enumerate(ep_joint.graph_signature.output_specs):
|
||||
if i == j:
|
||||
|
@ -42,13 +42,7 @@ from torch._higher_order_ops.scan import scan
|
||||
from torch._higher_order_ops.while_loop import while_loop
|
||||
from torch._inductor.compile_fx import split_const_gm
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.export import (
|
||||
default_decompositions,
|
||||
Dim,
|
||||
export,
|
||||
export_for_training,
|
||||
unflatten,
|
||||
)
|
||||
from torch.export import default_decompositions, Dim, export, unflatten
|
||||
from torch.export._trace import (
|
||||
_export,
|
||||
_export_to_torch_ir,
|
||||
@ -1058,7 +1052,7 @@ graph():
|
||||
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
|
||||
self.assertEqual(exported_program.module()(*args), m(*args))
|
||||
|
||||
gm: torch.fx.GraphModule = torch.export.export_for_training(
|
||||
gm: torch.fx.GraphModule = torch.export.export(
|
||||
m, args=example_args, dynamic_shapes=dynamic_shapes
|
||||
).module()
|
||||
|
||||
@ -2456,7 +2450,7 @@ def forward(self, x, y):
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,))
|
||||
ep_training = torch.export.export(m, (ref_x,))
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2519,7 +2513,7 @@ graph():
|
||||
ref_x = torch.randn(2, 2)
|
||||
ref_out = m(ref_x)
|
||||
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2651,7 +2645,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertTrue(torch.allclose(ep_training.module()(ref_x), ref_out))
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
@ -2706,7 +2700,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2746,7 +2740,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2784,7 +2778,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2823,7 +2817,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -2863,7 +2857,7 @@ graph():
|
||||
m = Foo()
|
||||
ref_x = torch.randn(3, 4)
|
||||
ref_out = m(ref_x)
|
||||
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
|
||||
ep_training = torch.export.export(m, (ref_x,), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_training.graph).strip(),
|
||||
"""\
|
||||
@ -3983,7 +3977,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
|
||||
x_linear = self.linear(x_conv)
|
||||
return x_linear.cos() + y_conv_1d.sum()
|
||||
|
||||
ep = torch.export.export_for_training(
|
||||
ep = torch.export.export(
|
||||
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
|
||||
)
|
||||
|
||||
@ -4251,9 +4245,7 @@ def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
eager_model, (torch.ones(2, 2),)
|
||||
)
|
||||
ep_for_training = torch.export.export(eager_model, (torch.ones(2, 2),))
|
||||
self.assertExpectedInline(
|
||||
str(ep_for_training.graph_module.code).strip(),
|
||||
"""\
|
||||
@ -4291,7 +4283,7 @@ def forward(self, x):
|
||||
|
||||
eager_model_for_export = Foo()
|
||||
eager_model_for_testing = Foo()
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
ep_for_training = torch.export.export(
|
||||
eager_model_for_export, (torch.ones(4, 4),)
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
@ -4337,7 +4329,7 @@ def forward(self, x):
|
||||
eager_model_for_export_training = Foo()
|
||||
eager_model_for_export_inference = Foo()
|
||||
eager_model_for_testing = Foo()
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
ep_for_training = torch.export.export(
|
||||
eager_model_for_export_training,
|
||||
(torch.ones(4, 4),),
|
||||
dynamic_shapes=({0: Dim("x")},),
|
||||
@ -4391,7 +4383,7 @@ def forward(self, x):
|
||||
return x + y + self.buffer.sum()
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
ep_for_training = torch.export.export(
|
||||
eager_model,
|
||||
([torch.ones(4, 4), torch.ones(4, 4)],),
|
||||
)
|
||||
@ -4597,7 +4589,7 @@ def forward(self, x):
|
||||
return self.linear(x) + self.buffer.sum()
|
||||
|
||||
eager_model = Foo()
|
||||
ep_for_training = torch.export.export_for_training(
|
||||
ep_for_training = torch.export.export(
|
||||
eager_model,
|
||||
(torch.ones(2, 2),),
|
||||
)
|
||||
@ -7530,7 +7522,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
|
||||
inp = torch.randn(4, 4)
|
||||
|
||||
ep = export_for_training(
|
||||
ep = torch.export.export(
|
||||
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
|
||||
)
|
||||
unflat = unflatten(ep).bar
|
||||
@ -7836,7 +7828,7 @@ graph():
|
||||
|
||||
decomp_table = {**default_decompositions(), **decomposition_table}
|
||||
|
||||
ep = export_for_training(M(), (torch.randn(2, 2),)).run_decompositions(
|
||||
ep = torch.export.export(M(), (torch.randn(2, 2),)).run_decompositions(
|
||||
decomp_table
|
||||
)
|
||||
|
||||
@ -7865,7 +7857,7 @@ def forward(self, c_lifted_tensor_0, x):
|
||||
mod.eval()
|
||||
inp = torch.randn(1, 1, 3, 3)
|
||||
|
||||
gm = torch.export.export_for_training(mod, (inp,)).module()
|
||||
gm = torch.export.export(mod, (inp,)).module()
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
@ -7885,7 +7877,7 @@ def forward(self, x):
|
||||
)
|
||||
|
||||
mod.train()
|
||||
gm_train = torch.export.export_for_training(mod, (inp,)).module()
|
||||
gm_train = torch.export.export(mod, (inp,)).module()
|
||||
self.assertExpectedInline(
|
||||
str(gm_train.code).strip(),
|
||||
"""\
|
||||
@ -8450,7 +8442,7 @@ def forward(self, x):
|
||||
ref_x = torch.randn(2, 2)
|
||||
ref_out = f(ref_x, mod)
|
||||
|
||||
ep = torch.export.export_for_training(f, (torch.randn(2, 2), mod), strict=False)
|
||||
ep = torch.export.export(f, (torch.randn(2, 2), mod), strict=False)
|
||||
self.assertEqual(ref_out, ep.module()(ref_x, mod))
|
||||
|
||||
def test_unbacked_noncontig_lin(self):
|
||||
@ -9645,7 +9637,7 @@ graph():
|
||||
return m(x) * x
|
||||
|
||||
inps = (torch.randn(3, 3),)
|
||||
ep = export_for_training(M2(), inps).run_decompositions({})
|
||||
ep = torch.export.export(M2(), inps).run_decompositions({})
|
||||
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
|
||||
|
||||
self.assertEqual(len(ep.state_dict), 0)
|
||||
@ -9682,7 +9674,7 @@ graph():
|
||||
|
||||
inps = (torch.randn(3, 3),)
|
||||
# Strict export segfaults (Issue #128109)
|
||||
ep = export_for_training(M2(), inps, strict=False).run_decompositions({})
|
||||
ep = torch.export.export(M2(), inps, strict=False).run_decompositions({})
|
||||
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
|
||||
|
||||
self.assertEqual(len(ep.state_dict), 0)
|
||||
@ -12013,7 +12005,7 @@ graph():
|
||||
|
||||
if is_training_ir_test(self._testMethodName):
|
||||
test(
|
||||
torch.export.export_for_training(
|
||||
torch.export.export(
|
||||
M(),
|
||||
inp,
|
||||
strict=not is_non_strict_test(self._testMethodName),
|
||||
@ -12134,7 +12126,7 @@ graph():
|
||||
test(export(M(), inp))
|
||||
|
||||
strict = not is_non_strict_test(self._testMethodName)
|
||||
ept = torch.export.export_for_training(
|
||||
ept = torch.export.export(
|
||||
M(),
|
||||
inp,
|
||||
strict=strict,
|
||||
@ -12209,7 +12201,7 @@ graph():
|
||||
|
||||
x = torch.zeros((4, 4, 10))
|
||||
|
||||
ep_training = torch.export.export_for_training(model, (x,), strict=False)
|
||||
ep_training = torch.export.export(model, (x,), strict=False)
|
||||
state_dict_before = ep_training.state_dict
|
||||
|
||||
ep = export(model, (x,), strict=False).run_decompositions()
|
||||
@ -12253,7 +12245,7 @@ def forward(self, c_params, x):
|
||||
|
||||
x = torch.zeros((4, 4, 10))
|
||||
|
||||
ep_training = torch.export.export_for_training(model, (x,), strict=False)
|
||||
ep_training = torch.export.export(model, (x,), strict=False)
|
||||
state_dict_before = ep_training.state_dict
|
||||
|
||||
ep = export(model, (x,), strict=False).run_decompositions()
|
||||
@ -12772,7 +12764,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||
|
||||
model = Model()
|
||||
with torch.no_grad():
|
||||
exported_program = torch.export.export_for_training(
|
||||
exported_program = torch.export.export(
|
||||
model,
|
||||
(torch.tensor(10), torch.tensor(12)),
|
||||
{},
|
||||
@ -12868,7 +12860,7 @@ def forward(self, x, b_t, y):
|
||||
# no grad
|
||||
model = Model()
|
||||
with torch.no_grad():
|
||||
ep_nograd = torch.export.export_for_training(
|
||||
ep_nograd = torch.export.export(
|
||||
model,
|
||||
(torch.tensor(10), torch.tensor(12)),
|
||||
{},
|
||||
@ -12888,7 +12880,7 @@ def forward(self, x, b_t, y):
|
||||
|
||||
# enable grad
|
||||
model = Model()
|
||||
ep_grad = torch.export.export_for_training(
|
||||
ep_grad = torch.export.export(
|
||||
model,
|
||||
(torch.tensor(10), torch.tensor(12)),
|
||||
{},
|
||||
@ -13011,7 +13003,7 @@ def forward(self, x, b_t, y):
|
||||
"torch.ops.higher_order.wrap_with_set_grad_enabled",
|
||||
ep.graph_module.code,
|
||||
)
|
||||
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
|
||||
gm = torch.export.export(model, (torch.randn(4, 4),)).module()
|
||||
self.assertIn(
|
||||
"set_grad_enabled",
|
||||
gm.code,
|
||||
@ -13040,7 +13032,7 @@ def forward(self, x, b_t, y):
|
||||
)
|
||||
# _export_for_traininig is using pre_dispatch=False
|
||||
# Therefore the autocast calls are not replaced with a hop.
|
||||
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
|
||||
gm = torch.export.export(model, (torch.randn(4, 4),)).module()
|
||||
self.assertIn(
|
||||
"autocast",
|
||||
gm.code,
|
||||
@ -13287,7 +13279,7 @@ def forward(self, x, b_t, y):
|
||||
|
||||
inps = (torch.ones(5),)
|
||||
|
||||
ep = export_for_training(M(), inps).run_decompositions({})
|
||||
ep = torch.export.export(M(), inps).run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph_module.code.strip()),
|
||||
"""\
|
||||
@ -13608,7 +13600,7 @@ def forward(self, x):
|
||||
return y + y_sum + unbacked_shape.sum()
|
||||
|
||||
inps = (torch.tensor(4), torch.randn(5, 5))
|
||||
ep_pre = torch.export.export_for_training(Foo(), inps, strict=False)
|
||||
ep_pre = torch.export.export(Foo(), inps, strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep_pre.graph_module.submod_1.code).strip(),
|
||||
"""\
|
||||
@ -14298,7 +14290,7 @@ graph():
|
||||
return val.b.a
|
||||
|
||||
mod = Foo()
|
||||
ep = export_for_training(mod, (torch.randn(4, 4),), strict=False)
|
||||
ep = torch.export.export(mod, (torch.randn(4, 4),), strict=False)
|
||||
self.assertExpectedInline(
|
||||
str(ep.graph).strip(),
|
||||
"""\
|
||||
@ -15311,7 +15303,7 @@ def forward(self, x):
|
||||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
ep_for_training = torch.export.export_for_training(M(), (x, y), strict=strict)
|
||||
ep_for_training = torch.export.export(M(), (x, y), strict=strict)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(
|
||||
ep_for_training.graph_module.print_readable(print_output=False)
|
||||
|
@ -15,14 +15,14 @@ test_classes = {}
|
||||
|
||||
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = torch.export.export_for_training(*args, **kwargs)
|
||||
ep = torch.export.export(*args, **kwargs)
|
||||
else:
|
||||
ep = torch.export.export_for_training(*args, **kwargs, strict=True)
|
||||
ep = torch.export.export(*args, **kwargs, strict=True)
|
||||
return ep.run_decompositions({})
|
||||
|
||||
|
||||
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
|
||||
ep = torch.export.export_for_training(*args, **kwargs)
|
||||
ep = torch.export.export(*args, **kwargs)
|
||||
|
||||
return ep.run_decompositions({})
|
||||
|
||||
|
@ -45,7 +45,7 @@ from torch._export.serde.serialize import (
|
||||
)
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||
from torch.export import Dim, export_for_training, load, save, unflatten
|
||||
from torch.export import Dim, export, load, save, unflatten
|
||||
from torch.export.pt2_archive.constants import ARCHIVE_VERSION_PATH
|
||||
from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -115,7 +115,7 @@ class TestSerialize(TestCase):
|
||||
return torch.ops.aten.add.Tensor._schema
|
||||
|
||||
inp = (torch.ones(10),)
|
||||
ep = export_for_training(TestModule(), inp, strict=True)
|
||||
ep = export(TestModule(), inp, strict=True)
|
||||
|
||||
# Register the custom op handler.
|
||||
foo_custom_op = FooExtensionOp()
|
||||
@ -180,9 +180,7 @@ class TestSerialize(TestCase):
|
||||
|
||||
model = MyModule().eval()
|
||||
random_inputs = (torch.rand([2, 3]), torch.rand([2, 3]))
|
||||
exp_program = export_for_training(
|
||||
model, random_inputs, {"use_p": True}, strict=True
|
||||
)
|
||||
exp_program = export(model, random_inputs, {"use_p": True}, strict=True)
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
# Tests that example inputs are preserved when saving and loading module.
|
||||
@ -201,7 +199,7 @@ class TestSerialize(TestCase):
|
||||
def forward(self, x):
|
||||
return x.sin()
|
||||
|
||||
exp_program = export_for_training(M(), (torch.randn(4, 4),), strict=True)
|
||||
exp_program = export(M(), (torch.randn(4, 4),), strict=True)
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
# Tests that example forward arg names are preserved when saving and loading module.
|
||||
@ -241,7 +239,7 @@ def forward(self, x):
|
||||
inp = (torch.ones(10),)
|
||||
# Module will only be able to roundtrip if metadata
|
||||
# can be correctly parsed.
|
||||
ep = export_for_training(MyModule(), inp, strict=True)
|
||||
ep = export(MyModule(), inp, strict=True)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
@ -282,7 +280,7 @@ def forward(self, x):
|
||||
return h + out_c
|
||||
|
||||
inp = (torch.ones(10),)
|
||||
ep = export_for_training(Foo(), inp, strict=True)
|
||||
ep = export(Foo(), inp, strict=True)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
@ -324,7 +322,7 @@ def forward(self, x):
|
||||
|
||||
# Check that module can be roundtripped, thereby confirming proper deserialization.
|
||||
inp = (torch.ones(10),)
|
||||
ep = export_for_training(MyModule(), inp, strict=True)
|
||||
ep = export(MyModule(), inp, strict=True)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
loaded_ep = load(buffer)
|
||||
@ -347,7 +345,7 @@ def forward(self, x):
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
exported_module = export_for_training(
|
||||
exported_module = export(
|
||||
MyModule(),
|
||||
(
|
||||
torch.ones([512, 512], requires_grad=True),
|
||||
@ -391,7 +389,7 @@ def forward(self, x):
|
||||
"b": {1: dim1_bc},
|
||||
"c": {0: dim0_ac, 1: dim1_bc},
|
||||
}
|
||||
exported_module = export_for_training(
|
||||
exported_module = export(
|
||||
DynamicShapeSimpleModel(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
@ -455,7 +453,7 @@ def forward(self, x):
|
||||
"b": {1: dim1_bc},
|
||||
"c": {0: dim0_ac, 1: dim1_bc},
|
||||
}
|
||||
exported_module = export_for_training(
|
||||
exported_module = export(
|
||||
DynamicShapeSimpleModel(),
|
||||
inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
@ -485,9 +483,7 @@ def forward(self, x):
|
||||
return torch.split(x, 2)
|
||||
|
||||
input = torch.arange(10.0).reshape(5, 2)
|
||||
exported_module = export_for_training(
|
||||
MyModule(), (input,), strict=True
|
||||
).run_decompositions()
|
||||
exported_module = export(MyModule(), (input,), strict=True).run_decompositions()
|
||||
|
||||
serialized = ExportedProgramSerializer().serialize(exported_module)
|
||||
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
||||
@ -550,7 +546,7 @@ def forward(self, x):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.var_mean.correction(x, [1])[0]
|
||||
|
||||
exported_module = export_for_training(
|
||||
exported_module = export(
|
||||
MyModule(), (torch.ones([512, 512], requires_grad=True),), strict=True
|
||||
).run_decompositions()
|
||||
|
||||
@ -571,7 +567,7 @@ def forward(self, x):
|
||||
def forward(self, x):
|
||||
return x + x
|
||||
|
||||
ep = export_for_training(
|
||||
ep = export(
|
||||
M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},), strict=True
|
||||
)
|
||||
|
||||
@ -720,7 +716,7 @@ def forward(self, x):
|
||||
f = Foo()
|
||||
|
||||
x, _ = torch.sort(torch.randn(3, 4))
|
||||
exported_module = export_for_training(f, (x,), strict=True).run_decompositions()
|
||||
exported_module = export(f, (x,), strict=True).run_decompositions()
|
||||
serialized = ExportedProgramSerializer().serialize(exported_module)
|
||||
|
||||
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
||||
@ -738,9 +734,7 @@ def forward(self, x):
|
||||
b = x + y
|
||||
return b + a
|
||||
|
||||
ep = export_for_training(
|
||||
Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True
|
||||
)
|
||||
ep = export(Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True)
|
||||
s = ExportedProgramSerializer().serialize(ep)
|
||||
c = canonicalize(s.exported_program)
|
||||
g = c.graph_module.graph
|
||||
@ -754,7 +748,7 @@ def forward(self, x):
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.sum.dim_IntList(x, [])
|
||||
|
||||
ep = torch.export.export_for_training(M(), (torch.randn(3, 2),), strict=True)
|
||||
ep = torch.export.export(M(), (torch.randn(3, 2),), strict=True)
|
||||
serialized = ExportedProgramSerializer().serialize(ep)
|
||||
for node in serialized.exported_program.graph_module.graph.nodes:
|
||||
if "aten.sum.dim_IntList" in node.target:
|
||||
@ -1024,7 +1018,7 @@ class TestDeserialize(TestCase):
|
||||
|
||||
def _check_graph(pre_dispatch):
|
||||
if pre_dispatch:
|
||||
ep = torch.export.export_for_training(
|
||||
ep = torch.export.export(
|
||||
fn,
|
||||
_deepcopy_inputs(inputs),
|
||||
{},
|
||||
@ -1574,7 +1568,7 @@ class TestDeserialize(TestCase):
|
||||
a = a * 2
|
||||
return a, b
|
||||
|
||||
ep = torch.export.export_for_training(M(), (torch.ones(3),), strict=True)
|
||||
ep = torch.export.export(M(), (torch.ones(3),), strict=True)
|
||||
|
||||
# insert another getitem node
|
||||
for node in ep.graph.nodes:
|
||||
@ -1720,7 +1714,7 @@ def forward(self, x):
|
||||
def forward(self):
|
||||
return self.p * self.p
|
||||
|
||||
ep = torch.export.export_for_training(M(), (), strict=True)
|
||||
ep = torch.export.export(M(), (), strict=True)
|
||||
ep._example_inputs = None
|
||||
roundtrip_ep = deserialize(serialize(ep))
|
||||
self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
|
||||
@ -1762,7 +1756,7 @@ class TestSchemaVersioning(TestCase):
|
||||
return x + x
|
||||
|
||||
f = Module()
|
||||
ep = export_for_training(f, (torch.randn(1, 3),), strict=True)
|
||||
ep = export(f, (torch.randn(1, 3),), strict=True)
|
||||
|
||||
serialized_program = ExportedProgramSerializer().serialize(ep)
|
||||
serialized_program.exported_program.schema_version.major = -1
|
||||
@ -1798,7 +1792,7 @@ class TestSaveLoad(TestCase):
|
||||
y = self.linear(y)
|
||||
return y
|
||||
|
||||
ep = export_for_training(Module(), inp, strict=True)
|
||||
ep = export(Module(), inp, strict=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
@ -1816,7 +1810,7 @@ class TestSaveLoad(TestCase):
|
||||
f = Foo()
|
||||
|
||||
inp = (torch.randn(2, 2),)
|
||||
ep = export_for_training(f, inp, strict=True)
|
||||
ep = export(f, inp, strict=True)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
save(ep, f.name)
|
||||
@ -1833,7 +1827,7 @@ class TestSaveLoad(TestCase):
|
||||
f = Foo()
|
||||
|
||||
inp = (torch.tensor([6]), torch.tensor([7]))
|
||||
ep = export_for_training(f, inp, strict=True)
|
||||
ep = export(f, inp, strict=True)
|
||||
|
||||
with TemporaryFileName(suffix=".pt2") as fname:
|
||||
path = Path(fname)
|
||||
@ -1851,7 +1845,7 @@ class TestSaveLoad(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(f, inp, strict=True)
|
||||
ep = export(f, inp, strict=True)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer, extra_files={"extra.txt": "moo"})
|
||||
@ -1872,7 +1866,7 @@ class TestSaveLoad(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(f, (torch.randn(1, 3),), strict=True)
|
||||
ep = export(f, (torch.randn(1, 3),), strict=True)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r"Saved archive version -1 does not match our current"
|
||||
@ -1908,7 +1902,7 @@ class TestSaveLoad(TestCase):
|
||||
list_tensor = [torch.tensor(3), torch.tensor(4)]
|
||||
return x + self.a + list_tensor[0] + list_tensor[1]
|
||||
|
||||
ep = export_for_training(Foo(), (torch.tensor(1),), strict=True)
|
||||
ep = export(Foo(), (torch.tensor(1),), strict=True)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
@ -1934,7 +1928,7 @@ class TestSerializeCustomClass(TestCase):
|
||||
f = Foo()
|
||||
|
||||
inputs = (torch.zeros(4, 4),)
|
||||
ep = export_for_training(f, inputs, strict=True)
|
||||
ep = export(f, inputs, strict=True)
|
||||
|
||||
# Replace one of the values with an instance of our custom class
|
||||
for node in ep.graph.nodes:
|
||||
@ -1988,7 +1982,7 @@ class TestSerializeCustomClass(TestCase):
|
||||
|
||||
inputs = (torch.zeros(2, 3),)
|
||||
with enable_torchbind_tracing():
|
||||
ep = export_for_training(f, inputs, strict=False)
|
||||
ep = export(f, inputs, strict=False)
|
||||
|
||||
serialized_vals = serialize(ep)
|
||||
ep = deserialize(serialized_vals)
|
||||
@ -2008,7 +2002,7 @@ class TestSerializeCustomClass(TestCase):
|
||||
|
||||
inputs = (torch.zeros(2, 3),)
|
||||
with enable_torchbind_tracing():
|
||||
ep = export_for_training(f, inputs, strict=False)
|
||||
ep = export(f, inputs, strict=False)
|
||||
|
||||
serialized_vals = serialize(ep)
|
||||
ep = deserialize(serialized_vals)
|
||||
@ -2043,7 +2037,7 @@ def forward(self, x):
|
||||
f = Foo()
|
||||
|
||||
inputs = (torch.zeros(4, 4),)
|
||||
ep = export_for_training(f, inputs, strict=True)
|
||||
ep = export(f, inputs, strict=True)
|
||||
|
||||
new_gm = copy.deepcopy(ep.graph_module)
|
||||
new_gm.meta["custom"] = {}
|
||||
@ -2078,7 +2072,7 @@ def forward(self, x):
|
||||
f = Foo()
|
||||
|
||||
inputs = (torch.ones(2, 2),)
|
||||
ep = export_for_training(f, inputs, strict=True)
|
||||
ep = export(f, inputs, strict=True)
|
||||
|
||||
new_gm = copy.deepcopy(ep.graph_module)
|
||||
new_gm.meta["custom"] = {}
|
||||
@ -2114,7 +2108,7 @@ def forward(self, x):
|
||||
f = Foo()
|
||||
|
||||
inputs = (torch.zeros(4, 4),)
|
||||
ep = export_for_training(f, inputs, strict=True)
|
||||
ep = export(f, inputs, strict=True)
|
||||
|
||||
new_gm = copy.deepcopy(ep.graph_module)
|
||||
new_gm.meta["custom"] = {}
|
||||
|
@ -138,7 +138,7 @@ class TestExportTorchbind(TestCase):
|
||||
def export_wrapper(f, args, kwargs, strict, pre_dispatch):
|
||||
with enable_torchbind_tracing():
|
||||
if pre_dispatch:
|
||||
exported_program = torch.export.export_for_training(
|
||||
exported_program = torch.export.export(
|
||||
f, args, kwargs, strict=strict
|
||||
).run_decompositions({})
|
||||
else:
|
||||
@ -755,7 +755,7 @@ def forward(self, arg0_1, arg1_1):
|
||||
b = torch.randn(2, 2)
|
||||
tq.push(a)
|
||||
tq.push(b)
|
||||
ep = torch.export.export_for_training(
|
||||
ep = torch.export.export(
|
||||
mod, (tq, torch.randn(2, 2)), strict=False
|
||||
).run_decompositions({})
|
||||
self.assertExpectedInline(
|
||||
@ -809,9 +809,9 @@ def forward(self, L_safe_obj_ : torch.ScriptObject):
|
||||
)
|
||||
|
||||
with enable_torchbind_tracing():
|
||||
ep = torch.export.export_for_training(
|
||||
mod, (safe_obj,), strict=False
|
||||
).run_decompositions({})
|
||||
ep = torch.export.export(mod, (safe_obj,), strict=False).run_decompositions(
|
||||
{}
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
@ -1407,9 +1407,9 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
|
||||
x = torch.randn(3, 1)
|
||||
eager_out = mod(test_obj, x)
|
||||
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x)
|
||||
ep = torch.export.export_for_training(
|
||||
mod, (test_obj, x), strict=False
|
||||
).run_decompositions({})
|
||||
ep = torch.export.export(mod, (test_obj, x), strict=False).run_decompositions(
|
||||
{}
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
ep.graph_module.code.strip(),
|
||||
"""\
|
||||
|
@ -7,14 +7,14 @@ except ImportError:
|
||||
import test_unflatten # @manual=fbcode//caffe2/test:test_export-library
|
||||
import testing # @manual=fbcode//caffe2/test:test_export-library
|
||||
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
|
||||
|
||||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_training_ir_export(*args, **kwargs):
|
||||
return export_for_training(*args, **kwargs, strict=True)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
|
@ -6,7 +6,7 @@ from functorch.experimental import control_flow
|
||||
from torch import Tensor
|
||||
from torch._dynamo.eval_frame import is_dynamo_supported
|
||||
from torch._export.verifier import SpecViolationError, Verifier
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
||||
|
||||
@ -20,7 +20,7 @@ class TestVerifier(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(f, (torch.randn(100), torch.randn(100)), strict=True)
|
||||
ep = export(f, (torch.randn(100), torch.randn(100)), strict=True)
|
||||
|
||||
verifier = Verifier()
|
||||
verifier.check(ep)
|
||||
@ -47,7 +47,7 @@ class TestVerifier(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(
|
||||
ep = export(
|
||||
f, (torch.randn(100), torch.randn(100)), strict=True
|
||||
).run_decompositions({})
|
||||
for node in ep.graph.nodes:
|
||||
@ -72,7 +72,7 @@ class TestVerifier(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True)
|
||||
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True)
|
||||
|
||||
verifier = Verifier()
|
||||
verifier.check(ep)
|
||||
@ -91,7 +91,7 @@ class TestVerifier(TestCase):
|
||||
|
||||
f = Foo()
|
||||
|
||||
ep = export_for_training(
|
||||
ep = export(
|
||||
f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True
|
||||
).run_decompositions({})
|
||||
for node in ep.graph_module.true_graph_0.graph.nodes:
|
||||
@ -111,7 +111,7 @@ class TestVerifier(TestCase):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.linear(x)
|
||||
|
||||
ep = export_for_training(M(), (torch.randn(10, 10),), strict=True)
|
||||
ep = export(M(), (torch.randn(10, 10),), strict=True)
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_invalid_param(self) -> None:
|
||||
@ -125,7 +125,7 @@ class TestVerifier(TestCase):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + self.a
|
||||
|
||||
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
||||
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
||||
|
||||
# Parameter doesn't exist in the state dict
|
||||
ep.graph_signature.input_specs[0] = InputSpec(
|
||||
@ -150,7 +150,7 @@ class TestVerifier(TestCase):
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y + self.a
|
||||
|
||||
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
||||
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
|
||||
|
||||
# Buffer doesn't exist in the state dict
|
||||
ep.graph_signature.input_specs[0] = InputSpec(
|
||||
@ -182,9 +182,7 @@ class TestVerifier(TestCase):
|
||||
self.my_buffer2.add_(1.0)
|
||||
return output
|
||||
|
||||
ep = export_for_training(
|
||||
M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True
|
||||
)
|
||||
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
|
||||
ep.validate()
|
||||
|
||||
def test_ep_verifier_invalid_output(self) -> None:
|
||||
@ -207,9 +205,7 @@ class TestVerifier(TestCase):
|
||||
self.my_buffer2.add_(1.0)
|
||||
return output
|
||||
|
||||
ep = export_for_training(
|
||||
M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True
|
||||
)
|
||||
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
|
||||
|
||||
output_node = list(ep.graph.nodes)[-1]
|
||||
output_node.args = (
|
||||
|
@ -6,7 +6,7 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.fx import symbolic_trace
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
@ -172,7 +172,7 @@ class TestMatcher(JitTestCase):
|
||||
torch.randn(1, 3, 3, 3) * 10,
|
||||
torch.randn(3, 3, 3, 3),
|
||||
)
|
||||
pattern_gm = export_for_training(
|
||||
pattern_gm = export(
|
||||
WrapperModule(pattern), example_inputs, strict=True
|
||||
).module()
|
||||
before_split_res = pattern_gm(*example_inputs)
|
||||
@ -203,11 +203,11 @@ class TestMatcher(JitTestCase):
|
||||
torch.randn(1, 3, 3, 3) * 10,
|
||||
torch.randn(3, 3, 3, 3),
|
||||
)
|
||||
pattern_gm = export_for_training(
|
||||
pattern_gm = export(
|
||||
WrapperModule(pattern), example_inputs, strict=True
|
||||
).module()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = export_for_training(
|
||||
target_gm = export(
|
||||
WrapperModule(target_graph), example_inputs, strict=True
|
||||
).module()
|
||||
internal_matches = matcher.match(target_gm.graph)
|
||||
@ -248,11 +248,9 @@ class TestMatcher(JitTestCase):
|
||||
return linear, {"linear": linear, "x": x}
|
||||
|
||||
example_inputs = (torch.randn(3, 5),)
|
||||
pattern_gm = export_for_training(
|
||||
Pattern(), example_inputs, strict=True
|
||||
).module()
|
||||
pattern_gm = export(Pattern(), example_inputs, strict=True).module()
|
||||
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
|
||||
target_gm = export_for_training(M(), example_inputs, strict=True).module()
|
||||
target_gm = export(M(), example_inputs, strict=True).module()
|
||||
internal_matches = matcher.match(target_gm.graph)
|
||||
for internal_match in internal_matches:
|
||||
name_node_map = internal_match.name_node_map
|
||||
|
@ -34,7 +34,7 @@ from torch._library import capture_triton
|
||||
from torch._utils_internal import full_aoti_runtime_assert
|
||||
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
|
||||
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
|
||||
from torch.export import Dim, export, export_for_training
|
||||
from torch.export import Dim, export
|
||||
from torch.export.pt2_archive._package import load_pt2
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
@ -2525,9 +2525,7 @@ class AOTInductorTestsTemplate:
|
||||
config.patch({"freezing": True, "aot_inductor.force_mmap_weights": True}),
|
||||
torch.no_grad(),
|
||||
):
|
||||
exported_model = export_for_training(
|
||||
model, example_inputs, strict=True
|
||||
).module()
|
||||
exported_model = export(model, example_inputs, strict=True).module()
|
||||
quantizer = X86InductorQuantizer()
|
||||
quantizer.set_global(
|
||||
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)
|
||||
|
@ -24,7 +24,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
OP_TO_ANNOTATOR,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_quantization import QuantizationTestCase
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
|
||||
|
||||
@ -101,7 +101,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
|
@ -102,7 +102,7 @@ class TestMetaDataPorting(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
|
||||
m = torch.export.export(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
|
@ -19,7 +19,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_quantization import TestHelperModules
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_WINDOWS,
|
||||
@ -86,7 +86,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_simple(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
debug_handle_map = self._extract_debug_handles(ep)
|
||||
@ -96,7 +96,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_control_flow(self):
|
||||
m = TestHelperModules.ControlFlow()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -107,7 +107,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_quantize_pt2e_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
|
||||
@ -167,14 +167,14 @@ class TestNumericDebugger(TestCase):
|
||||
def test_re_export_preserve_handle(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
debug_handle_map_ref = self._extract_debug_handles(ep)
|
||||
|
||||
ep_reexport = export_for_training(m, example_inputs, strict=True)
|
||||
ep_reexport = export(m, example_inputs, strict=True)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep_reexport)
|
||||
debug_handle_map = self._extract_debug_handles(ep_reexport)
|
||||
@ -184,7 +184,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_run_decompositions_same_handle_id(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -209,7 +209,7 @@ class TestNumericDebugger(TestCase):
|
||||
|
||||
for m in test_models:
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
|
||||
self._assert_each_node_has_debug_handle(ep)
|
||||
@ -232,7 +232,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_prepare_for_propagation_comparison(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_logger = prepare_for_propagation_comparison(m)
|
||||
@ -249,7 +249,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_extract_results_from_loggers(self):
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
@ -274,7 +274,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_extract_results_from_loggers_list_output(self):
|
||||
m = TestHelperModules.Conv2dWithSplit()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
m = ep.module()
|
||||
m_ref_logger = prepare_for_propagation_comparison(m)
|
||||
@ -304,7 +304,7 @@ class TestNumericDebugger(TestCase):
|
||||
def test_added_node_gets_unique_id(self) -> None:
|
||||
m = TestHelperModules.Conv2dThenConv1d()
|
||||
example_inputs = m.example_inputs()
|
||||
ep = export_for_training(m, example_inputs, strict=True)
|
||||
ep = export(m, example_inputs, strict=True)
|
||||
generate_numeric_debug_handle(ep)
|
||||
ref_handles = self._extract_debug_handles(ep)
|
||||
ref_counter = Counter(ref_handles.values())
|
||||
|
@ -39,7 +39,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
|
||||
OP_TO_ANNOTATOR,
|
||||
QuantizationConfig,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.fx import Node
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
@ -767,7 +767,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
|
||||
|
||||
# program capture
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
# make sure the two observers for input are shared
|
||||
conv_output_obs = []
|
||||
@ -827,7 +827,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
# make sure the two input observers and output are shared
|
||||
@ -1146,7 +1146,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
quantizer = BackendAQuantizer()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -1296,7 +1296,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
m = M().eval()
|
||||
example_inputs = torch.randn(1, 2, 3, 3)
|
||||
m = export_for_training(m, (example_inputs,), strict=True).module()
|
||||
m = export(m, (example_inputs,), strict=True).module()
|
||||
with self.assertRaises(Exception):
|
||||
m = prepare_pt2e(m, BackendAQuantizer())
|
||||
|
||||
@ -1419,7 +1419,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer.set_global(operator_config)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
weight_meta = None
|
||||
for n in m.graph.nodes:
|
||||
if (
|
||||
@ -1506,7 +1506,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().eval()
|
||||
quantizer = TestQuantizer()
|
||||
example_inputs = (torch.randn(1, 2, 3, 3),)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -1557,7 +1557,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
torch.randn(1, 2, 3, 3),
|
||||
torch.randn(1, 2, 3, 3),
|
||||
)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
node_occurrence = {
|
||||
@ -1812,7 +1812,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
m = M().train()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
if inplace:
|
||||
target = torch.ops.aten.dropout_.default
|
||||
else:
|
||||
@ -1877,7 +1877,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
# Assert that batch norm op exists and is in train mode
|
||||
bn_node = self._get_node(m, bn_train_op)
|
||||
@ -1908,7 +1908,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m.train()
|
||||
|
||||
# After export: this is not OK
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
m.eval()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
@ -1949,7 +1949,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
|
||||
targets = [n.target for n in m.graph.nodes]
|
||||
@ -2015,7 +2015,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
m = M().train()
|
||||
example_inputs = (torch.randn(1, 3, 3, 3),)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
torch.ao.quantization.allow_exported_model_train_eval(m)
|
||||
|
||||
# Mock m.recompile() to count how many times it's been called
|
||||
@ -2047,7 +2047,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
def test_model_is_exported(self):
|
||||
m = TestHelperModules.ConvWithBNRelu(relu=True)
|
||||
example_inputs = (torch.rand(3, 3, 5, 5),)
|
||||
exported_gm = export_for_training(m, example_inputs, strict=True).module()
|
||||
exported_gm = export(m, example_inputs, strict=True).module()
|
||||
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
|
||||
self.assertTrue(
|
||||
torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
|
||||
@ -2065,9 +2065,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
|
||||
)
|
||||
m.conv_bn_relu = export_for_training(
|
||||
m.conv_bn_relu, example_inputs, strict=True
|
||||
).module()
|
||||
m.conv_bn_relu = export(m.conv_bn_relu, example_inputs, strict=True).module()
|
||||
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
|
||||
m(*example_inputs)
|
||||
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
|
||||
@ -2075,7 +2073,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
quantizer = XNNPACKQuantizer().set_module_type(
|
||||
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
|
||||
)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m = convert_pt2e(m)
|
||||
|
||||
@ -2247,7 +2245,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
def dynamic_quantize_pt2e(model, example_inputs):
|
||||
torch._dynamo.reset()
|
||||
model = export_for_training(model, example_inputs, strict=True).module()
|
||||
model = export(model, example_inputs, strict=True).module()
|
||||
# Per channel quantization for weight
|
||||
# Dynamic quantization for activation
|
||||
# Please read a detail: https://fburl.com/code/30zds51q
|
||||
@ -2462,7 +2460,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
m = M()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer().set_global(
|
||||
get_symmetric_quantization_config(),
|
||||
)
|
||||
@ -2544,7 +2542,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
edge_or_node_to_obs_or_fq[x] = new_observer
|
||||
|
||||
example_inputs = (torch.rand(1, 32, 16, 16),)
|
||||
gm = export_for_training(Model().eval(), example_inputs, strict=True).module()
|
||||
gm = export(Model().eval(), example_inputs, strict=True).module()
|
||||
gm = prepare_pt2e(gm, BackendAQuantizer())
|
||||
gm = convert_pt2e(gm)
|
||||
for n in gm.graph.nodes:
|
||||
@ -2571,9 +2569,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
||||
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
|
||||
)
|
||||
|
||||
m.conv_bn_relu = export_for_training(
|
||||
m.conv_bn_relu, example_inputs, strict=True
|
||||
).module()
|
||||
m.conv_bn_relu = export(m.conv_bn_relu, example_inputs, strict=True).module()
|
||||
for node in m.conv_bn_relu.graph.nodes:
|
||||
if node.op not in ["placeholder", "output", "get_attr"]:
|
||||
check_nn_module(node)
|
||||
|
@ -34,7 +34,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
@ -140,9 +140,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
is_per_channel=is_per_channel, is_qat=True
|
||||
)
|
||||
)
|
||||
model_pt2e = export_for_training(
|
||||
model_pt2e, example_inputs, strict=True
|
||||
).module()
|
||||
model_pt2e = export(model_pt2e, example_inputs, strict=True).module()
|
||||
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
|
||||
torch.manual_seed(MANUAL_SEED)
|
||||
after_prepare_result_pt2e = model_pt2e(*example_inputs)
|
||||
@ -229,7 +227,7 @@ class PT2EQATTestCase(QuantizationTestCase):
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel, is_qat=True)
|
||||
)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
|
||||
@ -618,7 +616,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
m = M(self.conv_class, self.bn_class, backbone)
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
m = convert_pt2e(m)
|
||||
@ -676,7 +674,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_bias_derived_qspec(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -723,7 +721,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_per_channel_weight_custom_dtype(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnInt32WeightQuantizer()
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -777,7 +775,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
def test_qat_conv_bn_per_channel_weight_bias(self):
|
||||
m = self._get_conv_bn_model()
|
||||
example_inputs = self.example_inputs
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -834,7 +832,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
it into conv in `convert_pt2e` even in train mode.
|
||||
"""
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = export_for_training(m, self.example_inputs, strict=True).module()
|
||||
m = export(m, self.example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_global(
|
||||
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
|
||||
@ -850,7 +848,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
|
||||
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
|
||||
"""
|
||||
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
|
||||
m = export_for_training(m, self.example_inputs, strict=True).module()
|
||||
m = export(m, self.example_inputs, strict=True).module()
|
||||
|
||||
def _has_add_(graph):
|
||||
for node in graph.nodes:
|
||||
@ -1115,9 +1113,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
in_channels = child.linear1.weight.size(1)
|
||||
|
||||
example_input = (torch.rand((1, in_channels)),)
|
||||
traced_child = export_for_training(
|
||||
child, example_input, strict=True
|
||||
).module()
|
||||
traced_child = export(child, example_input, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=True, is_qat=True
|
||||
@ -1148,7 +1144,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
|
||||
self._convert_qat_linears(model)
|
||||
model(*example_inputs)
|
||||
|
||||
model_pt2e = export_for_training(model, example_inputs, strict=True).module()
|
||||
model_pt2e = export(model, example_inputs, strict=True).module()
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantizer.set_module_type(torch.nn.Linear, None)
|
||||
|
@ -10,7 +10,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
@ -34,7 +34,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
|
||||
) -> torch.nn.Module:
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
model = export_for_training(model, example_inputs, strict=True).module()
|
||||
model = export(model, example_inputs, strict=True).module()
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
model = prepare_pt2e(model, quantizer)
|
||||
|
@ -17,7 +17,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
|
||||
QUANT_ANNOTATION_KEY,
|
||||
X86InductorQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
QuantizationTestCase,
|
||||
@ -668,7 +668,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||
|
||||
# program capture
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
# QAT Model failed to deepcopy
|
||||
export_model = m if is_qat else copy.deepcopy(m)
|
||||
@ -2344,7 +2344,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||
)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Use a linear count instead of names because the names might change, but
|
||||
# the order should be the same.
|
||||
|
@ -29,7 +29,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
get_symmetric_quantization_config,
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.testing._internal.common_quantization import (
|
||||
NodeSpec as ns,
|
||||
PT2EQuantizationTestCase,
|
||||
@ -362,7 +362,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
)
|
||||
example_inputs = (torch.randn(2, 2),)
|
||||
m = M().eval()
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
# Use a linear count instead of names because the names might change, but
|
||||
# the order should be the same.
|
||||
@ -498,7 +498,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
example_inputs = (torch.randn(1, 3, 5, 5),)
|
||||
|
||||
# program capture
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
m = prepare_pt2e(m, quantizer)
|
||||
m(*example_inputs)
|
||||
@ -763,9 +763,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
model_fx = _convert_to_reference_decomposed_fx(model_fx)
|
||||
|
||||
with torchdynamo.config.patch(allow_rnn=True):
|
||||
model_graph = export_for_training(
|
||||
model_graph, example_inputs, strict=True
|
||||
).module()
|
||||
model_graph = export(model_graph, example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=False, is_dynamic=False
|
||||
@ -825,9 +823,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
|
||||
model_fx = _convert_to_reference_decomposed_fx(model_fx)
|
||||
|
||||
with torchdynamo.config.patch(allow_rnn=True):
|
||||
model_graph = export_for_training(
|
||||
model_graph, example_inputs, strict=True
|
||||
).module()
|
||||
model_graph = export(model_graph, example_inputs, strict=True).module()
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(
|
||||
is_per_channel=False, is_dynamic=False
|
||||
@ -1035,7 +1031,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
|
||||
m = torchvision.models.resnet18().eval()
|
||||
m_copy = copy.deepcopy(m)
|
||||
# program capture
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
|
||||
quantizer = XNNPACKQuantizer()
|
||||
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
|
||||
|
@ -27,9 +27,7 @@ class TestQuantizePT2EModels(TestCase):
|
||||
m = m.eval()
|
||||
input_shape = (1, 3, 224, 224)
|
||||
example_inputs = (torch.randn(input_shape),)
|
||||
m = torch.export.export_for_training(
|
||||
m, copy.deepcopy(example_inputs), strict=True
|
||||
).module()
|
||||
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
|
||||
m(*example_inputs)
|
||||
m = export.export(m, copy.deepcopy(example_inputs))
|
||||
ops = _get_ops_list(m.graph_module)
|
||||
|
@ -50,7 +50,7 @@ def lower_pt2e_quantized_to_x86(
|
||||
m.recompile()
|
||||
|
||||
lowered_model = (
|
||||
torch.export.export_for_training(model, example_inputs, strict=True)
|
||||
torch.export.export(model, example_inputs, strict=True)
|
||||
.run_decompositions(_post_autograd_decomp_table())
|
||||
.module()
|
||||
)
|
||||
|
@ -356,7 +356,7 @@ def _get_aten_graph_module_for_pattern(
|
||||
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
|
||||
)
|
||||
|
||||
aten_pattern = torch.export.export_for_training(
|
||||
aten_pattern = torch.export.export(
|
||||
pattern, # type: ignore[arg-type]
|
||||
example_inputs,
|
||||
kwargs,
|
||||
|
@ -1002,9 +1002,7 @@ class Pipe(torch.nn.Module):
|
||||
) -> ExportedProgram:
|
||||
logger.info("Tracing model ...")
|
||||
try:
|
||||
ep = torch.export.export_for_training(
|
||||
mod, example_args, example_kwargs, strict=True
|
||||
)
|
||||
ep = torch.export.export(mod, example_args, example_kwargs, strict=True)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"It seems that we cannot capture your model as a full graph. "
|
||||
|
@ -58,7 +58,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
|
||||
XNNPACKQuantizer,
|
||||
)
|
||||
|
||||
from torch.export import export_for_training
|
||||
from torch.export import export
|
||||
from torch.jit.mobile import _load_for_lite_interpreter
|
||||
from torch.testing._internal.common_quantized import override_quantized_engine
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
|
||||
@ -1513,7 +1513,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
{0: torch.export.Dim("dim")} if i == 0 else None
|
||||
for i in range(len(example_inputs))
|
||||
)
|
||||
m = export_for_training(
|
||||
m = export(
|
||||
m,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
|
||||
@ -1554,7 +1554,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
m_fx = _convert_to_reference_decomposed_fx(
|
||||
m_fx, backend_config=backend_config
|
||||
)
|
||||
m_fx = export_for_training(
|
||||
m_fx = export(
|
||||
m_fx,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
|
||||
@ -1578,7 +1578,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
|
||||
# resetting dynamo cache
|
||||
torch._dynamo.reset()
|
||||
|
||||
m = export_for_training(m, example_inputs, strict=True).module()
|
||||
m = export(m, example_inputs, strict=True).module()
|
||||
if is_qat:
|
||||
m = prepare_qat_pt2e(m, quantizer)
|
||||
else:
|
||||
@ -3183,12 +3183,15 @@ class TestHelperModules:
|
||||
x = self.adaptive_avg_pool2d(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvWithBNRelu(torch.nn.Module):
|
||||
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
|
||||
super().__init__()
|
||||
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
|
||||
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
|
||||
bns = {
|
||||
1: torch.nn.BatchNorm1d,
|
||||
2: torch.nn.BatchNorm2d,
|
||||
3: torch.nn.BatchNorm3d,
|
||||
}
|
||||
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
|
||||
|
||||
if bn:
|
||||
@ -3394,7 +3397,7 @@ def _generate_qdq_quantized_model(
|
||||
|
||||
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
|
||||
with maybe_no_grad:
|
||||
export_model = export_for_training(mod, inputs, strict=True).module(check_guards=False)
|
||||
export_model = export(mod, inputs, strict=True).module(check_guards=False)
|
||||
quantizer = (
|
||||
quantizer
|
||||
if quantizer
|
||||
|
Reference in New Issue
Block a user