diff --git a/test/dynamo/test_fx_annotate.py b/test/dynamo/test_fx_annotate.py new file mode 100644 index 000000000000..d62465ac57d8 --- /dev/null +++ b/test/dynamo/test_fx_annotate.py @@ -0,0 +1,270 @@ +# Owner(s): ["module: dynamo"] + +import torch +import torch._dynamo.test_case +import torch.fx.traceback as fx_traceback +import torch.utils.checkpoint +from torch._dynamo.test_case import run_tests +from torch._dynamo.testing import AotEagerAndRecordGraphs +from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from torch.testing._internal.triton_utils import requires_cuda_and_triton + + +def checkpoint_wrapper(fn): + def inner(*args): + return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) + + return inner + + +class AnnotateTests(torch._dynamo.test_case.TestCase): + # TODO - should not need this because we should turn this on in Dynamo but + # for some reasons, test fail. + def setUp(self): + super().setUp() + self.cm = torch.fx.traceback.preserve_node_meta() + self.cm.__enter__() + + def tearDown(self): + super().tearDown() + self.cm.__exit__(None, None, None) + + def get_custom_metadata(self, gm): + def helper(gm): + custom_metadata = [] + for node in gm.graph.nodes: + if hasattr(node, "meta") and node.meta.get("custom", None): + custom_metadata.append((node.op, node.name, node.meta["custom"])) + if node.op == "get_attr" and isinstance( + getattr(gm, node.target), torch.fx.GraphModule + ): + custom_metadata.append(helper(getattr(gm, node.target))) + return custom_metadata + + return "\n".join(str(x) for x in helper(gm)) + + def test_annotations(self): + class Mod(torch.nn.Module): + def forward(self, x): + with fx_traceback.annotate({"pp_stage": 0}): + with fx_traceback.annotate({"fdsp_bucket": 0}): + sin = torch.sin(x) + sub = sin - 2 + with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}): + mul = sub * 2 + div = mul / 3 + return div + + m = Mod() + backend = AotEagerAndRecordGraphs() + opt_m = torch.compile(m, backend=backend, fullgraph=True) + x = torch.randn(10, requires_grad=True) + opt_m(x).sum().backward() + + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) + fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) + bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) + self.assertExpectedInline( + str(dynamo_metadata), + """\ +('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sub', {'pp_stage': 0}) +('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950 + ) + self.assertExpectedInline( + str(fw_metadata), + """\ +('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'sub', {'pp_stage': 0}) +('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950 + ) + self.assertExpectedInline( + str(bw_metadata), + """\ +('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1}) +('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0}) +('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950 + ) + + def test_activation_checkpointing(self): + @checkpoint_wrapper + def gn(x): + return torch.sin(x) + + def fn(x): + with fx_traceback.annotate({"ac_sin": 0}): + ac = gn(x) + return torch.sigmoid(ac) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + x = torch.randn(10, requires_grad=True) + opt_fn(x).sum().backward() + + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) + fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) + bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) + self.assertExpectedInline( + str(dynamo_metadata), + """\ +('placeholder', 'l_x_', {'ac_sin': 0}) +('get_attr', 'wrap_body_0', {'ac_sin': 0}) +[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})] +('call_function', 'tag_activation_checkpoint', {'ac_sin': 0}) +('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(fw_metadata), + """('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(bw_metadata), + """\ +('call_function', 'cos', {'ac_sin': 0}) +('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950 + ) + + def test_activation_checkpointing_annotation_inside(self): + @checkpoint_wrapper + def gn(x): + x = x + 1 + with fx_traceback.annotate({"stage": 0}): + p = torch.sin(x) + return p + 1 + + def fn(x): + ac = gn(x) + return torch.sigmoid(ac) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + x = torch.randn(10, requires_grad=True) + opt_fn(x).sum().backward() + + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) + fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) + bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) + self.assertExpectedInline( + str(dynamo_metadata), + """[('call_function', 'p', {'stage': 0})]""", # noqa: B950 + ) + self.assertExpectedInline( + str(fw_metadata), + """('call_function', 'sin', {'stage': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(bw_metadata), + """\ +('call_function', 'cos', {'stage': 0}) +('call_function', 'mul', {'stage': 0})""", # noqa: B950 + ) + + @requires_cuda_and_triton + def test_ac_flex_attention(self): + def _squared(score, b, h, m, n): + return score * score + + def mask_mod(b, h, q, k): + return q >= 0 + + a = 12 + b = 64 + block_mask = create_block_mask(mask_mod, None, None, a * b, a * b) + + def gn(x: torch.Tensor): + with fx_traceback.annotate({"compile_inductor": 0}): + return flex_attention( + x, x, x, block_mask=block_mask, score_mod=_squared + ) + + def fn(x): + x = torch.sin(x) + x = gn(x) + return torch.cos(x) + + x = torch.randn( + 1, + 1, + a * b, + b, + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + opt_fn(x).sum().backward() + + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + + dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) + fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) + bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) + self.assertExpectedInline( + str(dynamo_metadata), + """\ +('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0}) +('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0}) +('get_attr', 'score_mod_0', {'compile_inductor': 0}) +[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})] +('get_attr', 'mask_fn_0', {'compile_inductor': 0}) +[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})] +('call_function', 'flex_attention', {'compile_inductor': 0}) +('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(fw_metadata), + """\ +('get_attr', 'sdpa_score0', {'compile_inductor': 0}) +[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})] +('get_attr', 'sdpa_mask0', {'compile_inductor': 0}) +[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})] +('call_function', 'flex_attention', {'compile_inductor': 0}) +('call_function', 'getitem', {'compile_inductor': 0}) +('call_function', 'getitem_1', {'compile_inductor': 0}) +('call_function', 'detach_1', {'compile_inductor': 0}) +('call_function', 'detach_4', {'compile_inductor': 0}) +('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950 + ) + self.assertExpectedInline( + str(bw_metadata), + """\ +('placeholder', 'getitem', {'compile_inductor': 0}) +('placeholder', 'detach_5', {'compile_inductor': 0}) +('call_function', 'zeros', {'compile_inductor': 0}) +('call_function', 'detach', {'compile_inductor': 0}) +('call_function', 'detach_2', {'compile_inductor': 0}) +('call_function', 'detach_3', {'compile_inductor': 0}) +('get_attr', 'fw_graph0', {'compile_inductor': 0}) +[] +('get_attr', 'joint_graph0', {'compile_inductor': 0}) +[] +('get_attr', 'mask_graph0', {'compile_inductor': 0}) +[('call_function', 'ge', {'compile_inductor': 0})] +('call_function', 'flex_attention_backward', {'compile_inductor': 0}) +('call_function', 'getitem_3', {'compile_inductor': 0}) +('call_function', 'getitem_4', {'compile_inductor': 0}) +('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950 + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/export/test_export.py b/test/export/test_export.py index 29b4922be1f4..aea2d6c00e10 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -15660,11 +15660,6 @@ def forward(self, x): test_serdes=True, ) - @testing.expectedFailureTrainingIRToRunDecomp - @testing.expectedFailureRetraceability - @testing.expectedFailureStrictV2 - @testing.expectedFailureStrict # annotation needs to be handled in dynamo - @testing.expectedFailureSerDer def test_preserve_annotation(self): class M(torch.nn.Module): def forward(self, x): diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index f4850f7e5e9b..3b79dc26dc78 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -750,6 +750,9 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle: return handle +# TODO - We want to run preserve_node_meta context manager here, but the CI +# fails (its unclear if the failures were flaky) +# @torch.fx.traceback.preserve_node_meta() @preserve_global_state def trace_frame( code: types.CodeType, diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 45adf58993e9..fc0f7238a5c5 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -51,7 +51,6 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX from .utils import ( getfile, hashable, - is_annotate_wrapped_function, is_lru_cache_wrapped_function, NP_SUPPORTED_MODULES, unwrap_if_wrapper, @@ -155,7 +154,6 @@ manual_torch_name_rule_map: dict[ type[UserFunctionVariable], ], ] = { - "torch.fx.traceback.annotate": UserFunctionVariable, "torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable, "torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable, "torch.overrides.is_tensor_like": TorchInGraphFunctionVariable, @@ -2996,9 +2994,6 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]: continue obj = torch_dir + k[len("torch/") :] if obj is not None: - if is_annotate_wrapped_function(obj): - # pyrefly: ignore # missing-attribute - obj = obj.__wrapped__ if is_lru_cache_wrapped_function(obj): obj = obj.__wrapped__ if obj in d and d[obj] != v: @@ -3430,6 +3425,7 @@ MOD_INLINELIST = [ "torch.fx._symbolic_trace", "torch.fx.experimental.proxy_tensor", "torch.fx.passes.shape_prop", + "torch.fx.traceback", "torch.nn", "torch.overrides", "torch.random", diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index b6d51f70a6e4..8da851d66b98 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -1111,14 +1111,6 @@ def is_lru_cache_wrapped_function( ) -def is_annotate_wrapped_function( - value: Any, -) -> bool: - return value == torch.fx.traceback.annotate and is_function( - inspect.getattr_static(value, "__wrapped__") - ) - - _FuncTypes: TypeAlias = Union[ types.FunctionType, types.BuiltinFunctionType, diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index a4eabc01a1de..24de4476a62e 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -29,6 +29,7 @@ from .ctx_manager import ( DynamoConfigPatchVariable, ErrorOnGraphBreakVariable, FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, GradIncrementNestingCtxManagerVariable, GradInplaceRequiresGradCtxManagerVariable, GradModeVariable, diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 20a4a9b389b3..cbd798511422 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1262,6 +1262,34 @@ class SDPAKernelVariable(ContextWrappingVariable): return "_sdpa_kernel_variadic" +class FxTracebackAnnotateVariable(ContextWrappingVariable): + """ + fx.traceback.annotate is a context manager that allows users to annotate the + fx graph nodes with custom metadata. In the context of Dynamo, we don't have + to trace the body of the context manager. Instead we want to directly run + the body of the context manager, so the Dynamo created Fx graphs have the + right custom metadata. This variable tracker just runs __enter__ and + __exit__ method (instead of tracing). + """ + + def __init__(self, target_values, initial_values=None, **kwargs) -> None: + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) + + def enter(self, tx, *args): + cm = torch.fx.traceback.annotate(self.target_values) + cm.__enter__() + self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None)) + return variables.ConstantVariable.create(None) + + def module_name(self): + return "torch.fx.traceback" + + def fn_name(self): + return "annotate" + + class StreamVariable(VariableTracker): def __init__(self, proxy, value, device, **kwargs) -> None: if proxy is not None and "example_value" in proxy.node.meta: diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index d8800f0fa74f..646f4ab0a186 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -125,6 +125,7 @@ supported_ctx_manager_classes = dict.fromkeys( torch.autograd.graph.disable_saved_tensors_hooks, torch.cpu.amp.autocast_mode.autocast, torch.cuda.amp.autocast_mode.autocast, + torch.fx.traceback.annotate, # We'll let Dynamo inline into the contextlib part of these context # manager instances, all the way till it invokes the wrapped function # itself (at which point we wrap it back to special context manager @@ -325,6 +326,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable): DisabledSavedTensorsHooksVariable, DualLevelContextManager, FSDPParamGroupUseTrainingStateVariable, + FxTracebackAnnotateVariable, GradIncrementNestingCtxManagerVariable, GradInplaceRequiresGradCtxManagerVariable, GradModeVariable, @@ -359,6 +361,11 @@ class TorchCtxManagerClassVariable(BaseTorchVariable): assert len(args) <= 1 and len(kwargs) == 0 inf_mode = args[0].as_python_constant() if len(args) == 1 else True return InferenceModeVariable.create(tx, inf_mode) + elif self.value is torch.fx.traceback.annotate: + assert len(args) <= 1 and len(kwargs) == 0 + return FxTracebackAnnotateVariable( + args[0].as_python_constant(), source=self.source + ) elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream): from torch._dynamo.variables.builder import wrap_fx_proxy_cls diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 1e615925e1e6..d40f1b353a5c 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -273,11 +273,7 @@ def annotate(annotation_dict: dict): global current_meta has_custom = "custom" in current_meta - old_custom = {} - # cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here, - # as dynamo doesn't support copy.copy() - for k, v in current_meta.get("custom", {}).items(): - old_custom[k] = v # noqa: PERF403 + old_custom = copy.copy(current_meta.get("custom", {})) try: if not has_custom: