[hop] run local_map with interpreter to preserve fx_traceback annotations (#165336)

We have an issue when using fx_traceback.annotate and HOPs that trace joint graphs. HOPs have bodies that have already been traced by Dynamo, and after Animesh's PR, does have the annotations. But when we lower that Dynamo HOP body to aten in either pre-dispatch or post-dispatch, we need to propagate the annotations to the aten nodes.

AOTAutograd does this indirectly by piggybacking off the `PropagateUnbackedSymInts` fx.Interpreter. I'm not sure if all HOPs should be using it to trace their joints or not. This PR adds an interpreter to local_map's implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165336
Approved by: https://github.com/yushangdi
This commit is contained in:
Simon Fan
2025-10-13 19:42:08 -07:00
committed by PyTorch MergeBot
parent 12fa4192c5
commit 21697feff2
2 changed files with 86 additions and 6 deletions

View File

@ -11,11 +11,13 @@ import torch._dynamo
import torch._functorch
import torch._inductor
import torch._inductor.decomposition
import torch.fx.traceback as fx_traceback
import torch.nn.functional as F
from torch import nn
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.utils.checkpoint import create_selective_checkpoint_contexts
@ -130,6 +132,12 @@ def save_scalar_muls(ctx, op, *args, **kwargs):
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
def save_mm(ctx, op, *args, **kwargs):
if op == torch.ops.aten.mm.default:
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
class LocalMapTransformerBlock(nn.Module):
def __init__(self, nheads, dim1, dim2):
@ -556,8 +564,10 @@ class GraphModule(torch.nn.Module):
out = x.view(-1) + 10
return (out.view(x.shape),)
# pretend this is a GraphModule for testing convenience
fn.meta = {
x = torch.randn(10, 80)
gm = make_fx(fn)(x)
gm.meta = {
"local_map_kwargs": {
"in_placements": ((Shard(0), Replicate(), Replicate()),),
"out_placements": ((Shard(0), Replicate(), Replicate()),),
@ -568,7 +578,7 @@ class GraphModule(torch.nn.Module):
with FakeTensorMode():
global_tensor = torch.randn(80, 80, requires_grad=True)
with torch._higher_order_ops.local_map.defer_inlining():
out = torch._higher_order_ops.local_map_hop(fn, global_tensor)
out = torch._higher_order_ops.local_map_hop(gm, global_tensor)
out[0].sum().backward()
self.assertEqual(global_tensor.shape, (80, 80))
@ -715,6 +725,65 @@ class GraphModule(torch.nn.Module):
inputs = (torch.randn(80, 80),)
ap_style_initial_capture(model, inputs)
@unittest.skipIf(*get_skip_reasons())
def test_fx_annotations(self):
@local_map(
out_placements=((Replicate(), Replicate(), Replicate()),),
in_placements=(
(Replicate(), Replicate(), Replicate()),
(Replicate(), Replicate(), Replicate()),
None,
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=self.mesh,
)
def fn(w, x, id):
with fx_traceback.annotate({"inside_local_map": id}):
return torch.matmul(x, w.t())
context_fn = functools.partial(create_selective_checkpoint_contexts, save_mm)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = nn.Linear(80, 80)
def forward(self, x):
a = fn(self.w.weight, x, 0)
b = torch.utils.checkpoint.checkpoint(
fn, self.w.weight, x, 1, use_reentrant=False, context_fn=context_fn
)
return a.sum() + b.sum()
model = MyModule()
with FakeTensorMode():
fw_inputs = (torch.randn(80, 80),)
with fx_traceback.preserve_node_meta():
joint_gm_deferred = ap_style_initial_capture(model, fw_inputs)
joint_inputs = [
n.meta["val"]
for n in joint_gm_deferred.graph.nodes
if n.op == "placeholder"
]
# TODO: need a local shape interpreter for cases where the graph specializes on shapes
interp = torch.fx.Interpreter(joint_gm_deferred)
joint_gm_inlined = make_fx(interp.run)(*joint_inputs)
mm_nodes = joint_gm_inlined.graph.find_nodes(
op="call_function", target=torch.ops.aten.mm.default
)
self.assertEqual(len(mm_nodes), 4)
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
self.assertEqual(mm_nodes[1].meta["custom"]["inside_local_map"], 1)
self.assertEqual(mm_nodes[2].meta["custom"]["inside_local_map"], 1)
self.assertEqual(mm_nodes[3].meta["custom"]["inside_local_map"], 0)
if __name__ == "__main__":
run_tests()

View File

@ -257,9 +257,14 @@ def create_hop_fw_bw(
primals = primals_and_tangents[:num_fw_inputs]
tangents = primals_and_tangents[num_fw_inputs:]
def prepare_fw_with_masks(fn: Callable[..., Any]) -> Callable[..., Any]:
def prepare_fw_with_masks(
fw_gm: torch.fx.GraphModule,
) -> Callable[..., Any]:
def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]:
fw_out = fn(*args)
# The Interpreter here is required to propagate metadata
# from the dynamo graph body to the local_map graph body.
# This is required for fx_traceback.annotate for work.
fw_out = torch.fx.Interpreter(fw_gm).run(*args)
assert isinstance(fw_out, tuple), (
"Dynamo traced submodule should return tuple"
)
@ -293,6 +298,11 @@ def create_hop_fw_bw(
*[example_grads[i] for i in filtered_grads_idx],
]
joint_hop_gm = make_fx(joint_f)(*primals_and_tangents)
from torch._functorch._aot_autograd.graph_capture import (
copy_fwd_metadata_to_bw_nodes,
)
copy_fwd_metadata_to_bw_nodes(joint_hop_gm)
from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner
from torch._inductor.compile_fx import partition_fn
@ -437,7 +447,8 @@ def autograd_key(
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs
)
return fw_gm(*args, **kwargs)
# TODO: get rid of this when we can install as a subgraph
return torch.fx.Interpreter(fw_gm).run(*args, **kwargs)
@local_map_hop.py_functionalize_impl