From d9f94e0d7d96e52a636899a1b104cf610dd1a905 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Fri, 17 Oct 2025 16:38:12 -0700 Subject: [PATCH] [dynamo] Support fx.traceback.annotate as decorator (#165805) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165805 Approved by: https://github.com/Lucaskabela, https://github.com/SherlockNoMad, https://github.com/yushangdi --- test/dynamo/test_fx_annotate.py | 50 ++++++++++++++++++++++++++++++++ torch/_dynamo/variables/torch.py | 6 +++- 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_fx_annotate.py b/test/dynamo/test_fx_annotate.py index ede0b51ef123..337ce0f5764c 100644 --- a/test/dynamo/test_fx_annotate.py +++ b/test/dynamo/test_fx_annotate.py @@ -238,6 +238,56 @@ class AnnotateTests(torch._dynamo.test_case.TestCase): ('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950 ) + def test_as_decorator(self): + class Mod(torch.nn.Module): + @fx_traceback.annotate({"fdsp_bucket": 0}) + def sin(self, x): + return torch.sin(x) + + def forward(self, x): + with fx_traceback.annotate({"pp_stage": 0}): + sin = self.sin(x) + sub = sin - 2 + mul = sub * 2 + div = mul / 3 + return div + + m = Mod() + backend = AotEagerAndRecordGraphs() + opt_m = torch.compile(m, backend=backend, fullgraph=True) + x = torch.randn(10, requires_grad=True) + m(x) + opt_m(x).sum().backward() + + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0]) + fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0]) + bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0]) + self.assertExpectedInline( + str(dynamo_metadata), + """\ +('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sub', {'pp_stage': 0}) +('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(fw_metadata), + """\ +('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sub', {'pp_stage': 0}) +('call_function', 'mul', {'pp_stage': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(bw_metadata), + """\ +('call_function', 'mul_1', {'pp_stage': 0}) +('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950 + ) + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index d659f3a24d86..1e39187274cc 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -126,6 +126,7 @@ supported_ctx_manager_classes = dict.fromkeys( torch.cpu.amp.autocast_mode.autocast, torch.cuda.amp.autocast_mode.autocast, torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] # We'll let Dynamo inline into the contextlib part of these context # manager instances, all the way till it invokes the wrapped function # itself (at which point we wrap it back to special context manager @@ -364,7 +365,10 @@ class TorchCtxManagerClassVariable(BaseTorchVariable): assert len(args) <= 1 and len(kwargs) == 0 inf_mode = args[0].as_python_constant() if len(args) == 1 else True return InferenceModeVariable.create(tx, inf_mode) - elif self.value is torch.fx.traceback.annotate: + elif self.value in ( + torch.fx.traceback.annotate, + torch.fx.traceback.annotate.__wrapped__, # type: ignore[attr-defined] + ): assert len(args) <= 1 and len(kwargs) == 0 return FxTracebackAnnotateVariable( args[0].as_python_constant(), source=self.source