mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] run local_map with interpreter to preserve fx_traceback annotations (#165336)
We have an issue when using fx_traceback.annotate and HOPs that trace joint graphs. HOPs have bodies that have already been traced by Dynamo, and after Animesh's PR, does have the annotations. But when we lower that Dynamo HOP body to aten in either pre-dispatch or post-dispatch, we need to propagate the annotations to the aten nodes. AOTAutograd does this indirectly by piggybacking off the `PropagateUnbackedSymInts` fx.Interpreter. I'm not sure if all HOPs should be using it to trace their joints or not. This PR adds an interpreter to local_map's implementation. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165336 Approved by: https://github.com/yushangdi
This commit is contained in:
committed by
PyTorch MergeBot
parent
12fa4192c5
commit
21697feff2
@ -11,11 +11,13 @@ import torch._dynamo
|
|||||||
import torch._functorch
|
import torch._functorch
|
||||||
import torch._inductor
|
import torch._inductor
|
||||||
import torch._inductor.decomposition
|
import torch._inductor.decomposition
|
||||||
|
import torch.fx.traceback as fx_traceback
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
|
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
|
||||||
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||||
from torch.utils.checkpoint import create_selective_checkpoint_contexts
|
from torch.utils.checkpoint import create_selective_checkpoint_contexts
|
||||||
|
|
||||||
@ -130,6 +132,12 @@ def save_scalar_muls(ctx, op, *args, **kwargs):
|
|||||||
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
||||||
|
|
||||||
|
|
||||||
|
def save_mm(ctx, op, *args, **kwargs):
|
||||||
|
if op == torch.ops.aten.mm.default:
|
||||||
|
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
|
||||||
|
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
||||||
|
|
||||||
|
|
||||||
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
|
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
|
||||||
class LocalMapTransformerBlock(nn.Module):
|
class LocalMapTransformerBlock(nn.Module):
|
||||||
def __init__(self, nheads, dim1, dim2):
|
def __init__(self, nheads, dim1, dim2):
|
||||||
@ -556,8 +564,10 @@ class GraphModule(torch.nn.Module):
|
|||||||
out = x.view(-1) + 10
|
out = x.view(-1) + 10
|
||||||
return (out.view(x.shape),)
|
return (out.view(x.shape),)
|
||||||
|
|
||||||
# pretend this is a GraphModule for testing convenience
|
x = torch.randn(10, 80)
|
||||||
fn.meta = {
|
gm = make_fx(fn)(x)
|
||||||
|
|
||||||
|
gm.meta = {
|
||||||
"local_map_kwargs": {
|
"local_map_kwargs": {
|
||||||
"in_placements": ((Shard(0), Replicate(), Replicate()),),
|
"in_placements": ((Shard(0), Replicate(), Replicate()),),
|
||||||
"out_placements": ((Shard(0), Replicate(), Replicate()),),
|
"out_placements": ((Shard(0), Replicate(), Replicate()),),
|
||||||
@ -568,7 +578,7 @@ class GraphModule(torch.nn.Module):
|
|||||||
with FakeTensorMode():
|
with FakeTensorMode():
|
||||||
global_tensor = torch.randn(80, 80, requires_grad=True)
|
global_tensor = torch.randn(80, 80, requires_grad=True)
|
||||||
with torch._higher_order_ops.local_map.defer_inlining():
|
with torch._higher_order_ops.local_map.defer_inlining():
|
||||||
out = torch._higher_order_ops.local_map_hop(fn, global_tensor)
|
out = torch._higher_order_ops.local_map_hop(gm, global_tensor)
|
||||||
out[0].sum().backward()
|
out[0].sum().backward()
|
||||||
self.assertEqual(global_tensor.shape, (80, 80))
|
self.assertEqual(global_tensor.shape, (80, 80))
|
||||||
|
|
||||||
@ -715,6 +725,65 @@ class GraphModule(torch.nn.Module):
|
|||||||
inputs = (torch.randn(80, 80),)
|
inputs = (torch.randn(80, 80),)
|
||||||
ap_style_initial_capture(model, inputs)
|
ap_style_initial_capture(model, inputs)
|
||||||
|
|
||||||
|
@unittest.skipIf(*get_skip_reasons())
|
||||||
|
def test_fx_annotations(self):
|
||||||
|
@local_map(
|
||||||
|
out_placements=((Replicate(), Replicate(), Replicate()),),
|
||||||
|
in_placements=(
|
||||||
|
(Replicate(), Replicate(), Replicate()),
|
||||||
|
(Replicate(), Replicate(), Replicate()),
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
redistribute_inputs=True,
|
||||||
|
in_grad_placements=None,
|
||||||
|
device_mesh=self.mesh,
|
||||||
|
)
|
||||||
|
def fn(w, x, id):
|
||||||
|
with fx_traceback.annotate({"inside_local_map": id}):
|
||||||
|
return torch.matmul(x, w.t())
|
||||||
|
|
||||||
|
context_fn = functools.partial(create_selective_checkpoint_contexts, save_mm)
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.w = nn.Linear(80, 80)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = fn(self.w.weight, x, 0)
|
||||||
|
b = torch.utils.checkpoint.checkpoint(
|
||||||
|
fn, self.w.weight, x, 1, use_reentrant=False, context_fn=context_fn
|
||||||
|
)
|
||||||
|
return a.sum() + b.sum()
|
||||||
|
|
||||||
|
model = MyModule()
|
||||||
|
with FakeTensorMode():
|
||||||
|
fw_inputs = (torch.randn(80, 80),)
|
||||||
|
|
||||||
|
with fx_traceback.preserve_node_meta():
|
||||||
|
joint_gm_deferred = ap_style_initial_capture(model, fw_inputs)
|
||||||
|
joint_inputs = [
|
||||||
|
n.meta["val"]
|
||||||
|
for n in joint_gm_deferred.graph.nodes
|
||||||
|
if n.op == "placeholder"
|
||||||
|
]
|
||||||
|
# TODO: need a local shape interpreter for cases where the graph specializes on shapes
|
||||||
|
interp = torch.fx.Interpreter(joint_gm_deferred)
|
||||||
|
joint_gm_inlined = make_fx(interp.run)(*joint_inputs)
|
||||||
|
|
||||||
|
mm_nodes = joint_gm_inlined.graph.find_nodes(
|
||||||
|
op="call_function", target=torch.ops.aten.mm.default
|
||||||
|
)
|
||||||
|
self.assertEqual(len(mm_nodes), 4)
|
||||||
|
self.assertNotIn("partitioner_tag", mm_nodes[0].meta)
|
||||||
|
self.assertNotIn("partitioner_tag", mm_nodes[1].meta)
|
||||||
|
self.assertEqual(mm_nodes[2].meta["partitioner_tag"], "is_backward")
|
||||||
|
self.assertEqual(mm_nodes[3].meta["partitioner_tag"], "is_backward")
|
||||||
|
self.assertEqual(mm_nodes[0].meta["custom"]["inside_local_map"], 0)
|
||||||
|
self.assertEqual(mm_nodes[1].meta["custom"]["inside_local_map"], 1)
|
||||||
|
self.assertEqual(mm_nodes[2].meta["custom"]["inside_local_map"], 1)
|
||||||
|
self.assertEqual(mm_nodes[3].meta["custom"]["inside_local_map"], 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -257,9 +257,14 @@ def create_hop_fw_bw(
|
|||||||
primals = primals_and_tangents[:num_fw_inputs]
|
primals = primals_and_tangents[:num_fw_inputs]
|
||||||
tangents = primals_and_tangents[num_fw_inputs:]
|
tangents = primals_and_tangents[num_fw_inputs:]
|
||||||
|
|
||||||
def prepare_fw_with_masks(fn: Callable[..., Any]) -> Callable[..., Any]:
|
def prepare_fw_with_masks(
|
||||||
|
fw_gm: torch.fx.GraphModule,
|
||||||
|
) -> Callable[..., Any]:
|
||||||
def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]:
|
def fw_with_masks(*args: Any) -> tuple[tuple[Any], list[bool]]:
|
||||||
fw_out = fn(*args)
|
# The Interpreter here is required to propagate metadata
|
||||||
|
# from the dynamo graph body to the local_map graph body.
|
||||||
|
# This is required for fx_traceback.annotate for work.
|
||||||
|
fw_out = torch.fx.Interpreter(fw_gm).run(*args)
|
||||||
assert isinstance(fw_out, tuple), (
|
assert isinstance(fw_out, tuple), (
|
||||||
"Dynamo traced submodule should return tuple"
|
"Dynamo traced submodule should return tuple"
|
||||||
)
|
)
|
||||||
@ -293,6 +298,11 @@ def create_hop_fw_bw(
|
|||||||
*[example_grads[i] for i in filtered_grads_idx],
|
*[example_grads[i] for i in filtered_grads_idx],
|
||||||
]
|
]
|
||||||
joint_hop_gm = make_fx(joint_f)(*primals_and_tangents)
|
joint_hop_gm = make_fx(joint_f)(*primals_and_tangents)
|
||||||
|
from torch._functorch._aot_autograd.graph_capture import (
|
||||||
|
copy_fwd_metadata_to_bw_nodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
copy_fwd_metadata_to_bw_nodes(joint_hop_gm)
|
||||||
|
|
||||||
from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner
|
from torch._functorch._aot_autograd.graph_compile import prepare_for_partitioner
|
||||||
from torch._inductor.compile_fx import partition_fn
|
from torch._inductor.compile_fx import partition_fn
|
||||||
@ -437,7 +447,8 @@ def autograd_key(
|
|||||||
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs
|
fw_gm, bw_gm, num_fw_ins, num_fw_outs, filtered_grads_idx, *args, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return fw_gm(*args, **kwargs)
|
# TODO: get rid of this when we can install as a subgraph
|
||||||
|
return torch.fx.Interpreter(fw_gm).run(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@local_map_hop.py_functionalize_impl
|
@local_map_hop.py_functionalize_impl
|
||||||
|
Reference in New Issue
Block a user