[dynamo] Support torch.fx.traceback.annotate (#164678)

Builds on top of https://github.com/pytorch/pytorch/pull/163673 and https://github.com/pytorch/pytorch/pull/164174. This will be used in the followup PRs to apply regional inductor compilation.

The existing implementation let Dynamo trace into the `torch.fx.traceback.annotate`, but thats not what we want. We want Dynamo to essentially run the torch.fx.traceback.annotate function in eager, so that every Fx node created in Dynamo Fx graph has the custom meta node.

This does not work with graph breaks yet. But we can solve that problem, if needed, in a separate PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164678
Approved by: https://github.com/SherlockNoMad, https://github.com/jansel, https://github.com/xmfan
This commit is contained in:
Animesh Jain
2025-10-05 00:13:50 -07:00
committed by PyTorch MergeBot
parent 9fff8155c3
commit 2883b5ab77
9 changed files with 298 additions and 22 deletions

View File

@ -0,0 +1,259 @@
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch.fx.traceback as fx_traceback
import torch.utils.checkpoint
from torch._dynamo.test_case import run_tests
from torch._dynamo.testing import AotEagerAndRecordGraphs
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.triton_utils import requires_cuda_and_triton
def checkpoint_wrapper(fn):
def inner(*args):
return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
return inner
class AnnotateTests(torch._dynamo.test_case.TestCase):
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), torch.fx.GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))
def test_annotations(self):
class Mod(torch.nn.Module):
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
with fx_traceback.annotate({"fdsp_bucket": 0}):
sin = torch.sin(x)
sub = sin - 2
with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
mul = sub * 2
div = mul / 3
return div
m = Mod()
backend = AotEagerAndRecordGraphs()
opt_m = torch.compile(m, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_m(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_x_', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sub', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""\
('call_function', 'sin', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'sub', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'mul_1', {'pp_stage': 0, 'cuda_stream': 2, 'fsdp_bucket': 1})
('call_function', 'cos', {'pp_stage': 0, 'fdsp_bucket': 0})
('call_function', 'mul_2', {'pp_stage': 0, 'fdsp_bucket': 0})""", # noqa: B950
)
def test_activation_checkpointing(self):
@checkpoint_wrapper
def gn(x):
return torch.sin(x)
def fn(x):
with fx_traceback.annotate({"ac_sin": 0}):
ac = gn(x)
return torch.sigmoid(ac)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_x_', {'ac_sin': 0})
('get_attr', 'wrap_body_0', {'ac_sin': 0})
[('placeholder', 'l_x_', {'ac_sin': 0}), ('call_function', 'sin', {'ac_sin': 0}), ('output', 'output', {'ac_sin': 0})]
('call_function', 'tag_activation_checkpoint', {'ac_sin': 0})
('call_function', 'ac', {'ac_sin': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""('call_function', 'sin', {'ac_sin': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'cos', {'ac_sin': 0})
('call_function', 'mul', {'ac_sin': 0})""", # noqa: B950
)
def test_activation_checkpointing_annotation_inside(self):
@checkpoint_wrapper
def gn(x):
x = x + 1
with fx_traceback.annotate({"stage": 0}):
p = torch.sin(x)
return p + 1
def fn(x):
ac = gn(x)
return torch.sigmoid(ac)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
x = torch.randn(10, requires_grad=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""('call_function', 'sin', {'stage': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('call_function', 'cos', {'stage': 0})
('call_function', 'mul', {'stage': 0})""", # noqa: B950
)
@requires_cuda_and_triton
def test_ac_flex_attention(self):
def _squared(score, b, h, m, n):
return score * score
def mask_mod(b, h, q, k):
return q >= 0
a = 12
b = 64
block_mask = create_block_mask(mask_mod, None, None, a * b, a * b)
def gn(x: torch.Tensor):
with fx_traceback.annotate({"compile_inductor": 0}):
return flex_attention(
x, x, x, block_mask=block_mask, score_mod=_squared
)
def fn(x):
x = torch.sin(x)
x = gn(x)
return torch.cos(x)
x = torch.randn(
1,
1,
a * b,
b,
dtype=torch.bfloat16,
device="cuda",
requires_grad=True,
)
backend = AotEagerAndRecordGraphs()
opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
opt_fn(x).sum().backward()
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
('placeholder', 'l_gn_closure_1_cell_contents_kv_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_kv_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_kv_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_q_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_q_indices', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_q_num_blocks', {'compile_inductor': 0})
('placeholder', 'l_gn_closure_1_cell_contents_full_q_indices', {'compile_inductor': 0})
('get_attr', 'score_mod_0', {'compile_inductor': 0})
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('placeholder', 'child_4', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'mask_fn_0', {'compile_inductor': 0})
[('placeholder', 'child', {'compile_inductor': 0}), ('placeholder', 'child_1', {'compile_inductor': 0}), ('placeholder', 'child_2', {'compile_inductor': 0}), ('placeholder', 'child_3', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention', {'compile_inductor': 0})
('call_function', 'out', {'compile_inductor': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(fw_metadata),
"""\
('get_attr', 'sdpa_score0', {'compile_inductor': 0})
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'sdpa_mask0', {'compile_inductor': 0})
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention', {'compile_inductor': 0})
('call_function', 'getitem', {'compile_inductor': 0})
('call_function', 'getitem_1', {'compile_inductor': 0})
('call_function', 'detach_1', {'compile_inductor': 0})
('call_function', 'detach_4', {'compile_inductor': 0})
('call_function', 'detach_5', {'compile_inductor': 0})""", # noqa: B950
)
self.assertExpectedInline(
str(bw_metadata),
"""\
('placeholder', 'getitem', {'compile_inductor': 0})
('placeholder', 'detach_5', {'compile_inductor': 0})
('call_function', 'zeros', {'compile_inductor': 0})
('call_function', 'detach', {'compile_inductor': 0})
('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0})
[]
('get_attr', 'joint_graph0', {'compile_inductor': 0})
[]
('get_attr', 'mask_graph0', {'compile_inductor': 0})
[('call_function', 'ge', {'compile_inductor': 0})]
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
('call_function', 'getitem_3', {'compile_inductor': 0})
('call_function', 'getitem_4', {'compile_inductor': 0})
('call_function', 'getitem_5', {'compile_inductor': 0})""", # noqa: B950
)
if __name__ == "__main__":
run_tests()

View File

@ -15573,11 +15573,6 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):

View File

@ -750,6 +750,7 @@ def register_bytecode_hook(hook: BytecodeHook) -> RemovableHandle:
@preserve_global_state
@torch.fx.traceback.preserve_node_meta()
def trace_frame(
code: types.CodeType,
globals: dict[str, object],

View File

@ -51,7 +51,6 @@ from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .utils import (
getfile,
hashable,
is_annotate_wrapped_function,
is_lru_cache_wrapped_function,
NP_SUPPORTED_MODULES,
unwrap_if_wrapper,
@ -155,7 +154,6 @@ manual_torch_name_rule_map: dict[
type[UserFunctionVariable],
],
] = {
"torch.fx.traceback.annotate": UserFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,
@ -3004,8 +3002,6 @@ def get_torch_obj_rule_map() -> dict[Any, type["VariableTracker"]]:
continue
obj = torch_dir + k[len("torch/") :]
if obj is not None:
if is_annotate_wrapped_function(obj):
obj = obj.__wrapped__
if is_lru_cache_wrapped_function(obj):
obj = obj.__wrapped__
if obj in d and d[obj] != v:
@ -3437,6 +3433,7 @@ MOD_INLINELIST = [
"torch.fx._symbolic_trace",
"torch.fx.experimental.proxy_tensor",
"torch.fx.passes.shape_prop",
"torch.fx.traceback",
"torch.nn",
"torch.overrides",
"torch.random",

View File

@ -1108,14 +1108,6 @@ def is_lru_cache_wrapped_function(
)
def is_annotate_wrapped_function(
value: Any,
) -> bool:
return value == torch.fx.traceback.annotate and is_function(
inspect.getattr_static(value, "__wrapped__")
)
_FuncTypes: TypeAlias = Union[
types.FunctionType,
types.BuiltinFunctionType,

View File

@ -29,6 +29,7 @@ from .ctx_manager import (
DynamoConfigPatchVariable,
ErrorOnGraphBreakVariable,
FSDPParamGroupUseTrainingStateVariable,
FxTracebackAnnotateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,

View File

@ -1262,6 +1262,34 @@ class SDPAKernelVariable(ContextWrappingVariable):
return "_sdpa_kernel_variadic"
class FxTracebackAnnotateVariable(ContextWrappingVariable):
"""
fx.traceback.annotate is a context manager that allows users to annotate the
fx graph nodes with custom metadata. In the context of Dynamo, we don't have
to trace the body of the context manager. Instead we want to directly run
the body of the context manager, so the Dynamo created Fx graphs have the
right custom metadata. This variable tracker just runs __enter__ and
__exit__ method (instead of tracing).
"""
def __init__(self, target_values, initial_values=None, **kwargs) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
def enter(self, tx, *args):
cm = torch.fx.traceback.annotate(self.target_values)
cm.__enter__()
self.set_cleanup_hook(tx, lambda: cm.__exit__(None, None, None))
return variables.ConstantVariable.create(None)
def module_name(self):
return "torch.fx.traceback"
def fn_name(self):
return "annotate"
class StreamVariable(VariableTracker):
def __init__(self, proxy, value, device, **kwargs) -> None:
if proxy is not None and "example_value" in proxy.node.meta:

View File

@ -125,6 +125,7 @@ supported_ctx_manager_classes = dict.fromkeys(
torch.autograd.graph.disable_saved_tensors_hooks,
torch.cpu.amp.autocast_mode.autocast,
torch.cuda.amp.autocast_mode.autocast,
torch.fx.traceback.annotate,
# We'll let Dynamo inline into the contextlib part of these context
# manager instances, all the way till it invokes the wrapped function
# itself (at which point we wrap it back to special context manager
@ -325,6 +326,7 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
DisabledSavedTensorsHooksVariable,
DualLevelContextManager,
FSDPParamGroupUseTrainingStateVariable,
FxTracebackAnnotateVariable,
GradIncrementNestingCtxManagerVariable,
GradInplaceRequiresGradCtxManagerVariable,
GradModeVariable,
@ -359,6 +361,11 @@ class TorchCtxManagerClassVariable(BaseTorchVariable):
assert len(args) <= 1 and len(kwargs) == 0
inf_mode = args[0].as_python_constant() if len(args) == 1 else True
return InferenceModeVariable.create(tx, inf_mode)
elif self.value is torch.fx.traceback.annotate:
assert len(args) <= 1 and len(kwargs) == 0
return FxTracebackAnnotateVariable(
args[0].as_python_constant(), source=self.source
)
elif inspect.isclass(self.value) and issubclass(self.value, torch.Stream):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls

View File

@ -277,11 +277,7 @@ def annotate(annotation_dict: dict):
global current_meta
has_custom = "custom" in current_meta
old_custom = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
old_custom[k] = v # noqa: PERF403
old_custom = copy.copy(current_meta.get("custom", {}))
try:
if not has_custom: