Compare commits

...

1 Commits

Author SHA1 Message Date
2de09a0d6f [export] preserve_node_meta by default (#165972)
Summary:

specify annotate() with requiring preserve_node_meta()

Test Plan: test_export

Reviewed By: angelayi, malaybag

Differential Revision: D85122326
2025-10-27 10:52:48 -07:00
3 changed files with 22 additions and 7 deletions

View File

@ -721,6 +721,20 @@ class TestExport(TestCase):
)
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
def test_fx_annotate(self):
class Foo(torch.nn.Module):
def forward(self, x):
x += 1
with torch.fx.traceback.annotate({"a": "b"}):
x += 1
x += 1
return x
ep = export(Foo(), (torch.randn(2),))
add_1 = list(ep.graph.nodes)[2]
self.assertTrue("custom" in add_1.meta and add_1.meta["custom"].get("a") == "b")
@requires_gpu
def test_flex_attention_export(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

View File

@ -9091,14 +9091,12 @@ class GraphModule(torch.nn.Module):
class true_graph_0(torch.nn.Module):
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
item: "Sym(u0)" = torch.ops.aten.item.default(b1); b1 = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
return (mul,)
class false_graph_0(torch.nn.Module):
def forward(self, c: "f32[10]", b1: "i64[1]", b2: "i64[1]"):
item: "Sym(u1)" = torch.ops.aten.item.default(b2); b2 = None
mul: "f32[10]" = torch.ops.aten.mul.Tensor(c, item); c = item = None
return (mul,)
""", # noqa: B950
@ -9183,7 +9181,6 @@ class GraphModule(torch.nn.Module):
class false_graph_0(torch.nn.Module):
def forward(self, x: "f32[s68, 3]", sym_size_int_4: "Sym(s17)", sym_size_int_5: "Sym(s68)", z: "f32[s68, 3]"):
mul: "f32[s68, 3]" = torch.ops.aten.mul.Tensor(z, sym_size_int_5); z = sym_size_int_5 = None
add: "f32[s68, 3]" = torch.ops.aten.add.Tensor(x, mul); x = mul = None
return (add,)
""", # noqa: B950

View File

@ -798,7 +798,10 @@ def _export_to_torch_ir(
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
with (
torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)),
torch.fx.traceback.preserve_node_meta(),
):
try:
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
_ExportModuleSpecTrackerDict()
@ -888,6 +891,7 @@ def _export_to_aten_ir(
_ignore_backend_decomps(),
_compiling_state_context(),
custom_triton_ops_decomposition_ctx(),
torch.fx.traceback.preserve_node_meta(),
):
gm, graph_signature = transform(aot_export_module)(
mod,
@ -1916,9 +1920,8 @@ def _non_strict_export(
in mod._forward_pre_hooks.values()
):
_check_input_constraints_pre_hook(mod, args, kwargs)
with torch.fx.traceback.preserve_node_meta():
args = (*args, *kwargs.values())
tree_out = torch.fx.Interpreter(mod).run(*args)
args = (*args, *kwargs.values())
tree_out = torch.fx.Interpreter(mod).run(*args)
else:
tree_out = mod(*args, **kwargs)
flat_outs, out_spec = pytree.tree_flatten(tree_out)
@ -2015,6 +2018,7 @@ def _non_strict_export(
),
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
_override_builtin_ops(),
torch.fx.traceback.preserve_node_meta(),
):
aten_export_artifact = _to_aten_func( # type: ignore[operator]
patched_mod,