[hop] support local_map + SAC (#163322)

Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: https://github.com/pytorch/pytorch/issues/162246#issuecomment-3259176721. This PR adds those rules.

Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163322
Approved by: https://github.com/ezyang
This commit is contained in:
Simon Fan
2025-09-23 14:55:55 -07:00
committed by PyTorch MergeBot
parent 20eeb54814
commit 124dd364e9
2 changed files with 286 additions and 40 deletions

View File

@ -2,6 +2,8 @@
# flake8: noqa: B950
import contextlib
import functools
import unittest
import torch
@ -12,19 +14,36 @@ import torch._inductor.decomposition
import torch.nn.functional as F
from torch import nn
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.utils.checkpoint import create_selective_checkpoint_contexts
if torch.distributed.is_available():
from torch.distributed._tensor.experimental import local_map
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.testing._internal.common_utils import (
run_tests,
TEST_WITH_CROSSREF,
TEST_WITH_TORCHDYNAMO,
TEST_WITH_TORCHINDUCTOR,
TestCase,
)
nested_compile_region = torch.compiler.nested_compile_region
def get_skip_reasons():
msg = ""
if not torch.distributed.is_available():
msg += "Torch distributed not available. "
if TEST_WITH_TORCHINDUCTOR or TEST_WITH_TORCHDYNAMO:
msg += "Already manually torch.compile'd. "
return msg != "", msg
class MyTransform(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
@ -42,7 +61,14 @@ def context_parallel_attention(query, key, value):
return out
def create_model(attention_fn, nheads, dim1, dim2):
# NOTE: we use this function directly in the node checks
def save_scalar_muls(ctx, op, *args, **kwargs):
if op == torch.ops.aten.mul.Scalar:
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
class LocalMapTransformerBlock(nn.Module):
def __init__(self, nheads, dim1, dim2):
super().__init__()
@ -54,8 +80,14 @@ def create_model(attention_fn, nheads, dim1, dim2):
self.wo = nn.Linear(dim1, dim1, bias=bias)
self.w1 = nn.Linear(dim1, dim2, bias=bias)
self.w2 = nn.Linear(dim2, dim1, bias=bias)
if sac_policy:
self.sac_context_fn = functools.partial(
create_selective_checkpoint_contexts, sac_policy
)
else:
self.sac_context_fn = None
def forward(self, x):
def _forward(self, x):
q = self.wq(x)
k = self.wk(x)
v = self.wv(x)
@ -78,41 +110,63 @@ def create_model(attention_fn, nheads, dim1, dim2):
o = o0 + o
return o
def forward(self, x):
if self.sac_context_fn is not None:
return torch.utils.checkpoint.checkpoint(
self._forward,
x,
use_reentrant=False,
context_fn=self.sac_context_fn,
)
return self._forward(x)
return LocalMapTransformerBlock(nheads, dim1, dim2)
class TestLocalMap(TestCase):
@requires_cuda_and_triton
@unittest.skipIf(
not torch.distributed.is_available(), "Torch distributed not available."
)
def test_simple(self):
@local_map(
out_placements=((Shard(0), Shard(1), Shard(2)),),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
def cp_decorated(query, key, value):
return context_parallel_attention(query, key, value)
def get_local_mapped_functions():
assert torch.distributed.is_available()
cp_function = local_map(
context_parallel_attention,
out_placements=(Shard(0), Shard(1), Shard(2)),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
@local_map(
out_placements=((Shard(0), Shard(1), Shard(2)),),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
def cp_decorated(query, key, value):
return context_parallel_attention(query, key, value)
cp_function = local_map(
context_parallel_attention,
out_placements=(Shard(0), Shard(1), Shard(2)),
in_placements=(
(Shard(0), Shard(1), Shard(2)), # query
(Shard(0), Shard(1), Replicate()), # key
(Shard(0), Shard(1), Replicate()), # value
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
)
return cp_decorated, cp_function
class TestLocalMap(TestCase):
def setUp(self):
self.exit_stack = contextlib.ExitStack()
self.exit_stack.enter_context(sdpa_kernel(backends=[SDPBackend.MATH]))
def tearDown(self):
self.exit_stack.close()
@unittest.skipIf(*get_skip_reasons())
def test_simple(self):
cp_decorated, cp_function = get_local_mapped_functions()
bs = 8 * 1
dim1 = 96
dim2 = dim1 * 4
@ -123,21 +177,24 @@ class TestLocalMap(TestCase):
backend = EagerAndRecordGraphs()
model = create_model(cp_decorated, nheads, dim1, dim2).cuda()
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
model = create_model(cp_decorated, nheads, dim1, dim2)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
model = create_model(cp_function, nheads, dim1, dim2).cuda()
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
model = create_model(cp_function, nheads, dim1, dim2)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
if not TEST_WITH_CROSSREF:
self.assertEqual(len(backend.graphs), 2)
# should see local_map_hop in both
self.assertEqual(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
)
self.assertExpectedInline(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
"""\
@ -193,10 +250,177 @@ class GraphModule(torch.nn.Module):
""",
)
@unittest.skipIf(*get_skip_reasons())
def test_sac(self):
cp_decorated, cp_function = get_local_mapped_functions()
bs = 8 * 1
dim1 = 96
dim2 = dim1 * 4
nheads = 16
seq_len = 16
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
backend = AotEagerAndRecordGraphs()
model = create_model(
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
model = create_model(
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
with LocalMapWrappedHigherOrderVariable.enable():
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
if not TEST_WITH_CROSSREF:
self.assertEqual(len(backend.graphs), 2)
self.assertEqual(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
)
self.assertEqual(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
)
self.assertEqual(
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
)
self.assertEqual(
len(
backend.graphs[0].graph.find_nodes(
op="call_function",
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
)
),
1,
)
# TODO: add joint to the testing compile backend
fw_outs = {
n.name
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
}
bw_ins = {
n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
}
for node in backend.fw_graphs[0].graph.nodes:
if "recompute" in node.meta:
expected = save_scalar_muls(None, node.target, None, None)
actual = node.meta["recompute"]
self.assertEqual(expected, actual)
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
self.assertTrue(node.name in fw_outs and node.name in bw_ins)
elif (
actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
):
# can still be in fw_outs for post-graph bytecode
self.assertFalse(node.name in bw_ins)
@unittest.skipIf(*get_skip_reasons())
def test_sac_deferred(self):
# This test is in a bit of a weird state, it needs compositional compile API
# so that we can defer inlining for up until AOTAutograd stage 1.
# Then we should be inlined by stage 2. But we can't do that today.
cp_decorated, cp_function = get_local_mapped_functions()
bs = 8 * 1
dim1 = 96
dim2 = dim1 * 4
nheads = 16
seq_len = 16
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
backend = AotEagerAndRecordGraphs()
model = create_model(
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
try:
with (
LocalMapWrappedHigherOrderVariable.enable(),
torch._higher_order_ops.local_map.defer_inlining(),
):
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
except AttributeError as e:
# TODO: get rid of this when we can install as a subgraph
self.assertTrue(
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
in str(e)
)
model = create_model(
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
)
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
try:
with (
LocalMapWrappedHigherOrderVariable.enable(),
torch._higher_order_ops.local_map.defer_inlining(),
):
out = torch.compile(model, backend=backend)(*inputs)
out.sum().backward()
except AttributeError as e:
# TODO: get rid of this when we can install as a subgraph
self.assertTrue(
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
in str(e)
)
# TODO: re-enable tests on backward when we can install as a subgraph
if not TEST_WITH_CROSSREF:
self.assertEqual(len(backend.graphs), 2)
self.assertEqual(
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
)
self.assertEqual(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
)
# self.assertEqual(
# normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
# normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
# )
self.assertEqual(
len(
backend.graphs[0].graph.find_nodes(
op="call_function",
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
)
),
1,
)
# TODO: add joint to the testing compile backend
fw_outs = {
n.name
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
}
# bw_ins = {
# n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
# }
for node in backend.fw_graphs[0].graph.nodes:
if "recompute" in node.meta:
expected = save_scalar_muls(None, node.target, None, None)
actual = node.meta["recompute"]
self.assertEqual(expected, actual)
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
self.assertTrue(node.name in fw_outs)
# self.assertTrue(node.name in fw_outs and node.name in bw_ins)
# elif (
# actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
# ):
# # can still be in fw_outs for post-graph bytecode
# self.assertFalse(node.name in bw_ins)
if __name__ == "__main__":

View File

@ -15,6 +15,7 @@ import torch.utils._pytree as pytree
from torch._C import DispatchKey
from torch._higher_order_ops.utils import (
clone_outputs_aliasing_inputs,
redirect_to_mode,
save_tensors_and_symints_for_backward,
saved_tensors_and_symints,
)
@ -22,6 +23,7 @@ from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
# Proxy the HOP instead of inlining into it
@ -49,6 +51,10 @@ class LocalMapHOP(HigherOrderOperator):
local_map_hop = LocalMapHOP()
# Registers dispatches for SAC
redirect_to_mode(local_map_hop, _CachingTorchDispatchMode)
redirect_to_mode(local_map_hop, _CachedTorchDispatchMode)
def create_hop_fw_bw(
fw_gm: GraphModule,
@ -203,6 +209,8 @@ class LocalMapAutogradOp(torch.autograd.Function):
*args: Any,
**kwargs: Any,
) -> tuple[Optional[torch.Tensor], ...]:
from torch._functorch._aot_autograd.schemas import MemoryFormatMeta
ctx.bw_gm = bw_gm
ctx.num_fw_ins = num_fw_ins
ctx.filtered_grads_idx = filtered_grads_idx
@ -214,17 +222,31 @@ class LocalMapAutogradOp(torch.autograd.Function):
saved_activations = fw_outs_with_saved_activations[num_fw_outs:]
save_tensors_and_symints_for_backward(ctx, saved_activations)
ctx.expected_tangent_metadata = {
i: MemoryFormatMeta.from_tensor(fw_outs[i]) for i in filtered_grads_idx
}
return fw_outs
@staticmethod
def backward(
ctx: Any, *_grads: tuple[torch.Tensor]
) -> tuple[Optional[torch.Tensor], ...]:
from torch._functorch._aot_autograd.runtime_wrappers import (
coerce_to_expected_memory_format,
)
saved_activations = saved_tensors_and_symints(ctx)
with torch._C._AutoDispatchBelowAutograd():
# Filter out grads that are None or do not require_grad.
# The AOTAutograd utils we rely on force this assumption.
grads = [_grads[i] for i in ctx.filtered_grads_idx]
assert len(grads) == len(ctx.expected_tangent_metadata), (
f"{len(grads)=} vs {len(ctx.expected_tangent_metadata)}"
)
for i, meta in ctx.expected_tangent_metadata.items():
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
if len(grad_ins) != ctx.num_fw_ins:
raise RuntimeError(