Replace export_for_training with export (#162396)

Summary: replace export_for_training with epxort

Test Plan:
CI

Rollback Plan:

Differential Revision: D81935792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162396
Approved by: https://github.com/angelayi, https://github.com/jerryzh168
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2025-09-10 14:19:34 +00:00
committed by PyTorch MergeBot
parent fc1b09a52a
commit de05dbc39c
24 changed files with 180 additions and 215 deletions

View File

@ -183,7 +183,7 @@ class TestDTensorCompile(torch._dynamo.test_case.TestCase):
)
torch.utils._pytree.register_constant(DeviceMesh)
ep = torch.export.export_for_training(
ep = torch.export.export(
Foo(), (torch.randn(4, 4, dtype=torch.float64),), strict=False
)
self.assertExpectedInline(

View File

@ -9,7 +9,7 @@ from torch._export.db.examples import (
filter_examples_by_support_level,
get_rewrite_cases,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
@ -35,7 +35,7 @@ class ExampleTests(TestCase):
kwargs_export = case.example_kwargs
args_model = copy.deepcopy(args_export)
kwargs_model = copy.deepcopy(kwargs_export)
exported_program = export_for_training(
exported_program = export(
model,
args_export,
kwargs_export,
@ -68,7 +68,7 @@ class ExampleTests(TestCase):
with self.assertRaises(
(torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
):
export_for_training(
export(
model,
case.example_args,
case.example_kwargs,
@ -94,7 +94,7 @@ class ExampleTests(TestCase):
self, name: str, rewrite_case: ExportCase
) -> None:
# pyre-ignore
export_for_training(
export(
rewrite_case.model,
rewrite_case.example_args,
rewrite_case.example_kwargs,

View File

@ -9,7 +9,7 @@ import torch
import torch._dynamo
from torch._dynamo.test_case import run_tests, TestCase
from torch._functorch.aot_autograd import aot_export_module
from torch.export import export, export_for_training
from torch.export import export
from torch.export.experimental import _export_forward_backward, _sticky_export
from torch.export.graph_signature import OutputKind
from torch.testing import FileCheck
@ -32,7 +32,7 @@ class TestExperiment(TestCase):
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
ep = torch.export.export_for_training(m, example_inputs, strict=True)
ep = torch.export.export(m, example_inputs, strict=True)
joint_ep = _export_forward_backward(ep)
self.assertExpectedInline(
str(joint_ep.graph_module.code).strip(),
@ -141,7 +141,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
m = Module()
example_inputs = (torch.randn(3),)
m(*example_inputs)
ep = torch.export.export_for_training(
ep = torch.export.export(
m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}}, strict=True
)
_export_forward_backward(ep)
@ -177,7 +177,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
labels = torch.ones(4, dtype=torch.int64)
inputs = (x, labels)
ep = export_for_training(net, inputs, strict=True)
ep = export(net, inputs, strict=True)
ep = _export_forward_backward(ep)
def test_joint_loss_index(self):
@ -197,7 +197,7 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x):
inputs = (torch.randn(4, 4),)
for i in [0, 1]:
ep = export_for_training(Foo(i), inputs, strict=True)
ep = export(Foo(i), inputs, strict=True)
ep_joint = _export_forward_backward(ep, joint_loss_index=i)
for j, spec in enumerate(ep_joint.graph_signature.output_specs):
if i == j:

View File

@ -42,13 +42,7 @@ from torch._higher_order_ops.scan import scan
from torch._higher_order_ops.while_loop import while_loop
from torch._inductor.compile_fx import split_const_gm
from torch._subclasses import FakeTensorMode
from torch.export import (
default_decompositions,
Dim,
export,
export_for_training,
unflatten,
)
from torch.export import default_decompositions, Dim, export, unflatten
from torch.export._trace import (
_export,
_export_to_torch_ir,
@ -1058,7 +1052,7 @@ graph():
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(exported_program.module()(*args), m(*args))
gm: torch.fx.GraphModule = torch.export.export_for_training(
gm: torch.fx.GraphModule = torch.export.export(
m, args=example_args, dynamic_shapes=dynamic_shapes
).module()
@ -2456,7 +2450,7 @@ def forward(self, x, y):
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,))
ep_training = torch.export.export(m, (ref_x,))
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2519,7 +2513,7 @@ graph():
ref_x = torch.randn(2, 2)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2651,7 +2645,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertTrue(torch.allclose(ep_training.module()(ref_x), ref_out))
self.assertExpectedInline(
str(ep_training.graph).strip(),
@ -2706,7 +2700,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2746,7 +2740,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2784,7 +2778,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2823,7 +2817,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -2863,7 +2857,7 @@ graph():
m = Foo()
ref_x = torch.randn(3, 4)
ref_out = m(ref_x)
ep_training = torch.export.export_for_training(m, (ref_x,), strict=False)
ep_training = torch.export.export(m, (ref_x,), strict=False)
self.assertExpectedInline(
str(ep_training.graph).strip(),
"""\
@ -3983,7 +3977,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
x_linear = self.linear(x_conv)
return x_linear.cos() + y_conv_1d.sum()
ep = torch.export.export_for_training(
ep = torch.export.export(
Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50))
)
@ -4251,9 +4245,7 @@ def forward(self, x):
return self.linear(x)
eager_model = Foo()
ep_for_training = torch.export.export_for_training(
eager_model, (torch.ones(2, 2),)
)
ep_for_training = torch.export.export(eager_model, (torch.ones(2, 2),))
self.assertExpectedInline(
str(ep_for_training.graph_module.code).strip(),
"""\
@ -4291,7 +4283,7 @@ def forward(self, x):
eager_model_for_export = Foo()
eager_model_for_testing = Foo()
ep_for_training = torch.export.export_for_training(
ep_for_training = torch.export.export(
eager_model_for_export, (torch.ones(4, 4),)
)
self.assertExpectedInline(
@ -4337,7 +4329,7 @@ def forward(self, x):
eager_model_for_export_training = Foo()
eager_model_for_export_inference = Foo()
eager_model_for_testing = Foo()
ep_for_training = torch.export.export_for_training(
ep_for_training = torch.export.export(
eager_model_for_export_training,
(torch.ones(4, 4),),
dynamic_shapes=({0: Dim("x")},),
@ -4391,7 +4383,7 @@ def forward(self, x):
return x + y + self.buffer.sum()
eager_model = Foo()
ep_for_training = torch.export.export_for_training(
ep_for_training = torch.export.export(
eager_model,
([torch.ones(4, 4), torch.ones(4, 4)],),
)
@ -4597,7 +4589,7 @@ def forward(self, x):
return self.linear(x) + self.buffer.sum()
eager_model = Foo()
ep_for_training = torch.export.export_for_training(
ep_for_training = torch.export.export(
eager_model,
(torch.ones(2, 2),),
)
@ -7530,7 +7522,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
inp = torch.randn(4, 4)
ep = export_for_training(
ep = torch.export.export(
Foo(), (inp,), strict=False, preserve_module_call_signature=("bar",)
)
unflat = unflatten(ep).bar
@ -7836,7 +7828,7 @@ graph():
decomp_table = {**default_decompositions(), **decomposition_table}
ep = export_for_training(M(), (torch.randn(2, 2),)).run_decompositions(
ep = torch.export.export(M(), (torch.randn(2, 2),)).run_decompositions(
decomp_table
)
@ -7865,7 +7857,7 @@ def forward(self, c_lifted_tensor_0, x):
mod.eval()
inp = torch.randn(1, 1, 3, 3)
gm = torch.export.export_for_training(mod, (inp,)).module()
gm = torch.export.export(mod, (inp,)).module()
self.assertExpectedInline(
str(gm.code).strip(),
"""\
@ -7885,7 +7877,7 @@ def forward(self, x):
)
mod.train()
gm_train = torch.export.export_for_training(mod, (inp,)).module()
gm_train = torch.export.export(mod, (inp,)).module()
self.assertExpectedInline(
str(gm_train.code).strip(),
"""\
@ -8450,7 +8442,7 @@ def forward(self, x):
ref_x = torch.randn(2, 2)
ref_out = f(ref_x, mod)
ep = torch.export.export_for_training(f, (torch.randn(2, 2), mod), strict=False)
ep = torch.export.export(f, (torch.randn(2, 2), mod), strict=False)
self.assertEqual(ref_out, ep.module()(ref_x, mod))
def test_unbacked_noncontig_lin(self):
@ -9645,7 +9637,7 @@ graph():
return m(x) * x
inps = (torch.randn(3, 3),)
ep = export_for_training(M2(), inps).run_decompositions({})
ep = torch.export.export(M2(), inps).run_decompositions({})
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
self.assertEqual(len(ep.state_dict), 0)
@ -9682,7 +9674,7 @@ graph():
inps = (torch.randn(3, 3),)
# Strict export segfaults (Issue #128109)
ep = export_for_training(M2(), inps, strict=False).run_decompositions({})
ep = torch.export.export(M2(), inps, strict=False).run_decompositions({})
self.assertTrue(torch.allclose(ep.module()(*inps), M2()(*inps)))
self.assertEqual(len(ep.state_dict), 0)
@ -12013,7 +12005,7 @@ graph():
if is_training_ir_test(self._testMethodName):
test(
torch.export.export_for_training(
torch.export.export(
M(),
inp,
strict=not is_non_strict_test(self._testMethodName),
@ -12134,7 +12126,7 @@ graph():
test(export(M(), inp))
strict = not is_non_strict_test(self._testMethodName)
ept = torch.export.export_for_training(
ept = torch.export.export(
M(),
inp,
strict=strict,
@ -12209,7 +12201,7 @@ graph():
x = torch.zeros((4, 4, 10))
ep_training = torch.export.export_for_training(model, (x,), strict=False)
ep_training = torch.export.export(model, (x,), strict=False)
state_dict_before = ep_training.state_dict
ep = export(model, (x,), strict=False).run_decompositions()
@ -12253,7 +12245,7 @@ def forward(self, c_params, x):
x = torch.zeros((4, 4, 10))
ep_training = torch.export.export_for_training(model, (x,), strict=False)
ep_training = torch.export.export(model, (x,), strict=False)
state_dict_before = ep_training.state_dict
ep = export(model, (x,), strict=False).run_decompositions()
@ -12772,7 +12764,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
model = Model()
with torch.no_grad():
exported_program = torch.export.export_for_training(
exported_program = torch.export.export(
model,
(torch.tensor(10), torch.tensor(12)),
{},
@ -12868,7 +12860,7 @@ def forward(self, x, b_t, y):
# no grad
model = Model()
with torch.no_grad():
ep_nograd = torch.export.export_for_training(
ep_nograd = torch.export.export(
model,
(torch.tensor(10), torch.tensor(12)),
{},
@ -12888,7 +12880,7 @@ def forward(self, x, b_t, y):
# enable grad
model = Model()
ep_grad = torch.export.export_for_training(
ep_grad = torch.export.export(
model,
(torch.tensor(10), torch.tensor(12)),
{},
@ -13011,7 +13003,7 @@ def forward(self, x, b_t, y):
"torch.ops.higher_order.wrap_with_set_grad_enabled",
ep.graph_module.code,
)
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
gm = torch.export.export(model, (torch.randn(4, 4),)).module()
self.assertIn(
"set_grad_enabled",
gm.code,
@ -13040,7 +13032,7 @@ def forward(self, x, b_t, y):
)
# _export_for_traininig is using pre_dispatch=False
# Therefore the autocast calls are not replaced with a hop.
gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module()
gm = torch.export.export(model, (torch.randn(4, 4),)).module()
self.assertIn(
"autocast",
gm.code,
@ -13287,7 +13279,7 @@ def forward(self, x, b_t, y):
inps = (torch.ones(5),)
ep = export_for_training(M(), inps).run_decompositions({})
ep = torch.export.export(M(), inps).run_decompositions({})
self.assertExpectedInline(
str(ep.graph_module.code.strip()),
"""\
@ -13608,7 +13600,7 @@ def forward(self, x):
return y + y_sum + unbacked_shape.sum()
inps = (torch.tensor(4), torch.randn(5, 5))
ep_pre = torch.export.export_for_training(Foo(), inps, strict=False)
ep_pre = torch.export.export(Foo(), inps, strict=False)
self.assertExpectedInline(
str(ep_pre.graph_module.submod_1.code).strip(),
"""\
@ -14298,7 +14290,7 @@ graph():
return val.b.a
mod = Foo()
ep = export_for_training(mod, (torch.randn(4, 4),), strict=False)
ep = torch.export.export(mod, (torch.randn(4, 4),), strict=False)
self.assertExpectedInline(
str(ep.graph).strip(),
"""\
@ -15311,7 +15303,7 @@ def forward(self, x):
x = torch.randn(2, 4)
y = torch.ones(4)
ep_for_training = torch.export.export_for_training(M(), (x, y), strict=strict)
ep_for_training = torch.export.export(M(), (x, y), strict=strict)
self.assertExpectedInline(
normalize_gm(
ep_for_training.graph_module.print_readable(print_output=False)

View File

@ -15,14 +15,14 @@ test_classes = {}
def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs):
if "strict" in kwargs:
ep = torch.export.export_for_training(*args, **kwargs)
ep = torch.export.export(*args, **kwargs)
else:
ep = torch.export.export_for_training(*args, **kwargs, strict=True)
ep = torch.export.export(*args, **kwargs, strict=True)
return ep.run_decompositions({})
def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs):
ep = torch.export.export_for_training(*args, **kwargs)
ep = torch.export.export(*args, **kwargs)
return ep.run_decompositions({})

View File

@ -45,7 +45,7 @@ from torch._export.serde.serialize import (
)
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.export import Dim, export_for_training, load, save, unflatten
from torch.export import Dim, export, load, save, unflatten
from torch.export.pt2_archive.constants import ARCHIVE_VERSION_PATH
from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges
from torch.testing._internal.common_utils import (
@ -115,7 +115,7 @@ class TestSerialize(TestCase):
return torch.ops.aten.add.Tensor._schema
inp = (torch.ones(10),)
ep = export_for_training(TestModule(), inp, strict=True)
ep = export(TestModule(), inp, strict=True)
# Register the custom op handler.
foo_custom_op = FooExtensionOp()
@ -180,9 +180,7 @@ class TestSerialize(TestCase):
model = MyModule().eval()
random_inputs = (torch.rand([2, 3]), torch.rand([2, 3]))
exp_program = export_for_training(
model, random_inputs, {"use_p": True}, strict=True
)
exp_program = export(model, random_inputs, {"use_p": True}, strict=True)
output_buffer = io.BytesIO()
# Tests that example inputs are preserved when saving and loading module.
@ -201,7 +199,7 @@ class TestSerialize(TestCase):
def forward(self, x):
return x.sin()
exp_program = export_for_training(M(), (torch.randn(4, 4),), strict=True)
exp_program = export(M(), (torch.randn(4, 4),), strict=True)
output_buffer = io.BytesIO()
# Tests that example forward arg names are preserved when saving and loading module.
@ -241,7 +239,7 @@ def forward(self, x):
inp = (torch.ones(10),)
# Module will only be able to roundtrip if metadata
# can be correctly parsed.
ep = export_for_training(MyModule(), inp, strict=True)
ep = export(MyModule(), inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer)
loaded_ep = load(buffer)
@ -282,7 +280,7 @@ def forward(self, x):
return h + out_c
inp = (torch.ones(10),)
ep = export_for_training(Foo(), inp, strict=True)
ep = export(Foo(), inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer)
loaded_ep = load(buffer)
@ -324,7 +322,7 @@ def forward(self, x):
# Check that module can be roundtripped, thereby confirming proper deserialization.
inp = (torch.ones(10),)
ep = export_for_training(MyModule(), inp, strict=True)
ep = export(MyModule(), inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer)
loaded_ep = load(buffer)
@ -347,7 +345,7 @@ def forward(self, x):
eps=1e-5,
)
exported_module = export_for_training(
exported_module = export(
MyModule(),
(
torch.ones([512, 512], requires_grad=True),
@ -391,7 +389,7 @@ def forward(self, x):
"b": {1: dim1_bc},
"c": {0: dim0_ac, 1: dim1_bc},
}
exported_module = export_for_training(
exported_module = export(
DynamicShapeSimpleModel(),
inputs,
dynamic_shapes=dynamic_shapes,
@ -455,7 +453,7 @@ def forward(self, x):
"b": {1: dim1_bc},
"c": {0: dim0_ac, 1: dim1_bc},
}
exported_module = export_for_training(
exported_module = export(
DynamicShapeSimpleModel(),
inputs,
dynamic_shapes=dynamic_shapes,
@ -485,9 +483,7 @@ def forward(self, x):
return torch.split(x, 2)
input = torch.arange(10.0).reshape(5, 2)
exported_module = export_for_training(
MyModule(), (input,), strict=True
).run_decompositions()
exported_module = export(MyModule(), (input,), strict=True).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
node = serialized.exported_program.graph_module.graph.nodes[-1]
@ -550,7 +546,7 @@ def forward(self, x):
def forward(self, x):
return torch.ops.aten.var_mean.correction(x, [1])[0]
exported_module = export_for_training(
exported_module = export(
MyModule(), (torch.ones([512, 512], requires_grad=True),), strict=True
).run_decompositions()
@ -571,7 +567,7 @@ def forward(self, x):
def forward(self, x):
return x + x
ep = export_for_training(
ep = export(
M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},), strict=True
)
@ -720,7 +716,7 @@ def forward(self, x):
f = Foo()
x, _ = torch.sort(torch.randn(3, 4))
exported_module = export_for_training(f, (x,), strict=True).run_decompositions()
exported_module = export(f, (x,), strict=True).run_decompositions()
serialized = ExportedProgramSerializer().serialize(exported_module)
node = serialized.exported_program.graph_module.graph.nodes[-1]
@ -738,9 +734,7 @@ def forward(self, x):
b = x + y
return b + a
ep = export_for_training(
Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True
)
ep = export(Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True)
s = ExportedProgramSerializer().serialize(ep)
c = canonicalize(s.exported_program)
g = c.graph_module.graph
@ -754,7 +748,7 @@ def forward(self, x):
def forward(self, x):
return torch.ops.aten.sum.dim_IntList(x, [])
ep = torch.export.export_for_training(M(), (torch.randn(3, 2),), strict=True)
ep = torch.export.export(M(), (torch.randn(3, 2),), strict=True)
serialized = ExportedProgramSerializer().serialize(ep)
for node in serialized.exported_program.graph_module.graph.nodes:
if "aten.sum.dim_IntList" in node.target:
@ -1024,7 +1018,7 @@ class TestDeserialize(TestCase):
def _check_graph(pre_dispatch):
if pre_dispatch:
ep = torch.export.export_for_training(
ep = torch.export.export(
fn,
_deepcopy_inputs(inputs),
{},
@ -1574,7 +1568,7 @@ class TestDeserialize(TestCase):
a = a * 2
return a, b
ep = torch.export.export_for_training(M(), (torch.ones(3),), strict=True)
ep = torch.export.export(M(), (torch.ones(3),), strict=True)
# insert another getitem node
for node in ep.graph.nodes:
@ -1720,7 +1714,7 @@ def forward(self, x):
def forward(self):
return self.p * self.p
ep = torch.export.export_for_training(M(), (), strict=True)
ep = torch.export.export(M(), (), strict=True)
ep._example_inputs = None
roundtrip_ep = deserialize(serialize(ep))
self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
@ -1762,7 +1756,7 @@ class TestSchemaVersioning(TestCase):
return x + x
f = Module()
ep = export_for_training(f, (torch.randn(1, 3),), strict=True)
ep = export(f, (torch.randn(1, 3),), strict=True)
serialized_program = ExportedProgramSerializer().serialize(ep)
serialized_program.exported_program.schema_version.major = -1
@ -1798,7 +1792,7 @@ class TestSaveLoad(TestCase):
y = self.linear(y)
return y
ep = export_for_training(Module(), inp, strict=True)
ep = export(Module(), inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer)
@ -1816,7 +1810,7 @@ class TestSaveLoad(TestCase):
f = Foo()
inp = (torch.randn(2, 2),)
ep = export_for_training(f, inp, strict=True)
ep = export(f, inp, strict=True)
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
save(ep, f.name)
@ -1833,7 +1827,7 @@ class TestSaveLoad(TestCase):
f = Foo()
inp = (torch.tensor([6]), torch.tensor([7]))
ep = export_for_training(f, inp, strict=True)
ep = export(f, inp, strict=True)
with TemporaryFileName(suffix=".pt2") as fname:
path = Path(fname)
@ -1851,7 +1845,7 @@ class TestSaveLoad(TestCase):
f = Foo()
ep = export_for_training(f, inp, strict=True)
ep = export(f, inp, strict=True)
buffer = io.BytesIO()
save(ep, buffer, extra_files={"extra.txt": "moo"})
@ -1872,7 +1866,7 @@ class TestSaveLoad(TestCase):
f = Foo()
ep = export_for_training(f, (torch.randn(1, 3),), strict=True)
ep = export(f, (torch.randn(1, 3),), strict=True)
with self.assertRaisesRegex(
ValueError, r"Saved archive version -1 does not match our current"
@ -1908,7 +1902,7 @@ class TestSaveLoad(TestCase):
list_tensor = [torch.tensor(3), torch.tensor(4)]
return x + self.a + list_tensor[0] + list_tensor[1]
ep = export_for_training(Foo(), (torch.tensor(1),), strict=True)
ep = export(Foo(), (torch.tensor(1),), strict=True)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
@ -1934,7 +1928,7 @@ class TestSerializeCustomClass(TestCase):
f = Foo()
inputs = (torch.zeros(4, 4),)
ep = export_for_training(f, inputs, strict=True)
ep = export(f, inputs, strict=True)
# Replace one of the values with an instance of our custom class
for node in ep.graph.nodes:
@ -1988,7 +1982,7 @@ class TestSerializeCustomClass(TestCase):
inputs = (torch.zeros(2, 3),)
with enable_torchbind_tracing():
ep = export_for_training(f, inputs, strict=False)
ep = export(f, inputs, strict=False)
serialized_vals = serialize(ep)
ep = deserialize(serialized_vals)
@ -2008,7 +2002,7 @@ class TestSerializeCustomClass(TestCase):
inputs = (torch.zeros(2, 3),)
with enable_torchbind_tracing():
ep = export_for_training(f, inputs, strict=False)
ep = export(f, inputs, strict=False)
serialized_vals = serialize(ep)
ep = deserialize(serialized_vals)
@ -2043,7 +2037,7 @@ def forward(self, x):
f = Foo()
inputs = (torch.zeros(4, 4),)
ep = export_for_training(f, inputs, strict=True)
ep = export(f, inputs, strict=True)
new_gm = copy.deepcopy(ep.graph_module)
new_gm.meta["custom"] = {}
@ -2078,7 +2072,7 @@ def forward(self, x):
f = Foo()
inputs = (torch.ones(2, 2),)
ep = export_for_training(f, inputs, strict=True)
ep = export(f, inputs, strict=True)
new_gm = copy.deepcopy(ep.graph_module)
new_gm.meta["custom"] = {}
@ -2114,7 +2108,7 @@ def forward(self, x):
f = Foo()
inputs = (torch.zeros(4, 4),)
ep = export_for_training(f, inputs, strict=True)
ep = export(f, inputs, strict=True)
new_gm = copy.deepcopy(ep.graph_module)
new_gm.meta["custom"] = {}

View File

@ -138,7 +138,7 @@ class TestExportTorchbind(TestCase):
def export_wrapper(f, args, kwargs, strict, pre_dispatch):
with enable_torchbind_tracing():
if pre_dispatch:
exported_program = torch.export.export_for_training(
exported_program = torch.export.export(
f, args, kwargs, strict=strict
).run_decompositions({})
else:
@ -755,7 +755,7 @@ def forward(self, arg0_1, arg1_1):
b = torch.randn(2, 2)
tq.push(a)
tq.push(b)
ep = torch.export.export_for_training(
ep = torch.export.export(
mod, (tq, torch.randn(2, 2)), strict=False
).run_decompositions({})
self.assertExpectedInline(
@ -809,9 +809,9 @@ def forward(self, L_safe_obj_ : torch.ScriptObject):
)
with enable_torchbind_tracing():
ep = torch.export.export_for_training(
mod, (safe_obj,), strict=False
).run_decompositions({})
ep = torch.export.export(mod, (safe_obj,), strict=False).run_decompositions(
{}
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
@ -1407,9 +1407,9 @@ def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
x = torch.randn(3, 1)
eager_out = mod(test_obj, x)
compiled_out = torch.compile(mod, backend=backend, fullgraph=True)(test_obj, x)
ep = torch.export.export_for_training(
mod, (test_obj, x), strict=False
).run_decompositions({})
ep = torch.export.export(mod, (test_obj, x), strict=False).run_decompositions(
{}
)
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\

View File

@ -7,14 +7,14 @@ except ImportError:
import test_unflatten # @manual=fbcode//caffe2/test:test_export-library
import testing # @manual=fbcode//caffe2/test:test_export-library
from torch.export import export_for_training
from torch.export import export
test_classes = {}
def mocked_training_ir_export(*args, **kwargs):
return export_for_training(*args, **kwargs, strict=True)
return export(*args, **kwargs, strict=True)
def make_dynamic_cls(cls):

View File

@ -6,7 +6,7 @@ from functorch.experimental import control_flow
from torch import Tensor
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.verifier import SpecViolationError, Verifier
from torch.export import export_for_training
from torch.export import export
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
@ -20,7 +20,7 @@ class TestVerifier(TestCase):
f = Foo()
ep = export_for_training(f, (torch.randn(100), torch.randn(100)), strict=True)
ep = export(f, (torch.randn(100), torch.randn(100)), strict=True)
verifier = Verifier()
verifier.check(ep)
@ -47,7 +47,7 @@ class TestVerifier(TestCase):
f = Foo()
ep = export_for_training(
ep = export(
f, (torch.randn(100), torch.randn(100)), strict=True
).run_decompositions({})
for node in ep.graph.nodes:
@ -72,7 +72,7 @@ class TestVerifier(TestCase):
f = Foo()
ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True)
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True)
verifier = Verifier()
verifier.check(ep)
@ -91,7 +91,7 @@ class TestVerifier(TestCase):
f = Foo()
ep = export_for_training(
ep = export(
f, (torch.randn(3, 3), torch.randn(3, 3)), strict=True
).run_decompositions({})
for node in ep.graph_module.true_graph_0.graph.nodes:
@ -111,7 +111,7 @@ class TestVerifier(TestCase):
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
ep = export_for_training(M(), (torch.randn(10, 10),), strict=True)
ep = export(M(), (torch.randn(10, 10),), strict=True)
ep.validate()
def test_ep_verifier_invalid_param(self) -> None:
@ -125,7 +125,7 @@ class TestVerifier(TestCase):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + self.a
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True)
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
# Parameter doesn't exist in the state dict
ep.graph_signature.input_specs[0] = InputSpec(
@ -150,7 +150,7 @@ class TestVerifier(TestCase):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y + self.a
ep = export_for_training(M(), (torch.randn(100), torch.randn(100)), strict=True)
ep = export(M(), (torch.randn(100), torch.randn(100)), strict=True)
# Buffer doesn't exist in the state dict
ep.graph_signature.input_specs[0] = InputSpec(
@ -182,9 +182,7 @@ class TestVerifier(TestCase):
self.my_buffer2.add_(1.0)
return output
ep = export_for_training(
M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True
)
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
ep.validate()
def test_ep_verifier_invalid_output(self) -> None:
@ -207,9 +205,7 @@ class TestVerifier(TestCase):
self.my_buffer2.add_(1.0)
return output
ep = export_for_training(
M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True
)
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)), strict=True)
output_node = list(ep.graph.nodes)[-1]
output_node.args = (

View File

@ -6,7 +6,7 @@ from typing import Callable
import torch
import torch.nn.functional as F
from torch.export import export_for_training
from torch.export import export
from torch.fx import symbolic_trace
from torch.fx.experimental.proxy_tensor import make_fx
@ -172,7 +172,7 @@ class TestMatcher(JitTestCase):
torch.randn(1, 3, 3, 3) * 10,
torch.randn(3, 3, 3, 3),
)
pattern_gm = export_for_training(
pattern_gm = export(
WrapperModule(pattern), example_inputs, strict=True
).module()
before_split_res = pattern_gm(*example_inputs)
@ -203,11 +203,11 @@ class TestMatcher(JitTestCase):
torch.randn(1, 3, 3, 3) * 10,
torch.randn(3, 3, 3, 3),
)
pattern_gm = export_for_training(
pattern_gm = export(
WrapperModule(pattern), example_inputs, strict=True
).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = export_for_training(
target_gm = export(
WrapperModule(target_graph), example_inputs, strict=True
).module()
internal_matches = matcher.match(target_gm.graph)
@ -248,11 +248,9 @@ class TestMatcher(JitTestCase):
return linear, {"linear": linear, "x": x}
example_inputs = (torch.randn(3, 5),)
pattern_gm = export_for_training(
Pattern(), example_inputs, strict=True
).module()
pattern_gm = export(Pattern(), example_inputs, strict=True).module()
matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
target_gm = export_for_training(M(), example_inputs, strict=True).module()
target_gm = export(M(), example_inputs, strict=True).module()
internal_matches = matcher.match(target_gm.graph)
for internal_match in internal_matches:
name_node_map = internal_match.name_node_map

View File

@ -34,7 +34,7 @@ from torch._library import capture_triton
from torch._utils_internal import full_aoti_runtime_assert
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.export import Dim, export, export_for_training
from torch.export import Dim, export
from torch.export.pt2_archive._package import load_pt2
from torch.testing import FileCheck
from torch.testing._internal import common_utils
@ -2525,9 +2525,7 @@ class AOTInductorTestsTemplate:
config.patch({"freezing": True, "aot_inductor.force_mmap_weights": True}),
torch.no_grad(),
):
exported_model = export_for_training(
model, example_inputs, strict=True
).module()
exported_model = export(model, example_inputs, strict=True).module()
quantizer = X86InductorQuantizer()
quantizer.set_global(
xiq.get_default_x86_inductor_quantization_config(reduce_range=True)

View File

@ -24,7 +24,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, raise_on_run_directly
@ -101,7 +101,7 @@ class TestDuplicateDQPass(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibrate

View File

@ -102,7 +102,7 @@ class TestMetaDataPorting(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = torch.export.export_for_training(m, example_inputs, strict=True).module()
m = torch.export.export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Calibrate

View File

@ -19,7 +19,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import TestHelperModules
from torch.testing._internal.common_utils import (
IS_WINDOWS,
@ -86,7 +86,7 @@ class TestNumericDebugger(TestCase):
def test_simple(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
debug_handle_map = self._extract_debug_handles(ep)
@ -96,7 +96,7 @@ class TestNumericDebugger(TestCase):
def test_control_flow(self):
m = TestHelperModules.ControlFlow()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -107,7 +107,7 @@ class TestNumericDebugger(TestCase):
def test_quantize_pt2e_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
@ -167,14 +167,14 @@ class TestNumericDebugger(TestCase):
def test_re_export_preserve_handle(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
self._assert_each_node_has_debug_handle(ep)
debug_handle_map_ref = self._extract_debug_handles(ep)
ep_reexport = export_for_training(m, example_inputs, strict=True)
ep_reexport = export(m, example_inputs, strict=True)
self._assert_each_node_has_debug_handle(ep_reexport)
debug_handle_map = self._extract_debug_handles(ep_reexport)
@ -184,7 +184,7 @@ class TestNumericDebugger(TestCase):
def test_run_decompositions_same_handle_id(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -209,7 +209,7 @@ class TestNumericDebugger(TestCase):
for m in test_models:
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
self._assert_each_node_has_debug_handle(ep)
@ -232,7 +232,7 @@ class TestNumericDebugger(TestCase):
def test_prepare_for_propagation_comparison(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_logger = prepare_for_propagation_comparison(m)
@ -249,7 +249,7 @@ class TestNumericDebugger(TestCase):
def test_extract_results_from_loggers(self):
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
@ -274,7 +274,7 @@ class TestNumericDebugger(TestCase):
def test_extract_results_from_loggers_list_output(self):
m = TestHelperModules.Conv2dWithSplit()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
m = ep.module()
m_ref_logger = prepare_for_propagation_comparison(m)
@ -304,7 +304,7 @@ class TestNumericDebugger(TestCase):
def test_added_node_gets_unique_id(self) -> None:
m = TestHelperModules.Conv2dThenConv1d()
example_inputs = m.example_inputs()
ep = export_for_training(m, example_inputs, strict=True)
ep = export(m, example_inputs, strict=True)
generate_numeric_debug_handle(ep)
ref_handles = self._extract_debug_handles(ep)
ref_counter = Counter(ref_handles.values())

View File

@ -39,7 +39,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
OP_TO_ANNOTATOR,
QuantizationConfig,
)
from torch.export import export_for_training
from torch.export import export
from torch.fx import Node
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
@ -767,7 +767,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5))
# program capture
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, BackendAQuantizer())
# make sure the two observers for input are shared
conv_output_obs = []
@ -827,7 +827,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
)
# program capture
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
# make sure the two input observers and output are shared
@ -1146,7 +1146,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
)
# program capture
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = BackendAQuantizer()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
@ -1296,7 +1296,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
example_inputs = torch.randn(1, 2, 3, 3)
m = export_for_training(m, (example_inputs,), strict=True).module()
m = export(m, (example_inputs,), strict=True).module()
with self.assertRaises(Exception):
m = prepare_pt2e(m, BackendAQuantizer())
@ -1419,7 +1419,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
weight_meta = None
for n in m.graph.nodes:
if (
@ -1506,7 +1506,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().eval()
quantizer = TestQuantizer()
example_inputs = (torch.randn(1, 2, 3, 3),)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
node_occurrence = {
@ -1557,7 +1557,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
torch.randn(1, 2, 3, 3),
torch.randn(1, 2, 3, 3),
)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
node_occurrence = {
@ -1812,7 +1812,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1),)
m = M().train()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
if inplace:
target = torch.ops.aten.dropout_.default
else:
@ -1877,7 +1877,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
# Assert that batch norm op exists and is in train mode
bn_node = self._get_node(m, bn_train_op)
@ -1908,7 +1908,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m.train()
# After export: this is not OK
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
with self.assertRaises(NotImplementedError):
m.eval()
with self.assertRaises(NotImplementedError):
@ -1949,7 +1949,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
bn_train_op, bn_eval_op = self._get_bn_train_eval_ops()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool):
targets = [n.target for n in m.graph.nodes]
@ -2015,7 +2015,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = M().train()
example_inputs = (torch.randn(1, 3, 3, 3),)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
torch.ao.quantization.allow_exported_model_train_eval(m)
# Mock m.recompile() to count how many times it's been called
@ -2047,7 +2047,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def test_model_is_exported(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)
exported_gm = export_for_training(m, example_inputs, strict=True).module()
exported_gm = export(m, example_inputs, strict=True).module()
fx_traced_gm = torch.fx.symbolic_trace(m, example_inputs)
self.assertTrue(
torch.ao.quantization.pt2e.export_utils.model_is_exported(exported_gm)
@ -2065,9 +2065,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(is_per_channel=True, is_qat=True)
)
m.conv_bn_relu = export_for_training(
m.conv_bn_relu, example_inputs, strict=True
).module()
m.conv_bn_relu = export(m.conv_bn_relu, example_inputs, strict=True).module()
m.conv_bn_relu = prepare_qat_pt2e(m.conv_bn_relu, quantizer)
m(*example_inputs)
m.conv_bn_relu = convert_pt2e(m.conv_bn_relu)
@ -2075,7 +2073,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
quantizer = XNNPACKQuantizer().set_module_type(
torch.nn.Linear, get_symmetric_quantization_config(is_per_channel=False)
)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m = convert_pt2e(m)
@ -2247,7 +2245,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
def dynamic_quantize_pt2e(model, example_inputs):
torch._dynamo.reset()
model = export_for_training(model, example_inputs, strict=True).module()
model = export(model, example_inputs, strict=True).module()
# Per channel quantization for weight
# Dynamic quantization for activation
# Please read a detail: https://fburl.com/code/30zds51q
@ -2462,7 +2460,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
m = M()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer().set_global(
get_symmetric_quantization_config(),
)
@ -2544,7 +2542,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
edge_or_node_to_obs_or_fq[x] = new_observer
example_inputs = (torch.rand(1, 32, 16, 16),)
gm = export_for_training(Model().eval(), example_inputs, strict=True).module()
gm = export(Model().eval(), example_inputs, strict=True).module()
gm = prepare_pt2e(gm, BackendAQuantizer())
gm = convert_pt2e(gm)
for n in gm.graph.nodes:
@ -2571,9 +2569,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
"ConvWithBNRelu" in node.meta["nn_module_stack"]["L__self__"][1]
)
m.conv_bn_relu = export_for_training(
m.conv_bn_relu, example_inputs, strict=True
).module()
m.conv_bn_relu = export(m.conv_bn_relu, example_inputs, strict=True).module()
for node in m.conv_bn_relu.graph.nodes:
if node.op not in ["placeholder", "output", "get_attr"]:
check_nn_module(node)

View File

@ -34,7 +34,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
@ -140,9 +140,7 @@ class PT2EQATTestCase(QuantizationTestCase):
is_per_channel=is_per_channel, is_qat=True
)
)
model_pt2e = export_for_training(
model_pt2e, example_inputs, strict=True
).module()
model_pt2e = export(model_pt2e, example_inputs, strict=True).module()
model_pt2e = prepare_qat_pt2e(model_pt2e, quantizer)
torch.manual_seed(MANUAL_SEED)
after_prepare_result_pt2e = model_pt2e(*example_inputs)
@ -229,7 +227,7 @@ class PT2EQATTestCase(QuantizationTestCase):
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel, is_qat=True)
)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -618,7 +616,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
m = M(self.conv_class, self.bn_class, backbone)
quantizer = XNNPACKQuantizer()
quantizer.set_global(get_symmetric_quantization_config(is_qat=True))
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
m = convert_pt2e(m)
@ -676,7 +674,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_bias_derived_qspec(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -723,7 +721,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_per_channel_weight_custom_dtype(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnInt32WeightQuantizer()
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -777,7 +775,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
def test_qat_conv_bn_per_channel_weight_bias(self):
m = self._get_conv_bn_model()
example_inputs = self.example_inputs
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = ConvBnDerivedBiasQuantizer(is_per_channel=True)
m = prepare_qat_pt2e(m, quantizer)
m(*example_inputs)
@ -834,7 +832,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
it into conv in `convert_pt2e` even in train mode.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export_for_training(m, self.example_inputs, strict=True).module()
m = export(m, self.example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_global(
get_symmetric_quantization_config(is_per_channel=False, is_qat=True),
@ -850,7 +848,7 @@ class TestQuantizePT2EQAT_ConvBn_Base(PT2EQATTestCase):
Test that batch norm stat tracking (which results in an add_ tensor) is removed when folding batch norm.
"""
m = self._get_conv_bn_model(has_conv_bias=False, has_bn=True, has_relu=False)
m = export_for_training(m, self.example_inputs, strict=True).module()
m = export(m, self.example_inputs, strict=True).module()
def _has_add_(graph):
for node in graph.nodes:
@ -1115,9 +1113,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
in_channels = child.linear1.weight.size(1)
example_input = (torch.rand((1, in_channels)),)
traced_child = export_for_training(
child, example_input, strict=True
).module()
traced_child = export(child, example_input, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=True, is_qat=True
@ -1148,7 +1144,7 @@ class TestQuantizeMixQATAndPTQ(QuantizationTestCase):
self._convert_qat_linears(model)
model(*example_inputs)
model_pt2e = export_for_training(model, example_inputs, strict=True).module()
model_pt2e = export(model, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantizer.set_module_type(torch.nn.Linear, None)

View File

@ -10,7 +10,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
@ -34,7 +34,7 @@ class TestPT2ERepresentation(QuantizationTestCase):
) -> torch.nn.Module:
# resetting dynamo cache
torch._dynamo.reset()
model = export_for_training(model, example_inputs, strict=True).module()
model = export(model, example_inputs, strict=True).module()
model_copy = copy.deepcopy(model)
model = prepare_pt2e(model, quantizer)

View File

@ -17,7 +17,7 @@ from torch.ao.quantization.quantizer.x86_inductor_quantizer import (
QUANT_ANNOTATION_KEY,
X86InductorQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
@ -668,7 +668,7 @@ class X86InductorQuantTestCase(QuantizationTestCase):
# program capture
m = copy.deepcopy(m_eager)
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
# QAT Model failed to deepcopy
export_model = m if is_qat else copy.deepcopy(m)
@ -2344,7 +2344,7 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.

View File

@ -29,7 +29,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
PT2EQuantizationTestCase,
@ -362,7 +362,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
)
example_inputs = (torch.randn(2, 2),)
m = M().eval()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
# Use a linear count instead of names because the names might change, but
# the order should be the same.
@ -498,7 +498,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
example_inputs = (torch.randn(1, 3, 5, 5),)
# program capture
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
m = prepare_pt2e(m, quantizer)
m(*example_inputs)
@ -763,9 +763,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
model_fx = _convert_to_reference_decomposed_fx(model_fx)
with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph, example_inputs, strict=True
).module()
model_graph = export(model_graph, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
@ -825,9 +823,7 @@ class TestXNNPACKQuantizer(PT2EQuantizationTestCase):
model_fx = _convert_to_reference_decomposed_fx(model_fx)
with torchdynamo.config.patch(allow_rnn=True):
model_graph = export_for_training(
model_graph, example_inputs, strict=True
).module()
model_graph = export(model_graph, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=False
@ -1035,7 +1031,7 @@ class TestXNNPACKQuantizerModels(PT2EQuantizationTestCase):
m = torchvision.models.resnet18().eval()
m_copy = copy.deepcopy(m)
# program capture
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)

View File

@ -27,9 +27,7 @@ class TestQuantizePT2EModels(TestCase):
m = m.eval()
input_shape = (1, 3, 224, 224)
example_inputs = (torch.randn(input_shape),)
m = torch.export.export_for_training(
m, copy.deepcopy(example_inputs), strict=True
).module()
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
m(*example_inputs)
m = export.export(m, copy.deepcopy(example_inputs))
ops = _get_ops_list(m.graph_module)

View File

@ -50,7 +50,7 @@ def lower_pt2e_quantized_to_x86(
m.recompile()
lowered_model = (
torch.export.export_for_training(model, example_inputs, strict=True)
torch.export.export(model, example_inputs, strict=True)
.run_decompositions(_post_autograd_decomp_table())
.module()
)

View File

@ -356,7 +356,7 @@ def _get_aten_graph_module_for_pattern(
[x.cuda() if isinstance(x, torch.Tensor) else x for x in example_inputs]
)
aten_pattern = torch.export.export_for_training(
aten_pattern = torch.export.export(
pattern, # type: ignore[arg-type]
example_inputs,
kwargs,

View File

@ -1002,9 +1002,7 @@ class Pipe(torch.nn.Module):
) -> ExportedProgram:
logger.info("Tracing model ...")
try:
ep = torch.export.export_for_training(
mod, example_args, example_kwargs, strict=True
)
ep = torch.export.export(mod, example_args, example_kwargs, strict=True)
except Exception as e:
raise RuntimeError(
"It seems that we cannot capture your model as a full graph. "

View File

@ -58,7 +58,7 @@ from torch.ao.quantization.quantizer.xnnpack_quantizer import (
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.export import export
from torch.jit.mobile import _load_for_lite_interpreter
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import TEST_WITH_ROCM, TestCase
@ -1513,7 +1513,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
{0: torch.export.Dim("dim")} if i == 0 else None
for i in range(len(example_inputs))
)
m = export_for_training(
m = export(
m,
example_inputs,
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
@ -1554,7 +1554,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
m_fx = _convert_to_reference_decomposed_fx(
m_fx, backend_config=backend_config
)
m_fx = export_for_training(
m_fx = export(
m_fx,
example_inputs,
dynamic_shapes=dynamic_shapes if export_with_dynamic_shape else None,
@ -1578,7 +1578,7 @@ class PT2EQuantizationTestCase(QuantizationTestCase):
# resetting dynamo cache
torch._dynamo.reset()
m = export_for_training(m, example_inputs, strict=True).module()
m = export(m, example_inputs, strict=True).module()
if is_qat:
m = prepare_qat_pt2e(m, quantizer)
else:
@ -3183,12 +3183,15 @@ class TestHelperModules:
x = self.adaptive_avg_pool2d(x)
return x
class ConvWithBNRelu(torch.nn.Module):
def __init__(self, relu, dim=2, bn=True, bias=True, padding=0):
super().__init__()
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
bns = {1: torch.nn.BatchNorm1d, 2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
bns = {
1: torch.nn.BatchNorm1d,
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}
self.conv = convs[dim](3, 3, 3, bias=bias, padding=padding)
if bn:
@ -3394,7 +3397,7 @@ def _generate_qdq_quantized_model(
maybe_no_grad = contextlib.nullcontext() if is_qat else torch.no_grad()
with maybe_no_grad:
export_model = export_for_training(mod, inputs, strict=True).module(check_guards=False)
export_model = export(mod, inputs, strict=True).module(check_guards=False)
quantizer = (
quantizer
if quantizer