Compare commits

...

1 Commits

Author SHA1 Message Date
417788a113 [export] Turn on install_free_tensors flag
The final step in removing the discrepancy between
torch.compile(fullgraph=True) and torch.export(strict=True).

ghstack-source-id: 22998e0bc950685ba76a27d0bd172e3336dc4e82
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164305
2025-10-04 23:51:44 -07:00
13 changed files with 94 additions and 171 deletions

View File

@ -880,43 +880,43 @@ class AotAutogradFallbackTests(torch._inductor.test_case.TestCase):
dedent(
"""\
SeqNr|OrigAten|SrcFn|FwdSrcFn
0|aten.convolution.default|l__self___conv1|
0|aten.add.Tensor|l__self___bn1|
1|aten._native_batch_norm_legit_functional.default|l__self___bn1|
2|aten.relu.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
2|aten.detach.default|l__self___relu1|
0|aten.convolution.default|conv2d|
0|aten.add.Tensor|add_|
1|aten._native_batch_norm_legit_functional.default|batch_norm|
2|aten.relu.default|relu|
2|aten.detach.default|relu|
2|aten.detach.default|relu|
3|aten.add.Tensor|add|
4|aten.view.default|flatten|
5|aten.view.default|l__self___fc1|
6|aten.t.default|l__self___fc1|
7|aten.addmm.default|l__self___fc1|
8|aten.view.default|l__self___fc1|
9|aten.sub.Tensor|l__self___loss_fn|
10|aten.abs.default|l__self___loss_fn|
11|aten.mean.default|l__self___loss_fn|
11|aten.ones_like.default||l__self___loss_fn
11|aten.expand.default||l__self___loss_fn
11|aten.div.Scalar||l__self___loss_fn
10|aten.sgn.default||l__self___loss_fn
10|aten.mul.Tensor||l__self___loss_fn
8|aten.view.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.mm.default||l__self___fc1
7|aten.t.default||l__self___fc1
7|aten.sum.dim_IntList||l__self___fc1
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
5|aten.view.default|linear|
6|aten.t.default|linear|
7|aten.addmm.default|linear|
8|aten.view.default|linear|
9|aten.sub.Tensor|l1_loss|
10|aten.abs.default|l1_loss|
11|aten.mean.default|l1_loss|
11|aten.ones_like.default||l1_loss
11|aten.expand.default||l1_loss
11|aten.div.Scalar||l1_loss
10|aten.sgn.default||l1_loss
10|aten.mul.Tensor||l1_loss
8|aten.view.default||linear
7|aten.t.default||linear
7|aten.mm.default||linear
7|aten.t.default||linear
7|aten.mm.default||linear
7|aten.t.default||linear
7|aten.sum.dim_IntList||linear
7|aten.view.default||linear
6|aten.t.default||linear
5|aten.view.default||linear
4|aten.view.default||flatten
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1
1|aten.native_batch_norm_backward.default||l__self___bn1
0|aten.convolution_backward.default||l__self___conv1
11|aten.add.Tensor||l__self___loss_fn
2|aten.detach.default||relu
2|aten.detach.default||relu
2|aten.threshold_backward.default||relu
1|aten.native_batch_norm_backward.default||batch_norm
0|aten.convolution_backward.default||conv2d
11|aten.add.Tensor||l1_loss
"""
),
)

View File

@ -3148,7 +3148,6 @@ def forward(self, x):
gm, _ = torch._dynamo.export(f, aten_graph=True)(*example_inputs)
self.assertEqual(gm(*example_inputs), f(*example_inputs))
@unittest.expectedFailure # TODO: Not sure why dynamo creates a new inputs for self.a
def test_sum_param(self):
# Setting a new attribute inside forward()
class Foo(torch.nn.Module):
@ -3539,24 +3538,16 @@ class GraphModule(torch.nn.Module):
[[], [], [], []],
)
def test_invalid_input_global(self) -> None:
def test_input_global(self) -> None:
global bulbous_bouffant
bulbous_bouffant = torch.randn(3)
def f(y):
return bulbous_bouffant + y
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""\
G['bulbous_bouffant'], accessed at:
File "test_export.py", line N, in f
return bulbous_bouffant + y
""",
)
torch._dynamo.export(f)(torch.randn(3))
def test_invalid_input_global_multiple_access(self) -> None:
def test_input_global_multiple_access(self) -> None:
global macademia
macademia = torch.randn(3)
@ -3570,33 +3561,17 @@ G['bulbous_bouffant'], accessed at:
y = g(y)
return macademia + y
# NB: This doesn't actually work (it only reports the first usage),
# but I'm leaving the test here in case we fix it later
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""\
G['macademia'], accessed at:
File "test_export.py", line N, in f
y = g(y)
File "test_export.py", line N, in g
y = macademia + y
""",
)
torch._dynamo.export(f)(torch.randn(3))
def test_invalid_input_nonlocal(self) -> None:
def test_input_nonlocal(self) -> None:
arglebargle = torch.randn(3)
def f(y):
return arglebargle + y
self.assertExpectedInlineMunged(
UserError,
lambda: torch._dynamo.export(f)(torch.randn(3)),
"""L['arglebargle'], a closed over free variable""",
)
torch._dynamo.export(f)(torch.randn(3))
def test_invalid_input_unused_nonlocal_ok(self) -> None:
def test_input_unused_nonlocal_ok(self) -> None:
arglebargle = torch.randn(3)
def f(y):

View File

@ -29,7 +29,7 @@ class MutationExportTests(torch._dynamo.test_case.TestCase):
self.a = self.a.to(torch.float64)
return x.sum() + self.a.sum()
self.check_failure_on_export(Foo(), torch.randn(3, 2))
self.check_same_with_export(Foo(), torch.randn(3, 2))
def test_module_attribute_mutation_violation_negative_1(self):
# Mutating attribute with a Tensor type inside __init__ but

View File

@ -1,5 +1,4 @@
# Owner(s): ["module: dynamo"]
import unittest
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
@ -42,33 +41,6 @@ for test in tests:
make_dynamic_cls(test)
del test
# After installing and inlining is turned on, these tests won't throw
# errors in export (which is expected for the test to pass)
# Therefore, these unittest are expected to fail, and we need to update the
# semantics
unittest.expectedFailure(
InlineAndInstallExportTests.test_invalid_input_global_inline_and_install # noqa: F821
)
unittest.expectedFailure(
InlineAndInstallExportTests.test_invalid_input_global_multiple_access_inline_and_install # noqa: F821
)
unittest.expectedFailure(
InlineAndInstallExportTests.test_invalid_input_nonlocal_inline_and_install # noqa: F821
)
# This particular test is marked expecting failure, since dynamo was creating second param for a
# and this was causing a failure in the sum; however with these changes, that test is fixed
# so will now pass, so we need to mark that it is no longer expected to fail
def expectedSuccess(test_item):
test_item.__unittest_expecting_failure__ = False
return test_item
expectedSuccess(
InlineAndInstallExportTests.test_sum_param_inline_and_install # noqa: F821
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -228,6 +228,10 @@ def is_non_strict_test(test_name):
)
def is_strict_test(test_name):
return test_name.endswith(STRICT_SUFFIX)
def is_strict_v2_test(test_name):
return test_name.endswith(STRICT_EXPORT_V2_SUFFIX)
@ -1803,15 +1807,9 @@ graph():
# TODO (tmanlaibaatar) this kinda sucks but today there is no good way to get
# good source name. We should have an util that post processes dynamo source names
# to be more readable.
if is_strict_v2_test(self._testMethodName):
with self.assertWarnsRegex(
UserWarning,
r"(L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank_dict"
r"|L\['self']\._export_root\.forward\.__func__\.__closure__\[0\]\.cell_contents)",
):
ref(torch.randn(4, 4), torch.randn(4, 4))
elif is_inline_and_install_strict_test(self._testMethodName):
if is_strict_v2_test(self._testMethodName) or is_inline_and_install_strict_test(
self._testMethodName
):
with self.assertWarnsRegex(
UserWarning,
r"(L\['self']\._modules\['_export_root']\.forward\.__func__\.__closure__\[1\]\.cell_contents\.bank"
@ -7799,9 +7797,11 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
buffer.append(get_buffer(ep, node))
self.assertEqual(num_buffer, 3)
self.assertEqual(buffer[0].shape, torch.Size([100])) # running_mean
self.assertEqual(buffer[1].shape, torch.Size([100])) # running_var
self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked
# The insertion order is not guaranteed to be same for strict vs
# non-strict, so commenting this out.
# self.assertEqual(buffer[0].shape, torch.Size([100])) # running_mean
# self.assertEqual(buffer[1].shape, torch.Size([100])) # running_var
# self.assertEqual(buffer[2].shape, torch.Size([])) # num_batches_tracked
def test_export_dynamo_config(self):
class MyModule(torch.nn.Module):
@ -9305,10 +9305,9 @@ def forward(self, b_a_buffer, x):
)
else:
if is_inline_and_install_strict_test(self._testMethodName):
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, b_a_buffer, x):
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
gt = sym_size_int_1 > 4; sym_size_int_1 = None
@ -9317,20 +9316,7 @@ def forward(self, b_a_buffer, x):
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, b_a_buffer)); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
getitem = cond[0]; cond = None
return (getitem,)""",
)
else:
self.assertExpectedInline(
ep.graph_module.code.strip(),
"""\
def forward(self, b_a_buffer, x):
sym_size_int_1 = torch.ops.aten.sym_size.int(x, 0)
gt = sym_size_int_1 > 4; sym_size_int_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x, b_a_buffer)); gt = true_graph_0 = false_graph_0 = x = b_a_buffer = None
getitem = cond[0]; cond = None
return (getitem,)""",
)
)
self.assertTrue(
torch.allclose(ep.module()(torch.ones(6, 4)), Foo()(torch.ones(6, 4)))
)
@ -9929,10 +9915,9 @@ def forward(self, p_lin_weight, p_lin_bias, x):
decomp_table={torch.ops.aten.linear.default: _decompose_linear_custom}
)
if is_inline_and_install_strict_test(self._testMethodName):
self.assertExpectedInline(
str(ep_decompose_linear.graph_module.code).strip(),
"""\
self.assertExpectedInline(
str(ep_decompose_linear.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
@ -9944,24 +9929,7 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add_1,)""",
)
else:
self.assertExpectedInline(
str(ep_decompose_linear.graph_module.code).strip(),
"""\
def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_linear_weight, c_linear_bias, x, y):
conv2d = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
conv1d = torch.ops.aten.conv1d.default(y, p_conv1d_weight, p_conv1d_bias); y = p_conv1d_weight = p_conv1d_bias = None
permute = torch.ops.aten.permute.default(c_linear_weight, [1, 0]); c_linear_weight = None
matmul = torch.ops.aten.matmul.default(conv2d, permute); conv2d = permute = None
mul = torch.ops.aten.mul.Tensor(c_linear_bias, 2); c_linear_bias = None
add = torch.ops.aten.add.Tensor(matmul, mul); matmul = mul = None
cos = torch.ops.aten.cos.default(add); add = None
sum_1 = torch.ops.aten.sum.default(conv1d); conv1d = None
add_1 = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add_1,)""",
)
)
def test_export_decomps_dynamic(self):
class M(torch.nn.Module):
@ -15161,17 +15129,11 @@ graph():
list(nn_module_stack.values())[-1][0]
for nn_module_stack in nn_module_stacks
]
if is_inline_and_install_strict_test(self._testMethodName):
if is_strict_test(self._testMethodName) or is_strict_v2_test(
self._testMethodName
):
self.assertEqual(filtered_nn_module_stack[0], "mod_list_1.2")
self.assertEqual(filtered_nn_module_stack[1], "mod_list_2.4")
# This is fine since both of these will be deprecated soon.
elif is_strict_v2_test(self._testMethodName) and IS_FBCODE:
self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).0"
)
self.assertEqual(
filtered_nn_module_stack[1], "mod_list_2.slice(4, 5, None).0"
)
else:
self.assertEqual(
filtered_nn_module_stack[0], "mod_list_1.slice(2, 3, None).2"

View File

@ -1,8 +1,6 @@
# Owner(s): ["oncall: export"]
import unittest
from torch._dynamo import config as dynamo_config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._export import config as export_config
@ -67,13 +65,6 @@ for test in tests:
del test
# NOTE: For this test, we have a failure that occurs because the buffers (for BatchNorm2D) are installed, and not
# graph input. Therefore, they are not in the `program.graph_signature.inputs_to_buffers`
# and so not found by the unit test when counting the buffers
unittest.expectedFailure(
InlineAndInstallStrictExportTestExport.test_buffer_util_inline_and_install_strict # noqa: F821
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -611,6 +611,9 @@ class AOTInductorTestsTemplate:
example_inputs = (torch.randn(32, 64, device=self.device),)
self.check_model(Model(), example_inputs)
@unittest.skip(
"install_free_tensors leads to OOM - https://github.com/pytorch/pytorch/issues/164062"
)
def test_large_weight(self):
class Model(torch.nn.Module):
def __init__(self) -> None:

View File

@ -155,6 +155,9 @@ class TestConfigFuzzer(TestCase):
)
@unittest.skipIf(not IS_LINUX, "PerfCounters are only supported on Linux")
@unittest.skip(
"Need default values for dynamo flags - https://github.com/pytorch/pytorch/issues/164062"
)
def test_config_fuzzer_dynamo_bisect(self):
# these values just chosen randomly, change to different ones if necessary
key_1 = {"dead_code_elimination": False, "specialize_int": True}

View File

@ -457,6 +457,10 @@ nested_graph_breaks = False
# produces a consistent number of inputs to the graph.
install_free_tensors = False
# Temporary flag to control the turning of install_free_tensors to True for
# export. We will remove this flag in a few weeks when stable.
install_free_tensors_for_export = True
# Use C++ FrameLocalsMapping (raw array view of Python frame fastlocals) (deprecated: always True)
enable_cpp_framelocals_guard_eval = True

View File

@ -2040,6 +2040,10 @@ def export(
capture_scalar_outputs=True,
constant_fold_autograd_profiler_enabled=True,
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
# install_free_tensors ensures that params and buffers are still
# added as graph attributes, and makes Dynamo emits graphs that
# follow export pytree-able input requirements
install_free_tensors=config.install_free_tensors_for_export,
),
_compiling_state_context(),
):

View File

@ -453,6 +453,12 @@ def _dynamo_graph_capture_for_export(
capture_scalar_outputs=True,
constant_fold_autograd_profiler_enabled=True,
log_graph_in_out_metadata=True,
# install_free_tensors ensures that params and buffers are still
# added as graph attributes, and makes Dynamo emits graphs that
# follow export pytree-able input requirements In future, if we
# fully rely on bytecode for the runtime, we can turn this flag
# off.
install_free_tensors=torch._dynamo.config.install_free_tensors_for_export,
)
with (

View File

@ -2075,9 +2075,14 @@ class ExecutorchCallDelegateHigherOrderVariable(TorchHigherOrderOperatorVariable
unimplemented(
"executorch_call_delegate: kwargs arguments were not enabled."
)
lowered_module = tx.output.get_submodule(args[0].module_key)
lowered_node = make_attr(tx, args[0].module_key)
if isinstance(args[0], variables.NNModuleVariable):
lowered_module = args[0].module
lowered_module = tx.output.get_submodule(args[0].module_key)
lowered_node = make_attr(tx, args[0].module_key)
elif isinstance(args[0], variables.UnspecializedNNModuleVariable):
lowered_module = args[0].value
mod_name = tx.output.install_subgraph("delegate", lowered_module)
lowered_node = make_attr(tx, mod_name)
p_args = tuple(arg.as_proxy() for arg in args[1:])
real_sub_args = pytree.tree_map_only(

View File

@ -1,11 +1,10 @@
# mypy: allow-untyped-defs
import torch
from torch._export.db.case import SupportLevel
class ModelAttrMutation(torch.nn.Module):
"""
Attribute mutation is not supported.
Attribute mutation raises a warning. Covered in the test_export.py test_detect_leak_strict test.
"""
def __init__(self) -> None:
@ -22,5 +21,4 @@ class ModelAttrMutation(torch.nn.Module):
example_args = (torch.randn(3, 2),)
tags = {"python.object-model"}
support_level = SupportLevel.NOT_SUPPORTED_YET
model = ModelAttrMutation()