Compare commits

...

1 Commits

Author SHA1 Message Date
d8659552e2 fx annotate 2025-10-14 22:57:15 -07:00
2 changed files with 22 additions and 4 deletions

View File

@ -720,6 +720,20 @@ class TestExport(TestCase):
)
self.assertEqual(node.meta["from_node"][-1].graph_id, graph_id)
def test_fx_annotate(self):
class Foo(torch.nn.Module):
def forward(self, x):
x += 1
with torch.fx.traceback.annotate({"a": "b"}):
x += 1
x += 1
return x
ep = export(Foo(), (torch.randn(2),))
add_1 = list(ep.graph.nodes)[2]
self.assertTrue("custom" in add_1.meta and add_1.meta["custom"].get("a") == "b")
@requires_gpu
def test_flex_attention_export(self):
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

View File

@ -809,7 +809,10 @@ def _export_to_torch_ir(
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards,
)
with torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)):
with (
torch._dynamo.config.patch(dataclasses.asdict(dynamo_cfg)),
torch.fx.traceback.preserve_node_meta(),
):
try:
module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = (
_ExportModuleSpecTrackerDict()
@ -899,6 +902,7 @@ def _export_to_aten_ir(
_ignore_backend_decomps(),
_compiling_state_context(),
custom_triton_ops_decomposition_ctx(),
torch.fx.traceback.preserve_node_meta(),
):
gm, graph_signature = transform(aot_export_module)(
mod,
@ -1927,9 +1931,8 @@ def _non_strict_export(
in mod._forward_pre_hooks.values()
):
_check_input_constraints_pre_hook(mod, args, kwargs)
with torch.fx.traceback.preserve_node_meta():
args = (*args, *kwargs.values())
tree_out = torch.fx.Interpreter(mod).run(*args)
args = (*args, *kwargs.values())
tree_out = torch.fx.Interpreter(mod).run(*args)
else:
tree_out = mod(*args, **kwargs)
flat_outs, out_spec = pytree.tree_flatten(tree_out)
@ -2026,6 +2029,7 @@ def _non_strict_export(
),
_fakify_module_inputs(fake_args, fake_kwargs, fake_mode),
_override_builtin_ops(),
torch.fx.traceback.preserve_node_meta(),
):
aten_export_artifact = _to_aten_func( # type: ignore[operator]
patched_mod,