mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[hop] support local_map + SAC (#163322)
Some ops like local_map hop's deferred mode are not desugared by make_fx, this means that when we apply SAC tags, we will need to define dispatch rules for the SAC torch dispatch modes as pointed out here: https://github.com/pytorch/pytorch/issues/162246#issuecomment-3259176721. This PR adds those rules. Additionally it fixes a pre-existing issue where we weren't coercing tangent layout (that AOTAutograd typically does) when partitioning the HOP joint. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163322 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
20eeb54814
commit
124dd364e9
@ -2,6 +2,8 @@
|
||||
# flake8: noqa: B950
|
||||
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
@ -12,19 +14,36 @@ import torch._inductor.decomposition
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable
|
||||
from torch.nn.attention import sdpa_kernel, SDPBackend
|
||||
from torch.utils.checkpoint import create_selective_checkpoint_contexts
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed._tensor.experimental import local_map
|
||||
from torch.distributed.tensor.placement_types import Replicate, Shard
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests, TEST_WITH_CROSSREF, TestCase
|
||||
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TEST_WITH_CROSSREF,
|
||||
TEST_WITH_TORCHDYNAMO,
|
||||
TEST_WITH_TORCHINDUCTOR,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
nested_compile_region = torch.compiler.nested_compile_region
|
||||
|
||||
|
||||
def get_skip_reasons():
|
||||
msg = ""
|
||||
if not torch.distributed.is_available():
|
||||
msg += "Torch distributed not available. "
|
||||
if TEST_WITH_TORCHINDUCTOR or TEST_WITH_TORCHDYNAMO:
|
||||
msg += "Already manually torch.compile'd. "
|
||||
|
||||
return msg != "", msg
|
||||
|
||||
|
||||
class MyTransform(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
@ -42,7 +61,14 @@ def context_parallel_attention(query, key, value):
|
||||
return out
|
||||
|
||||
|
||||
def create_model(attention_fn, nheads, dim1, dim2):
|
||||
# NOTE: we use this function directly in the node checks
|
||||
def save_scalar_muls(ctx, op, *args, **kwargs):
|
||||
if op == torch.ops.aten.mul.Scalar:
|
||||
return torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE
|
||||
return torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
||||
|
||||
|
||||
def create_model(attention_fn, nheads, dim1, dim2, sac_policy=None):
|
||||
class LocalMapTransformerBlock(nn.Module):
|
||||
def __init__(self, nheads, dim1, dim2):
|
||||
super().__init__()
|
||||
@ -54,8 +80,14 @@ def create_model(attention_fn, nheads, dim1, dim2):
|
||||
self.wo = nn.Linear(dim1, dim1, bias=bias)
|
||||
self.w1 = nn.Linear(dim1, dim2, bias=bias)
|
||||
self.w2 = nn.Linear(dim2, dim1, bias=bias)
|
||||
if sac_policy:
|
||||
self.sac_context_fn = functools.partial(
|
||||
create_selective_checkpoint_contexts, sac_policy
|
||||
)
|
||||
else:
|
||||
self.sac_context_fn = None
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
q = self.wq(x)
|
||||
k = self.wk(x)
|
||||
v = self.wv(x)
|
||||
@ -78,41 +110,63 @@ def create_model(attention_fn, nheads, dim1, dim2):
|
||||
o = o0 + o
|
||||
return o
|
||||
|
||||
def forward(self, x):
|
||||
if self.sac_context_fn is not None:
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
self._forward,
|
||||
x,
|
||||
use_reentrant=False,
|
||||
context_fn=self.sac_context_fn,
|
||||
)
|
||||
return self._forward(x)
|
||||
|
||||
return LocalMapTransformerBlock(nheads, dim1, dim2)
|
||||
|
||||
|
||||
class TestLocalMap(TestCase):
|
||||
@requires_cuda_and_triton
|
||||
@unittest.skipIf(
|
||||
not torch.distributed.is_available(), "Torch distributed not available."
|
||||
)
|
||||
def test_simple(self):
|
||||
@local_map(
|
||||
out_placements=((Shard(0), Shard(1), Shard(2)),),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
def cp_decorated(query, key, value):
|
||||
return context_parallel_attention(query, key, value)
|
||||
def get_local_mapped_functions():
|
||||
assert torch.distributed.is_available()
|
||||
|
||||
cp_function = local_map(
|
||||
context_parallel_attention,
|
||||
out_placements=(Shard(0), Shard(1), Shard(2)),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
@local_map(
|
||||
out_placements=((Shard(0), Shard(1), Shard(2)),),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
def cp_decorated(query, key, value):
|
||||
return context_parallel_attention(query, key, value)
|
||||
|
||||
cp_function = local_map(
|
||||
context_parallel_attention,
|
||||
out_placements=(Shard(0), Shard(1), Shard(2)),
|
||||
in_placements=(
|
||||
(Shard(0), Shard(1), Shard(2)), # query
|
||||
(Shard(0), Shard(1), Replicate()), # key
|
||||
(Shard(0), Shard(1), Replicate()), # value
|
||||
),
|
||||
redistribute_inputs=True,
|
||||
in_grad_placements=None,
|
||||
device_mesh=None,
|
||||
)
|
||||
|
||||
return cp_decorated, cp_function
|
||||
|
||||
|
||||
class TestLocalMap(TestCase):
|
||||
def setUp(self):
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.exit_stack.enter_context(sdpa_kernel(backends=[SDPBackend.MATH]))
|
||||
|
||||
def tearDown(self):
|
||||
self.exit_stack.close()
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_simple(self):
|
||||
cp_decorated, cp_function = get_local_mapped_functions()
|
||||
bs = 8 * 1
|
||||
dim1 = 96
|
||||
dim2 = dim1 * 4
|
||||
@ -123,21 +177,24 @@ class TestLocalMap(TestCase):
|
||||
|
||||
backend = EagerAndRecordGraphs()
|
||||
|
||||
model = create_model(cp_decorated, nheads, dim1, dim2).cuda()
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
|
||||
model = create_model(cp_decorated, nheads, dim1, dim2)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
model = create_model(cp_function, nheads, dim1, dim2).cuda()
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True).cuda(),)
|
||||
model = create_model(cp_function, nheads, dim1, dim2)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertEqual(len(backend.graphs), 2)
|
||||
# should see local_map_hop in both
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
@ -193,10 +250,177 @@ class GraphModule(torch.nn.Module):
|
||||
""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_sac(self):
|
||||
cp_decorated, cp_function = get_local_mapped_functions()
|
||||
bs = 8 * 1
|
||||
dim1 = 96
|
||||
dim2 = dim1 * 4
|
||||
nheads = 16
|
||||
seq_len = 16
|
||||
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
model = create_model(
|
||||
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
|
||||
)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
model = create_model(
|
||||
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
|
||||
)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
with LocalMapWrappedHigherOrderVariable.enable():
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertEqual(len(backend.graphs), 2)
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
self.assertEqual(
|
||||
len(
|
||||
backend.graphs[0].graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
|
||||
)
|
||||
),
|
||||
1,
|
||||
)
|
||||
# TODO: add joint to the testing compile backend
|
||||
fw_outs = {
|
||||
n.name
|
||||
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
|
||||
}
|
||||
bw_ins = {
|
||||
n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
|
||||
}
|
||||
for node in backend.fw_graphs[0].graph.nodes:
|
||||
if "recompute" in node.meta:
|
||||
expected = save_scalar_muls(None, node.target, None, None)
|
||||
actual = node.meta["recompute"]
|
||||
self.assertEqual(expected, actual)
|
||||
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
|
||||
self.assertTrue(node.name in fw_outs and node.name in bw_ins)
|
||||
elif (
|
||||
actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
||||
):
|
||||
# can still be in fw_outs for post-graph bytecode
|
||||
self.assertFalse(node.name in bw_ins)
|
||||
|
||||
@unittest.skipIf(*get_skip_reasons())
|
||||
def test_sac_deferred(self):
|
||||
# This test is in a bit of a weird state, it needs compositional compile API
|
||||
# so that we can defer inlining for up until AOTAutograd stage 1.
|
||||
# Then we should be inlined by stage 2. But we can't do that today.
|
||||
|
||||
cp_decorated, cp_function = get_local_mapped_functions()
|
||||
bs = 8 * 1
|
||||
dim1 = 96
|
||||
dim2 = dim1 * 4
|
||||
nheads = 16
|
||||
seq_len = 16
|
||||
|
||||
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
||||
|
||||
backend = AotEagerAndRecordGraphs()
|
||||
|
||||
model = create_model(
|
||||
cp_decorated, nheads, dim1, dim2, sac_policy=save_scalar_muls
|
||||
)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
try:
|
||||
with (
|
||||
LocalMapWrappedHigherOrderVariable.enable(),
|
||||
torch._higher_order_ops.local_map.defer_inlining(),
|
||||
):
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
except AttributeError as e:
|
||||
# TODO: get rid of this when we can install as a subgraph
|
||||
self.assertTrue(
|
||||
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
|
||||
in str(e)
|
||||
)
|
||||
|
||||
model = create_model(
|
||||
cp_function, nheads, dim1, dim2, sac_policy=save_scalar_muls
|
||||
)
|
||||
inputs = (torch.randn(bs, seq_len, dim1, requires_grad=True),)
|
||||
try:
|
||||
with (
|
||||
LocalMapWrappedHigherOrderVariable.enable(),
|
||||
torch._higher_order_ops.local_map.defer_inlining(),
|
||||
):
|
||||
out = torch.compile(model, backend=backend)(*inputs)
|
||||
out.sum().backward()
|
||||
except AttributeError as e:
|
||||
# TODO: get rid of this when we can install as a subgraph
|
||||
self.assertTrue(
|
||||
"module 'torch._higher_order_ops.local_map' has no attribute 'call_local_map'"
|
||||
in str(e)
|
||||
)
|
||||
|
||||
# TODO: re-enable tests on backward when we can install as a subgraph
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertEqual(len(backend.graphs), 2)
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
self.assertEqual(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
normalize_gm(backend.fw_graphs[1].print_readable(print_output=False)),
|
||||
)
|
||||
# self.assertEqual(
|
||||
# normalize_gm(backend.bw_graphs[0].print_readable(print_output=False)),
|
||||
# normalize_gm(backend.bw_graphs[1].print_readable(print_output=False)),
|
||||
# )
|
||||
self.assertEqual(
|
||||
len(
|
||||
backend.graphs[0].graph.find_nodes(
|
||||
op="call_function",
|
||||
target=torch._higher_order_ops.wrap.tag_activation_checkpoint,
|
||||
)
|
||||
),
|
||||
1,
|
||||
)
|
||||
# TODO: add joint to the testing compile backend
|
||||
fw_outs = {
|
||||
n.name
|
||||
for n in backend.fw_graphs[0].graph.find_nodes(op="output")[0].args[0]
|
||||
}
|
||||
# bw_ins = {
|
||||
# n.name for n in backend.bw_graphs[0].graph.find_nodes(op="placeholder")
|
||||
# }
|
||||
for node in backend.fw_graphs[0].graph.nodes:
|
||||
if "recompute" in node.meta:
|
||||
expected = save_scalar_muls(None, node.target, None, None)
|
||||
actual = node.meta["recompute"]
|
||||
self.assertEqual(expected, actual)
|
||||
if actual == torch.utils.checkpoint.CheckpointPolicy.MUST_SAVE:
|
||||
self.assertTrue(node.name in fw_outs)
|
||||
# self.assertTrue(node.name in fw_outs and node.name in bw_ins)
|
||||
# elif (
|
||||
# actual == torch.utils.checkpoint.CheckpointPolicy.MUST_RECOMPUTE
|
||||
# ):
|
||||
# # can still be in fw_outs for post-graph bytecode
|
||||
# self.assertFalse(node.name in bw_ins)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -15,6 +15,7 @@ import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import (
|
||||
clone_outputs_aliasing_inputs,
|
||||
redirect_to_mode,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
)
|
||||
@ -22,6 +23,7 @@ from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.utils.checkpoint import _CachedTorchDispatchMode, _CachingTorchDispatchMode
|
||||
|
||||
|
||||
# Proxy the HOP instead of inlining into it
|
||||
@ -49,6 +51,10 @@ class LocalMapHOP(HigherOrderOperator):
|
||||
|
||||
local_map_hop = LocalMapHOP()
|
||||
|
||||
# Registers dispatches for SAC
|
||||
redirect_to_mode(local_map_hop, _CachingTorchDispatchMode)
|
||||
redirect_to_mode(local_map_hop, _CachedTorchDispatchMode)
|
||||
|
||||
|
||||
def create_hop_fw_bw(
|
||||
fw_gm: GraphModule,
|
||||
@ -203,6 +209,8 @@ class LocalMapAutogradOp(torch.autograd.Function):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
from torch._functorch._aot_autograd.schemas import MemoryFormatMeta
|
||||
|
||||
ctx.bw_gm = bw_gm
|
||||
ctx.num_fw_ins = num_fw_ins
|
||||
ctx.filtered_grads_idx = filtered_grads_idx
|
||||
@ -214,17 +222,31 @@ class LocalMapAutogradOp(torch.autograd.Function):
|
||||
saved_activations = fw_outs_with_saved_activations[num_fw_outs:]
|
||||
save_tensors_and_symints_for_backward(ctx, saved_activations)
|
||||
|
||||
ctx.expected_tangent_metadata = {
|
||||
i: MemoryFormatMeta.from_tensor(fw_outs[i]) for i in filtered_grads_idx
|
||||
}
|
||||
return fw_outs
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: Any, *_grads: tuple[torch.Tensor]
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
from torch._functorch._aot_autograd.runtime_wrappers import (
|
||||
coerce_to_expected_memory_format,
|
||||
)
|
||||
|
||||
saved_activations = saved_tensors_and_symints(ctx)
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
# Filter out grads that are None or do not require_grad.
|
||||
# The AOTAutograd utils we rely on force this assumption.
|
||||
grads = [_grads[i] for i in ctx.filtered_grads_idx]
|
||||
assert len(grads) == len(ctx.expected_tangent_metadata), (
|
||||
f"{len(grads)=} vs {len(ctx.expected_tangent_metadata)}"
|
||||
)
|
||||
|
||||
for i, meta in ctx.expected_tangent_metadata.items():
|
||||
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
|
||||
|
||||
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
|
||||
if len(grad_ins) != ctx.num_fw_ins:
|
||||
raise RuntimeError(
|
||||
|
Reference in New Issue
Block a user