Files
pytorch/test/distributed/test_dynamo_distributed.py
nullplay ac529df244 Native matmul (#157743)
### Implementation of #151705

This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates.

To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705:

1. **Basic support** (this PR)
2. **Lazy broadcasting** for optimal performance (future PR)

### Summary of This PR

This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead.

### Notable Changes

1. Adds a new config flag: `config.triton.enable_native_matmul`
2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled
3. Enforces tililng suitable for matmul when the native matmul flag is enabled
4. Implements code generation for `ops.dot`
5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this.

@eellison @jansel @PaulZhang12 @shunting314

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743
Approved by: https://github.com/jansel
2025-10-14 04:22:30 +00:00

2114 lines
79 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import copy
import functools
import random
import unittest
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
from unittest.mock import patch
import numpy as np
import torch
import torch._dynamo
import torch._dynamo.logging
import torch._dynamo.test_case
import torch.distributed as dist
import torch.optim as optim
from torch import nn
from torch._C import FileCheck
from torch._dynamo import config
from torch._dynamo.backends.distributed import DDPOptimizer
from torch._dynamo.comptime import comptime
from torch._dynamo.testing import collect_results
from torch._dynamo.utils import same
from torch._higher_order_ops.wrap import tag_activation_checkpoint
from torch.compiler import set_enable_guard_collectives
from torch.distributed._functional_collectives import _maybe_wrap_tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import (
lambda_auto_wrap_policy,
transformer_auto_wrap_policy,
)
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
)
from torch.testing._internal.common_distributed import (
_dynamo_dist_per_rank_init,
DynamoDistributedMultiProcTestCase,
DynamoDistributedSingleProcTestCase,
import_transformers_or_skip,
requires_accelerator_dist_backend,
skip_if_lt_x_gpu,
)
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
from torch.testing._internal.triton_utils import requires_cuda_and_triton
def reset_rng_state():
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
@contextmanager
def enable_guard_collectives():
old = set_enable_guard_collectives(True)
try:
yield
finally:
set_enable_guard_collectives(old)
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
self.ctx_manager = ctx_manager
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
if self.ctx_manager is not None:
with self.ctx_manager():
return self.net(inputs)
else:
return self.net(inputs)
def get_model(
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = ToyModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class MutatingModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
super().__init__()
self.ctx_manager = ctx_manager
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
self.state = 1
def forward(self, inputs):
self.state = 2
return self.net(inputs) * self.state
def get_mutating_model(
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = MutatingModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class ForcedGetAttrMod(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.linear = torch.nn.Linear(1, 1)
self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device)
self.counter = 0
def forward(self, x):
self.counter += 1
return x * self.linear(x) * self.forced_linear.weight
def get_forced_getattr_module(device):
mod = ForcedGetAttrMod(device).to(device=device)
x = torch.randn(1, 1, device=device)
return mod, x, mod(x)
class ToyInnerModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)]
self.layers = nn.Sequential(*self.layers)
def forward(self, inputs):
return self.layers(inputs)
class ToyOuterModel(nn.Module):
def __init__(self, device):
super().__init__()
self.layers = [ToyInnerModel().to(device) for _ in range(2)]
self.layers = nn.Sequential(
self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU()
)
def forward(self, inputs):
return self.layers(inputs)
def get_toy_model_for_activation_checkpointing(device):
m = ToyOuterModel(device).to(device)
m.apply(init_weights)
inputs = torch.rand(100, 100).to(device)
return m, inputs
def find_first_node(gm, func):
for node in gm.graph.nodes:
if node.target is func:
return node
return None
def apply_fsdp_with_checkpointing(
model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True
):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
model = FSDP(
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
)
if use_activation_checkpointing:
checkpoint_wrapper_fn = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
check_fn=checkpoint_policy,
)
return model
def get_custom_model(device):
class MyCustomLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
tmp = torch.mm(x, self.weight.t())
# test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
# and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(512, 512)
def forward(self, x):
return self.linear(x)
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
mods = [
(MyLinear(), torch.nn.ReLU()),
# sandwich the custom in the middle so it comes before and after
(MyCustomLinear(), torch.nn.ReLU()),
(MyLinear(), torch.nn.ReLU()),
]
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
def forward(self, x, y):
# test special case where the 0th bucket (layers close to graph input) is at capacity, which would
# trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket.
# optimize this case by fusing that 'empty bucket' back together with the previous full one
return self.seq(x + y)
m = MyModule().to(device)
m.apply(init_weights)
inputs = torch.rand((512, 512)).to(device)
# test duplicated inputs
inputs = (inputs, inputs)
correct_outputs = m(*inputs)
return m, inputs, correct_outputs
def get_hf_bert(rank):
# Note: use @import_transformers_or_skip on your test case if you use this
# in a multiprocessing test
try:
from transformers import AutoModelForMaskedLM, BertConfig
except ImportError as e:
raise unittest.SkipTest("Unable to import transformers") from e
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
batch_size, max_length, config, device = (
4,
512,
BertConfig(),
f"{device_type}:{rank}",
)
model = AutoModelForMaskedLM.from_config(config).to(device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
device
)
inputs = {"input_ids": input_ids, "labels": decoder_ids}
model.train()
return model, inputs
class CheckSplitsCompiler:
def __init__(self) -> None:
self.compiler_called = 0
def compile_fn(self, gm, example_inputs):
self.compiler_called += 1
return gm
# This simulates DDP, but it doesn't actually do any process communication;
# it just has enough properties so that the dynamo distributed optimization is
# able to optimize. Feel free to simulate more properties as necessary. The
# other important thing is patching _active_ddp_module, which is what actually
# triggers DDP optimization
class FakeDDP(nn.Module):
def __init__(self, module, bucket_cap_mb=25):
super().__init__()
self.module = module
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
@contextmanager
def _inside_ddp_forward(self):
DDP._active_ddp_module = self
try:
yield
finally:
DDP._active_ddp_module = None
def forward(self, *inputs, **kwargs):
if not DDP._active_ddp_module:
with self._inside_ddp_forward():
return self.module.forward(*inputs, **kwargs)
else:
return self.module.forward(*inputs, **kwargs)
def run_hf_bert_ddp(self, model, inputs, backend):
reset_rng_state()
correct_outputs = model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = torch.compile(model, backend=backend)
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
self.assertTrue(same(correct_results, opt_results))
class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", True)
@patch.object(torch._inductor.config, "fallback_random", True)
@unittest.skipIf(
torch._inductor.config.triton.native_matmul,
"FIXME : native matmul fails. RuntimeError: Cannot access data pointer of Tensor",
)
def test_hf_bert_ddp_inductor(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "inductor")
@patch.object(config, "optimize_ddp", True)
def test_hf_bert_ddp_aot_eager(self):
model, inputs = get_hf_bert(0)
model = FakeDDP(model)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
@patch.object(config, "optimize_ddp", True)
def test_issue90375(self):
class Model(nn.Module):
def forward(self):
return torch.randn(3) * torch.randn(3)
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(model, backend="aot_eager")
opt_model()
@patch.object(config, "optimize_ddp", True)
def test_symbol_splitting(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x):
x = torch.cat([x, x])
y = x @ self.weight1
z = x + y @ self.weight2
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512))
@patch.object(config, "optimize_ddp", True)
def test_ddp_optimizer_inductor_strides_dont_specialize(self):
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc_0 = nn.Linear(768, 768)
self.fc_1 = nn.Linear(768, 768)
def forward(self, x):
x = self.fc_0(x)
x = self.fc_1(x)
return x
model = Model()
model = FakeDDP(model)
inp = torch.randn((16, 18, 768))
inp2 = torch.randn((16, 20, 768))
torch._dynamo.mark_dynamic(inp, 1)
torch._dynamo.mark_dynamic(inp2, 1)
torch._dynamo.utils.clear_compilation_metrics()
torch._dynamo.reset()
try:
DDP._active_ddp_module = model
opt_model = torch.compile(model)
self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
opt_model(inp)
compile_count_before = len(torch._dynamo.utils.get_compilation_metrics())
opt_model(inp2)
compile_count_after = len(torch._dynamo.utils.get_compilation_metrics())
# no recompiles
self.assertEqual(compile_count_before, compile_count_after)
finally:
DDP._active_ddp_module = None
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_direct(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, _ = y.tolist()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * u0
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_indirect(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
u0, _ = y.tolist()
a = torch.ones(u0)
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * a.sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
def test_unbacked_symbol_splitting_torture_multi(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
self.weight3 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
# partition one (contains the u0 def)
u0, _ = y.tolist()
x = torch.cat([x, x])
y1 = x @ self.weight1
# partition two (contains the variable)
y2 = y1 @ self.weight2
a = torch.ones(u0)
# partition three
z = (x + y2 @ self.weight3) * a.sum()
return z
model = Model()
model = FakeDDP(model, bucket_cap_mb=1)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
def test_unbacked_symbol_splitting_no_binding(self):
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight1 = nn.Parameter(torch.randn(512, 512))
self.weight2 = nn.Parameter(torch.randn(512, 512))
def forward(self, x, y):
nz = y.nonzero()
x = torch.cat([x, x])
y = x @ self.weight1
z = (x + y @ self.weight2) * (nz + 1).sum()
return z
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(dynamic=True)(model)
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
@patch.object(config, "optimize_ddp", True)
def test_call_method_forward(self):
class Model(nn.Module):
def __init__(
self,
):
super().__init__()
layers = []
for _ in range(2):
layer = nn.ModuleList(
[
nn.LayerNorm(96),
nn.MultiheadAttention(
embed_dim=96, num_heads=4, batch_first=True
),
]
)
layers.append(layer)
self.layers = nn.ModuleList(layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [Batch, Freq, Time, Feature]
B, F, T, H = x.shape
for m in self.layers:
x = x.reshape(B * F, T, H)
x = m[0](x)
x, _ = m[1].forward(x, x, x)
x = x.reshape(B, F, T, H)
return x
model = Model()
model = FakeDDP(model)
opt_model = torch.compile(model)
opt_model(torch.randn(2, 129, 100, 96))
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
# single process version; if it's just a problem in the Dynamo distributed
# # optimizer, you should be able to repro it single process!
@requires_accelerator_dist_backend(["nccl", "xccl"])
class TestMultiProc(DynamoDistributedMultiProcTestCase):
"""
Note: MultiProcTestCase spawns processes per test and is slow.
Prefer MultiThreadedTestCase for most tests. Perhaps use this one
sparingly for integration tests.
"""
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
@skip_if_lt_x_gpu(2)
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch.compile(m, backend="aot_eager")
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))
def _test_hf_bert_ddp_inductor(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "inductor")
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor(self):
self._test_hf_bert_ddp_inductor(static_graph=False)
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
@patch.object(torch._inductor.config, "fallback_random", True)
def test_hf_bert_ddp_inductor_static_graph(self):
self._test_hf_bert_ddp_inductor(static_graph=True)
def _test_hf_bert_aot_eager(self, static_graph):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_hf_bert(self.rank)
model = DDP(model, static_graph=static_graph)
run_hf_bert_ddp(self, model, inputs, "aot_eager")
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
def test_hf_bert_ddp_aot_eager(self):
self._test_hf_bert_aot_eager(static_graph=False)
@skip_if_lt_x_gpu(2)
@import_transformers_or_skip()
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
def test_hf_bert_ddp_aot_eager_static_graph(self):
self._test_hf_bert_aot_eager(static_graph=True)
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
def test_ddp_activation_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(64, 32)
self.fc2 = torch.nn.Linear(32, 16)
self.fc3 = torch.nn.Linear(16, 8)
def forward(self, inp):
return self.fc3(self.fc2(self.fc1(inp)))
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
model = MyModel().to(device=self.device_type)
# Activation checkpointing for Linear layers.
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
check_fn = lambda submodule: isinstance( # noqa: E731
submodule, torch.nn.Linear
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
model = DDP(model)
x = torch.randn(10, 64).to(self.device_type)
correct_outputs = model(x)
opt_model = torch.compile(model)
outputs = opt_model(x)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_aot_eager(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
),
use_orig_params=True,
)
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@skip_if_lt_x_gpu(2)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@requires_cuda_and_triton
def test_ddp_optimizer_cudagraph(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
# need a large channel to trigger ddp optimizer split module
self.CHANNELS = 640
self.convi = nn.Conv2d(46, self.CHANNELS, 3, padding=1, bias=False)
self.convp = nn.Conv2d(
self.CHANNELS, self.CHANNELS, 1, padding=0, bias=False
)
self.bni = nn.BatchNorm2d(self.CHANNELS)
def forward(self, bitmap_channels):
x = self.convi(bitmap_channels)
x = self.bni(x)
x = self.convp(x)
return x
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
net = Net().to(self.rank)
optimizer = torch.optim.SGD(
net.parameters(),
lr=5e-2,
)
net = DDP(net, device_ids=[self.rank])
opt_net = torch.compile(net, mode="reduce-overhead")
opt_net.train()
for _ in range(10):
optimizer.zero_grad()
data = torch.randn((16, 46, 8, 8), dtype=torch.float32, device="cuda")
opt_net(data).sum().backward()
# 2 fwd and 2 bwd graph such that 4 graphs in total
graph_id = (
torch._inductor.cudagraph_trees.get_container(self.rank)
.tree_manager.new_graph_id()
.id
)
self.assertTrue(graph_id == 4)
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_setattr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_mutating_model(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
self.assertEqual(len(counters["graph_break"]), 1)
first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015
self.assertIn("setattr() on Tensor.requires_grad", first_graph_break)
@config.patch(inline_inbuilt_nn_modules=False)
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_unspecialized_forced_getattr_no_inline(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
def test_fsdp_unspecialized_forced_getattr_inline(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
from torch._dynamo.utils import counters
counters.clear()
m, inputs, correct_outputs = get_forced_getattr_module(
f"{self.device_type}:{self.rank}"
)
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_inductor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=True)
fsdp_m = torch.compile(fsdp_m, backend="inductor")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Test with recursive wrapping, nested FSDP around each Linear
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(
m,
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
),
use_orig_params=True,
)
fsdp_m = torch.compile(fsdp_m, backend="inductor")
outputs = fsdp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@config.patch(enable_compiler_collectives=True)
@skip_if_lt_x_gpu(1)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_fsdp_activation_checkpointing(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
model, inputs = get_toy_model_for_activation_checkpointing(
f"{self.device_type}:{self.rank}"
)
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner)
correct_outputs = model(inputs)
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
opt_model = torch.compile(model, backend=cnt)
outputs = opt_model(inputs)
self.assertTrue(same(correct_outputs, outputs))
# Each FSDP module is a separate graph
self.assertEqual(cnt.frame_count, 2)
self.assertTrue(
find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None
)
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@config.patch(enable_compiler_collectives=True)
@unittest.skipIf(
PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
"Inaccurate results with fused SDPA kernels",
)
def test_hf_bert_fsdp(self):
def apply_fsdp(model, wrap_policy):
model = FSDP(
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
)
return model
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
for wrap_policy, test_instance in (
(None, "FSDP without recursive wrapping"),
):
print(f"Running hf_bert test for {test_instance}")
model, inputs = get_hf_bert(self.rank)
reset_rng_state()
eager_model = apply_fsdp(model, wrap_policy)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp(model, wrap_policy)
opt_model = torch.compile(opt_model, backend="inductor")
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
eager_model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(
opt_model, opt_outputs.logits, opt_loss, inputs_flat
)
self.assertTrue(same(correct_results, opt_results))
@import_transformers_or_skip()
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
@patch.object(torch._inductor.config, "fallback_random", True)
@config.patch(guard_nn_modules=True, enable_compiler_collectives=True)
def test_hf_bert_fsdp_activation_checkpointing(self):
from transformers.models.bert.modeling_bert import BertLayer
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
for wrap_policy, test_instance in (
(
functools.partial(
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,)
),
"FSDP with recursive wrapping BertLayer instances",
),
):
print(
f"Running hf_bert_activation_checkpointing test for {test_instance}"
)
model, inputs = get_hf_bert(self.rank)
check_fn = lambda submodule: isinstance( # noqa: E731
submodule, BertLayer
)
reset_rng_state()
eager_model = apply_fsdp_with_checkpointing(
model, wrap_policy, check_fn
)
correct_outputs = eager_model(**inputs)
correct_loss = correct_outputs.loss
correct_loss.backward()
reset_rng_state()
opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn)
opt_model = torch.compile(opt_model, backend="inductor")
opt_outputs = opt_model(**inputs)
opt_loss = opt_outputs.loss
opt_loss.backward()
inputs_flat = [inputs[k] for k in inputs]
correct_results = collect_results(
eager_model, correct_outputs.logits, correct_loss, inputs_flat
)
opt_results = collect_results(
opt_model, opt_outputs.logits, opt_loss, inputs_flat
)
self.assertTrue(same(correct_results, opt_results))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_tensor(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
class SimpleModel(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.linear = nn.Linear(input_size, output_size)
def forward(self, x):
return self.linear(x)
torch._dynamo.utils.clear_compilation_metrics()
model = SimpleModel(10, 2).to(self.rank)
model.forward = torch.compile(model.forward)
ddp_model = DDP(model, device_ids=[self.rank])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
def B(s):
return [torch.randn(s, 10), torch.randint(0, 2, (s,))]
if self.rank == 0:
dataloader = [B(5), B(8), B(6)]
else:
dataloader = [B(6), B(6), B(3)]
for data, labels in dataloader:
data, labels = data.to(self.rank), labels.to(self.rank)
optimizer.zero_grad()
output = ddp_model(data)
loss = loss_fn(output, labels)
loss.backward()
optimizer.step()
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_scalar(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
# TODO: This should be possible to do inside the function, but
device = f"{self.device_type}:{self.rank}"
@torch.compile()
def f(x, y):
return x + torch.ones(y, device=device).sum()
if self.rank == 0:
dataloader = [3, 3, 7]
else:
dataloader = [3, 4, 9]
for data in dataloader:
f(torch.randn(5, device=self.rank), data)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
zx = x.shape # noqa: F841
zy = y.shape # noqa: F841
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [4, 4]
else:
dataloader = [3, 4]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_graph_break_empty_graph_still_collective(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
z = y # noqa: F841
print("woof")
zx = x.shape # noqa: F841
zy = y.shape # noqa: F841
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [5, 5, 6]
else:
dataloader = [3, 4, 5]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_dim_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x, y):
zx = x.shape # noqa: F841
zy = y.shape # noqa: F841
return x.sum() + y.sum()
if self.rank == 0:
dataloader = [[4, 2]]
else:
dataloader = [[3]]
for data in dataloader:
f(
torch.randn(data, device=self.rank),
torch.randn(data, device=self.rank),
)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(rank, xs):
return xs[rank].sum()
xs = []
for _ in range(self.world_size):
xs.append(torch.randn(10, device=self.rank))
f(self.rank, xs)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_scalar_missing_source(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(rank, xs):
return torch.tensor(xs[rank], device=self.rank)
xs = []
for i in range(self.world_size):
xs.append(10 + i)
f(self.rank, xs)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@config.patch(enable_compiler_collectives=True)
def test_compiler_collectives_type_mismatch(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x):
if isinstance(x, int):
return torch.tensor(x, device=self.rank)
else:
return x.sum()
if self.rank == 0:
x = torch.randn(10, device=self.rank)
else:
x = 12
f(x)
# This deadlocks, I guess we don't support this
"""
if self.rank == 0:
x = torch.randn(12, device=self.rank)
else:
x = 10
f(x)
"""
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@enable_guard_collectives()
def test_guard_collective(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
@torch.compile()
def f(x):
return x.sum()
x = torch.randn(10, device=self.rank)
f(x)
if self.rank == 0:
x = torch.randn(10, device=self.rank)
else:
x = torch.randn(12, device=self.rank) # recompile on one rank
f(x)
metrics = torch._dynamo.utils.get_compilation_metrics()
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_get_pg_attr(self):
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
pg = dist.distributed_c10d._get_default_group()
device = f"{self.device_type}:{self.rank}"
@torch.compile(fullgraph=True)
def f(x):
if dist.distributed_c10d._rank_not_in_group(pg):
return x + 1
else:
return x - 1
x = torch.ones(4, device=device)
self.assertEqual(f(x), x - 1)
pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER
self.assertEqual(f(x), x + 1)
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", False)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
def test_asymmetric_compilation(self):
from torch._dynamo.comptime import comptime
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
device = f"{self.device_type}:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
cnt = torch._dynamo.testing.CompileCounter()
sleep_time = 5
@torch.compile(backend=cnt)
def f(x):
if self.rank == 0:
comptime.sleep(sleep_time)
y = 2 * x
return y.sum()
backend = pg._get_backend(torch.device(device))
backend._set_default_timeout(timedelta(seconds=sleep_time - 2))
x = torch.ones(4, device=device)
# NCCL startup is lazy
w = pg.allreduce(x)
w.wait()
f(x)
if self.rank != 0:
# test fails with NCCL timeout without this line
dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
timedelta(seconds=sleep_time)
)
w = pg.allreduce(x)
w.wait()
torch.accelerator.synchronize(device)
metrics = torch._dynamo.utils.get_compilation_metrics()
# Number of compiles same on all nodes
res = [None] * self.world_size
torch.distributed.all_gather_object(res, len(metrics))
for r in res[1:]:
self.assertEqual(res[0], r)
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(torch._inductor.config, "fx_graph_cache", True)
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
def test_asymmetric_compilation_with_fx_cache(self):
from torch._dynamo.utils import counters
from torch._inductor.utils import fresh_cache
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
torch._dynamo.utils.clear_compilation_metrics()
device = f"{self.device_type}:{self.rank}"
pg = dist.distributed_c10d._get_default_group()
@torch.compile
def f(x):
y = 2 * x
return y.sum()
backend = pg._get_backend(torch.device(device))
backend._set_default_timeout(timedelta(seconds=5))
counters.clear()
x = torch.ones(4, device=device)
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
w = pg.allreduce(x)
w.wait()
torch.accelerator.synchronize(device)
torch._dynamo.reset()
if self.rank == 0:
with fresh_cache():
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
else:
f(x)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
w = pg.allreduce(x)
w.wait()
torch.accelerator.synchronize(device)
@requires_accelerator_dist_backend(["nccl", "xccl"])
@unittest.skipUnless(torch.accelerator.is_available(), "Requires accelerator")
class TestSingleProc(DynamoDistributedSingleProcTestCase):
"""
Test harness initializes dist process group.
Test simple things here since they are simpler to debug.
Use TestMultiProc for things that really need to run on multiple nodes
"""
device_type = (
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
)
def get_model(
self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
):
m = ToyModel(
in_feat=in_feat,
hidden_feat=hidden_feat,
out_feat=out_feat,
ctx_manager=ctx_manager,
).to(self.device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch.compile(ddp_m, backend="aot_eager")
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_inductor(self):
from torch.nn.parallel import DistributedDataParallel as DDP
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids)
ddp_m = torch.compile(ddp_m, backend="inductor")
outputs = ddp_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@patch.object(config, "optimize_ddp", True)
def test_graph_split(self):
assert config.optimize_ddp
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
# ensure compatibility with dynamo explain
explain_out = torch._dynamo.explain(ddp_m)(inputs)
break_reasons = explain_out.break_reasons
self.assertEqual(len(break_reasons), 3)
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@patch.object(config, "optimize_ddp", True)
def test_graph_split_ctx_manager(self):
"""
Ensures that we get the right number of splits and that the respective
context managers' effects are applied to the computation.
"""
for get_compiler in [
lambda: CheckSplitsCompiler(),
lambda: None,
]:
for ctx_manager, output_test in [
(
lambda: torch.autocast(
torch.device(self.device).type, torch.float16
),
lambda out: self.assertEqual(out.dtype, torch.float16),
),
(torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)),
(torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)),
]:
m, inputs, correct_outputs = self.get_model(
out_feat=1000,
hidden_feat=1000,
in_feat=1000,
ctx_manager=ctx_manager,
)
# inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
# hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
bucket_cap_mb = 3.5 # 4MB
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
compiler = get_compiler()
@torch.compile(backend=compiler.compile_fn if compiler else "aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
if compiler:
self.assertEqual(compiler.compiler_called, 4)
output_test(opt_outputs)
# ensure compatibility with dynamo explain
explain_out = torch._dynamo.explain(ddp_m)(inputs)
break_reasons = explain_out.break_reasons
self.assertEqual(len(break_reasons), 4)
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
@skipIfXpu # XPU device doesn't support flex_attention yet.
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_full_model_ddp(self):
class Model(torch.nn.Module):
def __init__(self, S, H, D):
super().__init__()
self.S = S
self.H = H
self.D = D
alibi_bias = self.generate_alibi_bias(H)
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
self.attention = flex_attention
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
self.project_v = torch.nn.Linear(H * D, H * D)
def forward(self, hidden_states):
batch_size, _, _ = hidden_states.size()
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
query = query.view(self.S, batch_size, self.H, self.D)
query = query.permute(1, 2, 0, 3)
key = key.view(self.S, batch_size, self.H, self.D)
key = key.permute(1, 2, 0, 3)
value = self.project_v(hidden_states)
value = value.view(self.S, batch_size, self.H, self.D)
value = value.permute(1, 2, 0, 3)
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
def generate_alibi_bias(self, num_heads):
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
return torch.tensor(alibi_bias)
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
bias = (q_idx - kv_idx) * self.alibi_bias[h]
return score + bias
B = 16
H = 12
S = 512
D = 64
model = Model(S, H, D)
model.to(self.device_type)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
model(hidden_states)
torch.accelerator.synchronize()
@skipIfXpu # XPU device doesn't support flex_attention yet.
@patch.object(config, "optimize_ddp", True)
def test_compiled_flex_attention_local_ddp(self):
class Model(torch.nn.Module):
def __init__(self, S, H, D):
super().__init__()
self.S = S
self.H = H
self.D = D
alibi_bias = self.generate_alibi_bias(H)
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
self.attention = torch.compile(flex_attention)
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
self.project_v = torch.nn.Linear(H * D, H * D)
def forward(self, hidden_states):
batch_size, _, _ = hidden_states.size()
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
query = query.view(self.S, batch_size, self.H, self.D)
query = query.permute(1, 2, 0, 3)
key = key.view(self.S, batch_size, self.H, self.D)
key = key.permute(1, 2, 0, 3)
value = self.project_v(hidden_states)
value = value.view(self.S, batch_size, self.H, self.D)
value = value.permute(1, 2, 0, 3)
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
def generate_alibi_bias(self, num_heads):
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
return torch.tensor(alibi_bias)
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
bias = (q_idx - kv_idx) * self.alibi_bias[h]
return score + bias
B = 16
H = 12
S = 512
D = 64
model = Model(S, H, D)
model.to(self.device_type)
model = torch.compile(model)
model = DDP(model, device_ids=self.device_ids)
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
model(hidden_states)
torch.accelerator.synchronize()
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor(self):
assert config.optimize_ddp
"""
Same as above, but using inductor backend.
We observed issues with inductor/fx interface in the past.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@torch._inductor.config.patch(
{"layout_optimization": True, "keep_output_stride": False}
)
@patch.object(config, "optimize_ddp", True)
def _test_graph_split_inductor_layout_optimizations_impl(self, context):
assert config.optimize_ddp
channel_dim = 512
# channel dim must be > 64 for inductor to do layout optimization and use NHWC
class ToyModelConv(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net = nn.Sequential(
*[
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
+ [
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
nn.ReLU(),
]
)
def forward(self, inputs):
return self.net(inputs)
def get_model():
m = ToyModelConv().to(self.device)
m.apply(init_weights)
inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
outputs = m(inputs)
return m, inputs, outputs
with context():
m, inputs, correct_outputs = get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="inductor")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_training(self):
self._test_graph_split_inductor_layout_optimizations_impl(
contextlib.nullcontext
)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_layout_optimizations_inference(self):
self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad)
@patch.object(config, "optimize_ddp", True)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_graph_split_inductor_transpose(self):
assert config.optimize_ddp
B = 100
N = 30
D = 50
K = 70
class Foo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear0 = nn.Linear(N, K)
self.linear1 = torch.nn.Linear(D * K, 2048)
def forward(self, x):
xt = x.transpose(2, 1)
xt = self.linear0(xt).flatten(1)
return self.linear1(xt)
mod = Foo().to(self.device)
compiled_mod = torch.compile(mod, backend="inductor")
ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids)
x = torch.randn((B, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x), ddp_compiled_mod(x)))
x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1)))
x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device)
self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2)))
@patch.object(config, "optimize_ddp", True)
def test_no_split(self):
"""
Ensures the DDPOptimizer returns a correct, compiled module without
introducing graph splits. (Based on model parameters fitting in the bucket)
"""
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 1)
@patch.object(config, "optimize_ddp", True)
def test_aot_autograd(self):
"""
Explicitly check AotAutograd family of compilers work,
since they require example inputs propagated between graph splits.
"""
m, inputs, correct_outputs = self.get_model()
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
@torch.compile(backend="aot_eager")
def opt_fn(inputs):
return ddp_m(inputs)
opt_outputs = opt_fn(inputs)
opt_outputs.sum().backward()
self.assertTrue(same(correct_outputs, opt_outputs))
@patch.object(config, "optimize_ddp", True)
def test_custom_layer(self):
"""
Just ensures that the appropriate number of splits happen (based on
bucket size and model parameters) - verifies the number of times
the user-provided compiler is called by the DDPOptimizer which is
doing the graph splitting
"""
m, inputs, correct_outputs = get_custom_model(self.device)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
check_splits_compiler = CheckSplitsCompiler()
@torch.compile(backend=check_splits_compiler.compile_fn)
def opt_fn(inputs):
return ddp_m(*inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 3)
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
def test_empty_graph_inductor(self):
def fn():
get_world_size = torch.distributed.distributed_c10d.get_world_size()
return (get_world_size,)
opt_fn = torch.compile(fn, backend="inductor")
res = None
try:
res = opt_fn()[0]
except Exception:
pass
self.assertEqual(res, 1)
@patch.object(config, "optimize_ddp", False)
def test_ignored_parameters(self):
"""
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
Hooks up graph-split optimizer manually so it can peek at internal state.
"""
m, inputs, correct_outputs = get_custom_model(self.device)
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
parameter_ids_to_ignore = [
id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore
]
check_splits_compiler = CheckSplitsCompiler()
ddp_optimizer = DDPOptimizer(
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
backend_compile_fn=check_splits_compiler.compile_fn,
)
@torch.compile(backend=ddp_optimizer.compile_fn)
def opt_fn(inputs):
return ddp_m(*inputs)
opt_outputs = opt_fn(inputs)
self.assertTrue(same(correct_outputs, opt_outputs))
self.assertEqual(check_splits_compiler.compiler_called, 2)
for b in ddp_optimizer.buckets:
for p_id in b.param_ids:
self.assertFalse(p_id in parameter_ids_to_ignore)
@patch.object(config, "optimize_ddp", True)
def test_higher_order_op(self):
from torch.utils.checkpoint import checkpoint
N = 1000
class InnerModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(N, N)
self.linear2 = torch.nn.Linear(N, N)
def forward(self, x):
a = self.linear1(x)
a = self.linear2(a)
return a
class MockModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.inner_mod1 = InnerModule()
self.inner_mod2 = InnerModule()
def forward(self, x):
a = checkpoint(self.inner_mod1, x, use_reentrant=False)
a = torch.cos(a)
a = checkpoint(self.inner_mod2, a, use_reentrant=False)
a = torch.cos(a)
return a
mod = MockModule().to(self.device_type)
mod = DDP(mod, bucket_cap_mb=1)
x = torch.randn(N, N, device=self.device_type, requires_grad=True)
args = (x,)
backend = "aot_eager"
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
torch.compile(mod, backend=cnt)(*args)
def test_fsdp_orig_params_assert(self):
# Test with basic FSDP wrapping (outer wrap around whole model)
m, inputs, _ = get_model(f"{self.device_type}:{self.rank}")
fsdp_m = FSDP(m, use_orig_params=False)
# Test is that this function call does not throw an exception.
fsdp_m = torch.compile(fsdp_m)
def test_fsdp_skip_guards(self):
"""
It's currently difficult to test dynamo guards. Most guards tests are indirect- modify something and
observe that the guard in question failed. In this case, since the FSDP guards were already deemed
useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to
make those guards fail. Instead, we observe the 'guard source' printed by dynamo's comptime print_guards
function.
Note: comptime prints the guards before the time they get installed or not installed, so in both cases
(skip or no skip) the same guards get printed. The difference is that in the skip case, they show up
with a special 'guard source' which will cause them to not be installed. So all we check for is the expected
guard source 'local_fsdp_module'.
"""
global GUARDS_FILE
GUARDS_FILE = StringIO()
for skip_guards, expected_guard_source in (
(True, "local_fsdp_module"),
(False, "local_unspecialized_nn_module"),
):
torch._dynamo.reset()
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
)
def forward(self, inputs):
out = self.net(inputs)
@comptime
def _(ctx):
ctx.print_guards(file=GUARDS_FILE)
return out
device = f"{self.device_type}:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
out_feat=5,
).to(device)
inputs = torch.rand(20, 10).to(device)
m.apply(init_weights)
correct_outputs = m(inputs)
fsdp_m = FSDP(m, use_orig_params=True)
with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards):
opt_m = torch.compile(fsdp_m, backend="aot_eager")
outputs = opt_m(inputs)
# far from an exhaustive check of all the expected guards, just check a couple of them.
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
).check(
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
).run(GUARDS_FILE.getvalue())
self.assertTrue(same(correct_outputs, outputs))
def test_fsdp_skip_register_attr_or_module(self):
"""
ensure FSDP module is not registered as attributes
in the fx graph
see `not source.guard_source().is_fsdp_module()`
before calling `register_attr_or_module`
in variables/builder.py
"""
class ToyModel(nn.Module):
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
)
def forward(self, inputs):
out = self.net(inputs)
return out
torch._dynamo.reset()
device = f"{self.device_type}:{self.rank}"
m = ToyModel(
in_feat=10,
hidden_feat=5000,
out_feat=5,
).to(device)
inputs = torch.rand(20, 10).to(device)
m.apply(init_weights)
correct_outputs = m(inputs)
fsdp_m = FSDP(m, use_orig_params=True)
def debug_compiler(gm, _):
for node in gm.graph.nodes:
if node.op == "get_attr":
for name in [
"l__self___net_0_weight",
"l__self___net_0_bias",
"l__self___net_2_weight",
"l__self___net_2_bias",
]:
self.assertFalse(
name in node.name,
f"FSDP module {name} should not be registered as attributes",
)
return gm
opt_m = torch.compile(fsdp_m, backend=debug_compiler)
outputs = opt_m(inputs)
self.assertTrue(same(correct_outputs, outputs))
def test_fsdp_dup_tensors_same_source(self):
"""
Tests that FSDP-managed modules' parameters and buffers with the same
source are de-duplicated, meaning that they are each only passed once
as a graph input.
"""
class DuplicateModule(nn.Module):
def __init__(self) -> None:
super().__init__()
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = torch.randn((3,), device=device_type)
self._buf = torch.nn.Buffer(
torch.randn((3,), requires_grad=False, device=device_type)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use `_param` and `_buf` each twice in this compiled forward
# to exercise if they are de-duplicated by TorchDynamo
z = x + self._buf + self._buf
z += self._param + self._param
return z
model = DuplicateModule()
fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
fsdp_model = torch.compile(fsdp_model, backend="aot_eager")
inp = torch.randn((2, 3), device=self.device_type)
local_out = model(inp)
fsdp_out = fsdp_model(inp)
self.assertEqual(local_out, fsdp_out)
@patch.object(config, "guard_nn_modules", True)
def test_fsdp_dup_tensors_diff_source(self):
"""
Tests that FSDP-managed modules' parameters and buffers with different
source do not result in incorrect AOTAutograd de-dup guards like
``a is b``, where ``a`` and ``b`` are certainly not the same. We check
this by checking for per-invocation recompiles.
"""
class BufModule(nn.Module):
def __init__(self) -> None:
super().__init__()
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._buf = nn.Buffer(
torch.randn((3,), requires_grad=False, device=device_type)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self._buf
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = nn.Parameter(torch.randn((1,), device=device_type))
self._buf_module = BufModule()
# Share the buffer, meaning same tensor but different source
self._buf = self._buf_module._buf
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Use the same buffer tensor twice in the compiled forward,
# including a data mutation to trigger de-dup logic
self._buf.mul_(2)
z = x + self._buf
z = self._buf_module(z)
z += self._param
return z
fsdp_model = FSDP(Model(), use_orig_params=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fsdp_model = torch.compile(fsdp_model, backend=cnt)
inp = torch.randn((2, 3), device=self.device_type)
for _ in range(15):
fsdp_model(inp)
# Check for no recompiles (if there were incorrect de-dup guards, then
# the frame count would be equal to the number of forward calls)
self.assertEqual(cnt.frame_count, 1)
def test_fsdp_staticmethod(self):
"""
Tests that Dynamo compiles staticmethods for FSDP-managed modules
correctly both when the staticmethod is invoked from the class and from
the object itself.
"""
class ModuleWithStaticMethod(nn.Module):
def __init__(self, use_self: bool):
super().__init__()
self._use_self = use_self
torch.manual_seed(42) # force `_param` to be deterministic
device_type = (
acc.type
if (acc := torch.accelerator.current_accelerator())
else "cpu"
)
self._param = nn.Parameter(torch.randn((3,), device=device_type))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._use_self:
z = self._add(x, self._param)
else:
z = ModuleWithStaticMethod._add(x, self._param)
z *= 2
return z
@staticmethod
def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
model = ModuleWithStaticMethod(False)
x = torch.randn((2, 3), device=self.device_type)
ref_out = model(x)
test_outs: list[torch.Tensor] = []
for use_self in (False, True):
model = ModuleWithStaticMethod(use_self)
fsdp_model = FSDP(model, use_orig_params=True)
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
fsdp_model = torch.compile(fsdp_model, backend=cnt)
test_outs.append(fsdp_model(x))
# Check for no recompiles, which could happen if incorrectly
# passing args to the staticmethod (e.g. doubly passing `self`)
# 3 is expected here for 1 forward.
# Graph 1 should be add and imul
self.assertEqual(cnt.frame_count, 1)
for test_out in test_outs:
self.assertEqual(test_out, ref_out)
def test_async_subclass_no_specialize(self):
cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
@torch.compile(backend=cnt, fullgraph=True, dynamic=True)
def f(x):
return x + 1
f(_maybe_wrap_tensor(torch.randn(10)))
f(_maybe_wrap_tensor(torch.randn(12)))
self.assertEqual(cnt.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()