mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
### Implementation of #151705 This PR introduces the initial implementation of native `tl.dot` support in Inductor, with the goal of generating Triton matmul kernels directly—without relying on predefined templates. To avoid complexity and ease the review process, I plan to split this work into two phases as outlined in #151705: 1. **Basic support** (this PR) 2. **Lazy broadcasting** for optimal performance (future PR) ### Summary of This PR This PR implements the basic functionality. It does **not** include lazy broadcasting, so the generated kernels may involve explicit `tl.reshape` and `tl.trans` operations before calling `tl.dot`, which introduces some overhead. ### Notable Changes 1. Adds a new config flag: `config.triton.enable_native_matmul` 2. Introduces a new `ops.dot` IR node in Inductor and lowers `aten.mm` and `aten.bmm` to it when native matmul is enabled 3. Enforces tililng suitable for matmul when the native matmul flag is enabled 4. Implements code generation for `ops.dot` 5. Adds Triton autotuning heuristics: for now, I’ve copied the configuration from the existing matmul templates. However, this may not be optimal—it currently takes a long time to tune, and I think there must be a better way to tackle this. @eellison @jansel @PaulZhang12 @shunting314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157743 Approved by: https://github.com/jansel
2114 lines
79 KiB
Python
2114 lines
79 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
import copy
|
|
import functools
|
|
import random
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from datetime import timedelta
|
|
from io import StringIO
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.logging
|
|
import torch._dynamo.test_case
|
|
import torch.distributed as dist
|
|
import torch.optim as optim
|
|
from torch import nn
|
|
from torch._C import FileCheck
|
|
from torch._dynamo import config
|
|
from torch._dynamo.backends.distributed import DDPOptimizer
|
|
from torch._dynamo.comptime import comptime
|
|
from torch._dynamo.testing import collect_results
|
|
from torch._dynamo.utils import same
|
|
from torch._higher_order_ops.wrap import tag_activation_checkpoint
|
|
from torch.compiler import set_enable_guard_collectives
|
|
from torch.distributed._functional_collectives import _maybe_wrap_tensor
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
from torch.distributed.fsdp.wrap import (
|
|
lambda_auto_wrap_policy,
|
|
transformer_auto_wrap_policy,
|
|
)
|
|
from torch.nn.attention.flex_attention import flex_attention
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
|
)
|
|
from torch.testing._internal.common_distributed import (
|
|
_dynamo_dist_per_rank_init,
|
|
DynamoDistributedMultiProcTestCase,
|
|
DynamoDistributedSingleProcTestCase,
|
|
import_transformers_or_skip,
|
|
requires_accelerator_dist_backend,
|
|
skip_if_lt_x_gpu,
|
|
)
|
|
from torch.testing._internal.common_utils import skipIfXpu
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
from torch.testing._internal.triton_utils import requires_cuda_and_triton
|
|
|
|
|
|
def reset_rng_state():
|
|
torch.manual_seed(1337)
|
|
random.seed(1337)
|
|
np.random.seed(1337)
|
|
|
|
|
|
def init_weights(m):
|
|
if isinstance(m, nn.Linear):
|
|
nn.init.xavier_uniform_(m.weight)
|
|
m.bias.data.fill_(0.01)
|
|
|
|
|
|
@contextmanager
|
|
def enable_guard_collectives():
|
|
old = set_enable_guard_collectives(True)
|
|
try:
|
|
yield
|
|
finally:
|
|
set_enable_guard_collectives(old)
|
|
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
|
|
super().__init__()
|
|
self.ctx_manager = ctx_manager
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
if self.ctx_manager is not None:
|
|
with self.ctx_manager():
|
|
return self.net(inputs)
|
|
else:
|
|
return self.net(inputs)
|
|
|
|
|
|
def get_model(
|
|
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
|
|
):
|
|
m = ToyModel(
|
|
in_feat=in_feat,
|
|
hidden_feat=hidden_feat,
|
|
out_feat=out_feat,
|
|
ctx_manager=ctx_manager,
|
|
).to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(bsz, in_feat).to(device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
|
|
class MutatingModel(nn.Module):
|
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None):
|
|
super().__init__()
|
|
self.ctx_manager = ctx_manager
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
|
|
)
|
|
self.state = 1
|
|
|
|
def forward(self, inputs):
|
|
self.state = 2
|
|
return self.net(inputs) * self.state
|
|
|
|
|
|
def get_mutating_model(
|
|
device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
|
|
):
|
|
m = MutatingModel(
|
|
in_feat=in_feat,
|
|
hidden_feat=hidden_feat,
|
|
out_feat=out_feat,
|
|
ctx_manager=ctx_manager,
|
|
).to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(bsz, in_feat).to(device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
|
|
class ForcedGetAttrMod(torch.nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
self.__dict__["forced_linear"] = torch.nn.Linear(1, 1).to(device=device)
|
|
self.counter = 0
|
|
|
|
def forward(self, x):
|
|
self.counter += 1
|
|
return x * self.linear(x) * self.forced_linear.weight
|
|
|
|
|
|
def get_forced_getattr_module(device):
|
|
mod = ForcedGetAttrMod(device).to(device=device)
|
|
x = torch.randn(1, 1, device=device)
|
|
return mod, x, mod(x)
|
|
|
|
|
|
class ToyInnerModel(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layers = [nn.Linear(100, 100), nn.Linear(100, 100)]
|
|
self.layers = nn.Sequential(*self.layers)
|
|
|
|
def forward(self, inputs):
|
|
return self.layers(inputs)
|
|
|
|
|
|
class ToyOuterModel(nn.Module):
|
|
def __init__(self, device):
|
|
super().__init__()
|
|
self.layers = [ToyInnerModel().to(device) for _ in range(2)]
|
|
self.layers = nn.Sequential(
|
|
self.layers[0], nn.ReLU(), self.layers[1], nn.ReLU()
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
return self.layers(inputs)
|
|
|
|
|
|
def get_toy_model_for_activation_checkpointing(device):
|
|
m = ToyOuterModel(device).to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(100, 100).to(device)
|
|
return m, inputs
|
|
|
|
|
|
def find_first_node(gm, func):
|
|
for node in gm.graph.nodes:
|
|
if node.target is func:
|
|
return node
|
|
return None
|
|
|
|
|
|
def apply_fsdp_with_checkpointing(
|
|
model, wrap_policy, checkpoint_policy, use_activation_checkpointing=True
|
|
):
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
apply_activation_checkpointing,
|
|
checkpoint_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
|
|
model = FSDP(
|
|
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
|
|
)
|
|
if use_activation_checkpointing:
|
|
checkpoint_wrapper_fn = functools.partial(
|
|
checkpoint_wrapper,
|
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
|
)
|
|
apply_activation_checkpointing(
|
|
model,
|
|
checkpoint_wrapper_fn=checkpoint_wrapper_fn,
|
|
check_fn=checkpoint_policy,
|
|
)
|
|
return model
|
|
|
|
|
|
def get_custom_model(device):
|
|
class MyCustomLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x):
|
|
tmp = torch.mm(x, self.weight.t())
|
|
# test an edge case where torch.where.scalar was decomposed to aten.where.self(tensor, tensor, tensor)
|
|
# and the tensors T(0.4) and T(0.5) were not wrapped in FakeTensors during DDPOptimizer compilation
|
|
return tmp + torch.where(tmp < 0.5, 0.3, 0.6)
|
|
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(512, 512)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
mods = [
|
|
(MyLinear(), torch.nn.ReLU()),
|
|
# sandwich the custom in the middle so it comes before and after
|
|
(MyCustomLinear(), torch.nn.ReLU()),
|
|
(MyLinear(), torch.nn.ReLU()),
|
|
]
|
|
self.seq = torch.nn.Sequential(*[x for items in mods for x in items])
|
|
|
|
def forward(self, x, y):
|
|
# test special case where the 0th bucket (layers close to graph input) is at capacity, which would
|
|
# trigger a new bucket, but there are only trivial ops without parameters to put into the new bucket.
|
|
# optimize this case by fusing that 'empty bucket' back together with the previous full one
|
|
return self.seq(x + y)
|
|
|
|
m = MyModule().to(device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand((512, 512)).to(device)
|
|
# test duplicated inputs
|
|
inputs = (inputs, inputs)
|
|
correct_outputs = m(*inputs)
|
|
return m, inputs, correct_outputs
|
|
|
|
|
|
def get_hf_bert(rank):
|
|
# Note: use @import_transformers_or_skip on your test case if you use this
|
|
# in a multiprocessing test
|
|
try:
|
|
from transformers import AutoModelForMaskedLM, BertConfig
|
|
except ImportError as e:
|
|
raise unittest.SkipTest("Unable to import transformers") from e
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
|
)
|
|
batch_size, max_length, config, device = (
|
|
4,
|
|
512,
|
|
BertConfig(),
|
|
f"{device_type}:{rank}",
|
|
)
|
|
model = AutoModelForMaskedLM.from_config(config).to(device)
|
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(device)
|
|
decoder_ids = torch.randint(0, config.vocab_size, (batch_size, max_length)).to(
|
|
device
|
|
)
|
|
inputs = {"input_ids": input_ids, "labels": decoder_ids}
|
|
model.train()
|
|
return model, inputs
|
|
|
|
|
|
class CheckSplitsCompiler:
|
|
def __init__(self) -> None:
|
|
self.compiler_called = 0
|
|
|
|
def compile_fn(self, gm, example_inputs):
|
|
self.compiler_called += 1
|
|
return gm
|
|
|
|
|
|
# This simulates DDP, but it doesn't actually do any process communication;
|
|
# it just has enough properties so that the dynamo distributed optimization is
|
|
# able to optimize. Feel free to simulate more properties as necessary. The
|
|
# other important thing is patching _active_ddp_module, which is what actually
|
|
# triggers DDP optimization
|
|
class FakeDDP(nn.Module):
|
|
def __init__(self, module, bucket_cap_mb=25):
|
|
super().__init__()
|
|
self.module = module
|
|
self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)
|
|
|
|
@contextmanager
|
|
def _inside_ddp_forward(self):
|
|
DDP._active_ddp_module = self
|
|
try:
|
|
yield
|
|
finally:
|
|
DDP._active_ddp_module = None
|
|
|
|
def forward(self, *inputs, **kwargs):
|
|
if not DDP._active_ddp_module:
|
|
with self._inside_ddp_forward():
|
|
return self.module.forward(*inputs, **kwargs)
|
|
else:
|
|
return self.module.forward(*inputs, **kwargs)
|
|
|
|
|
|
def run_hf_bert_ddp(self, model, inputs, backend):
|
|
reset_rng_state()
|
|
correct_outputs = model(**inputs)
|
|
correct_loss = correct_outputs.loss
|
|
correct_loss.backward()
|
|
|
|
reset_rng_state()
|
|
opt_model = torch.compile(model, backend=backend)
|
|
opt_outputs = opt_model(**inputs)
|
|
opt_loss = opt_outputs.loss
|
|
opt_loss.backward()
|
|
|
|
inputs_flat = [inputs[k] for k in inputs]
|
|
correct_results = collect_results(
|
|
model, correct_outputs.logits, correct_loss, inputs_flat
|
|
)
|
|
opt_results = collect_results(opt_model, opt_outputs.logits, opt_loss, inputs_flat)
|
|
self.assertTrue(same(correct_results, opt_results))
|
|
|
|
|
|
class TestFakeDistributedSingleProc(torch._dynamo.test_case.TestCase):
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
@unittest.skipIf(
|
|
torch._inductor.config.triton.native_matmul,
|
|
"FIXME : native matmul fails. RuntimeError: Cannot access data pointer of Tensor",
|
|
)
|
|
def test_hf_bert_ddp_inductor(self):
|
|
model, inputs = get_hf_bert(0)
|
|
model = FakeDDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "inductor")
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_hf_bert_ddp_aot_eager(self):
|
|
model, inputs = get_hf_bert(0)
|
|
model = FakeDDP(model)
|
|
run_hf_bert_ddp(self, model, inputs, "aot_eager")
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_issue90375(self):
|
|
class Model(nn.Module):
|
|
def forward(self):
|
|
return torch.randn(3) * torch.randn(3)
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
opt_model = torch.compile(model, backend="aot_eager")
|
|
opt_model()
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_symbol_splitting(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x):
|
|
x = torch.cat([x, x])
|
|
y = x @ self.weight1
|
|
z = x + y @ self.weight2
|
|
return z
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
opt_model = torch.compile(dynamic=True)(model)
|
|
opt_model(torch.randn(20, 512))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_ddp_optimizer_inductor_strides_dont_specialize(self):
|
|
class Model(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc_0 = nn.Linear(768, 768)
|
|
self.fc_1 = nn.Linear(768, 768)
|
|
|
|
def forward(self, x):
|
|
x = self.fc_0(x)
|
|
x = self.fc_1(x)
|
|
return x
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
inp = torch.randn((16, 18, 768))
|
|
inp2 = torch.randn((16, 20, 768))
|
|
|
|
torch._dynamo.mark_dynamic(inp, 1)
|
|
torch._dynamo.mark_dynamic(inp2, 1)
|
|
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
torch._dynamo.reset()
|
|
try:
|
|
DDP._active_ddp_module = model
|
|
opt_model = torch.compile(model)
|
|
self.assertEqual(0, len(torch._dynamo.utils.get_compilation_metrics()))
|
|
opt_model(inp)
|
|
compile_count_before = len(torch._dynamo.utils.get_compilation_metrics())
|
|
opt_model(inp2)
|
|
compile_count_after = len(torch._dynamo.utils.get_compilation_metrics())
|
|
# no recompiles
|
|
self.assertEqual(compile_count_before, compile_count_after)
|
|
finally:
|
|
DDP._active_ddp_module = None
|
|
|
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
|
def test_unbacked_symbol_splitting_direct(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x, y):
|
|
u0, _ = y.tolist()
|
|
x = torch.cat([x, x])
|
|
y = x @ self.weight1
|
|
z = (x + y @ self.weight2) * u0
|
|
return z
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
opt_model = torch.compile(dynamic=True)(model)
|
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
|
|
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
|
def test_unbacked_symbol_splitting_indirect(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x, y):
|
|
u0, _ = y.tolist()
|
|
a = torch.ones(u0)
|
|
x = torch.cat([x, x])
|
|
y = x @ self.weight1
|
|
z = (x + y @ self.weight2) * a.sum()
|
|
return z
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
opt_model = torch.compile(dynamic=True)(model)
|
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
|
|
|
@config.patch(optimize_ddp=True, capture_scalar_outputs=True)
|
|
def test_unbacked_symbol_splitting_torture_multi(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight3 = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x, y):
|
|
# partition one (contains the u0 def)
|
|
u0, _ = y.tolist()
|
|
x = torch.cat([x, x])
|
|
y1 = x @ self.weight1
|
|
# partition two (contains the variable)
|
|
y2 = y1 @ self.weight2
|
|
a = torch.ones(u0)
|
|
# partition three
|
|
z = (x + y2 @ self.weight3) * a.sum()
|
|
return z
|
|
|
|
model = Model()
|
|
model = FakeDDP(model, bucket_cap_mb=1)
|
|
|
|
opt_model = torch.compile(dynamic=True)(model)
|
|
opt_model(torch.randn(20, 512), torch.tensor([12, 13]))
|
|
|
|
@config.patch(optimize_ddp=True, capture_dynamic_output_shape_ops=True)
|
|
def test_unbacked_symbol_splitting_no_binding(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = nn.Parameter(torch.randn(512, 512))
|
|
self.weight2 = nn.Parameter(torch.randn(512, 512))
|
|
|
|
def forward(self, x, y):
|
|
nz = y.nonzero()
|
|
x = torch.cat([x, x])
|
|
y = x @ self.weight1
|
|
z = (x + y @ self.weight2) * (nz + 1).sum()
|
|
return z
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
|
|
opt_model = torch.compile(dynamic=True)(model)
|
|
opt_model(torch.randn(20, 512), torch.tensor([0.0, 12.0, 0.0, 11.0]))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_call_method_forward(self):
|
|
class Model(nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
layers = []
|
|
for _ in range(2):
|
|
layer = nn.ModuleList(
|
|
[
|
|
nn.LayerNorm(96),
|
|
nn.MultiheadAttention(
|
|
embed_dim=96, num_heads=4, batch_first=True
|
|
),
|
|
]
|
|
)
|
|
layers.append(layer)
|
|
self.layers = nn.ModuleList(layers)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# x: [Batch, Freq, Time, Feature]
|
|
B, F, T, H = x.shape
|
|
for m in self.layers:
|
|
x = x.reshape(B * F, T, H)
|
|
x = m[0](x)
|
|
x, _ = m[1].forward(x, x, x)
|
|
x = x.reshape(B, F, T, H)
|
|
return x
|
|
|
|
model = Model()
|
|
model = FakeDDP(model)
|
|
opt_model = torch.compile(model)
|
|
opt_model(torch.randn(2, 129, 100, 96))
|
|
|
|
|
|
# Are these tests failing? Check and see if TestFakeDistributedSingleProc has a
|
|
# single process version; if it's just a problem in the Dynamo distributed
|
|
# # optimizer, you should be able to repro it single process!
|
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
|
class TestMultiProc(DynamoDistributedMultiProcTestCase):
|
|
"""
|
|
Note: MultiProcTestCase spawns processes per test and is slow.
|
|
Prefer MultiThreadedTestCase for most tests. Perhaps use this one
|
|
sparingly for integration tests.
|
|
"""
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
|
)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
|
|
def test_ddp_baseline_aot_eager_multiprocess(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
self.assertFalse(config.optimize_ddp)
|
|
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
|
m = DDP(m, device_ids=[self.rank])
|
|
m = torch.compile(m, backend="aot_eager")
|
|
outputs = m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
def _test_hf_bert_ddp_inductor(self, static_graph):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
model, inputs = get_hf_bert(self.rank)
|
|
model = DDP(model, static_graph=static_graph)
|
|
run_hf_bert_ddp(self, model, inputs, "inductor")
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
def test_hf_bert_ddp_inductor(self):
|
|
self._test_hf_bert_ddp_inductor(static_graph=False)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
def test_hf_bert_ddp_inductor_static_graph(self):
|
|
self._test_hf_bert_ddp_inductor(static_graph=True)
|
|
|
|
def _test_hf_bert_aot_eager(self, static_graph):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
model, inputs = get_hf_bert(self.rank)
|
|
model = DDP(model, static_graph=static_graph)
|
|
run_hf_bert_ddp(self, model, inputs, "aot_eager")
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
|
|
def test_hf_bert_ddp_aot_eager(self):
|
|
self._test_hf_bert_aot_eager(static_graph=False)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@import_transformers_or_skip()
|
|
@config.patch(optimize_ddp=True, enable_compiler_collectives=True)
|
|
def test_hf_bert_ddp_aot_eager_static_graph(self):
|
|
self._test_hf_bert_aot_eager(static_graph=True)
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(optimize_ddp=False, enable_compiler_collectives=True)
|
|
def test_ddp_activation_checkpointing(self):
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
|
apply_activation_checkpointing,
|
|
checkpoint_wrapper,
|
|
CheckpointImpl,
|
|
)
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(64, 32)
|
|
self.fc2 = torch.nn.Linear(32, 16)
|
|
self.fc3 = torch.nn.Linear(16, 8)
|
|
|
|
def forward(self, inp):
|
|
return self.fc3(self.fc2(self.fc1(inp)))
|
|
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
self.assertFalse(config.optimize_ddp)
|
|
model = MyModel().to(device=self.device_type)
|
|
|
|
# Activation checkpointing for Linear layers.
|
|
non_reentrant_wrapper = functools.partial(
|
|
checkpoint_wrapper,
|
|
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
|
|
)
|
|
check_fn = lambda submodule: isinstance( # noqa: E731
|
|
submodule, torch.nn.Linear
|
|
)
|
|
apply_activation_checkpointing(
|
|
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
|
|
)
|
|
|
|
model = DDP(model)
|
|
x = torch.randn(10, 64).to(self.device_type)
|
|
correct_outputs = model(x)
|
|
|
|
opt_model = torch.compile(model)
|
|
outputs = opt_model(x)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_fsdp_aot_eager(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
# Test with recursive wrapping, nested FSDP around each Linear
|
|
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
|
fsdp_m = FSDP(
|
|
m,
|
|
auto_wrap_policy=functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
|
|
),
|
|
use_orig_params=True,
|
|
)
|
|
fsdp_m = torch.compile(fsdp_m, backend="aot_eager")
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@skip_if_lt_x_gpu(2)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@requires_cuda_and_triton
|
|
def test_ddp_optimizer_cudagraph(self):
|
|
class Net(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# need a large channel to trigger ddp optimizer split module
|
|
self.CHANNELS = 640
|
|
self.convi = nn.Conv2d(46, self.CHANNELS, 3, padding=1, bias=False)
|
|
self.convp = nn.Conv2d(
|
|
self.CHANNELS, self.CHANNELS, 1, padding=0, bias=False
|
|
)
|
|
self.bni = nn.BatchNorm2d(self.CHANNELS)
|
|
|
|
def forward(self, bitmap_channels):
|
|
x = self.convi(bitmap_channels)
|
|
x = self.bni(x)
|
|
x = self.convp(x)
|
|
return x
|
|
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
net = Net().to(self.rank)
|
|
optimizer = torch.optim.SGD(
|
|
net.parameters(),
|
|
lr=5e-2,
|
|
)
|
|
|
|
net = DDP(net, device_ids=[self.rank])
|
|
opt_net = torch.compile(net, mode="reduce-overhead")
|
|
opt_net.train()
|
|
|
|
for _ in range(10):
|
|
optimizer.zero_grad()
|
|
data = torch.randn((16, 46, 8, 8), dtype=torch.float32, device="cuda")
|
|
opt_net(data).sum().backward()
|
|
|
|
# 2 fwd and 2 bwd graph such that 4 graphs in total
|
|
graph_id = (
|
|
torch._inductor.cudagraph_trees.get_container(self.rank)
|
|
.tree_manager.new_graph_id()
|
|
.id
|
|
)
|
|
self.assertTrue(graph_id == 4)
|
|
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_fsdp_setattr(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
m, inputs, correct_outputs = get_mutating_model(
|
|
f"{self.device_type}:{self.rank}"
|
|
)
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
self.assertEqual(len(counters["graph_break"]), 1)
|
|
first_graph_break = list(counters["graph_break"].keys())[0] # noqa: RUF015
|
|
self.assertIn("setattr() on Tensor.requires_grad", first_graph_break)
|
|
|
|
@config.patch(inline_inbuilt_nn_modules=False)
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_fsdp_unspecialized_forced_getattr_no_inline(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
m, inputs, correct_outputs = get_forced_getattr_module(
|
|
f"{self.device_type}:{self.rank}"
|
|
)
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
def test_fsdp_unspecialized_forced_getattr_inline(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
m, inputs, correct_outputs = get_forced_getattr_module(
|
|
f"{self.device_type}:{self.rank}"
|
|
)
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch.compile(fsdp_m, backend="eager", fullgraph=False)
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_fsdp_inductor(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
fsdp_m = torch.compile(fsdp_m, backend="inductor")
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
# Test with recursive wrapping, nested FSDP around each Linear
|
|
m, inputs, correct_outputs = get_model(f"{self.device_type}:{self.rank}")
|
|
fsdp_m = FSDP(
|
|
m,
|
|
auto_wrap_policy=functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(nn.Linear,)
|
|
),
|
|
use_orig_params=True,
|
|
)
|
|
fsdp_m = torch.compile(fsdp_m, backend="inductor")
|
|
outputs = fsdp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@skip_if_lt_x_gpu(1)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_fsdp_activation_checkpointing(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
model, inputs = get_toy_model_for_activation_checkpointing(
|
|
f"{self.device_type}:{self.rank}"
|
|
)
|
|
is_inner = lambda module: isinstance(module, ToyInnerModel) # noqa: E731
|
|
wrap_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=is_inner)
|
|
model = apply_fsdp_with_checkpointing(model, wrap_policy, is_inner)
|
|
correct_outputs = model(inputs)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
opt_model = torch.compile(model, backend=cnt)
|
|
outputs = opt_model(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
# Each FSDP module is a separate graph
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertTrue(
|
|
find_first_node(cnt.graphs[0], tag_activation_checkpoint) is not None
|
|
)
|
|
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
|
|
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
@config.patch(enable_compiler_collectives=True)
|
|
@unittest.skipIf(
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
|
"Inaccurate results with fused SDPA kernels",
|
|
)
|
|
def test_hf_bert_fsdp(self):
|
|
def apply_fsdp(model, wrap_policy):
|
|
model = FSDP(
|
|
copy.deepcopy(model), auto_wrap_policy=wrap_policy, use_orig_params=True
|
|
)
|
|
return model
|
|
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
for wrap_policy, test_instance in (
|
|
(None, "FSDP without recursive wrapping"),
|
|
):
|
|
print(f"Running hf_bert test for {test_instance}")
|
|
model, inputs = get_hf_bert(self.rank)
|
|
reset_rng_state()
|
|
eager_model = apply_fsdp(model, wrap_policy)
|
|
correct_outputs = eager_model(**inputs)
|
|
correct_loss = correct_outputs.loss
|
|
correct_loss.backward()
|
|
|
|
reset_rng_state()
|
|
opt_model = apply_fsdp(model, wrap_policy)
|
|
opt_model = torch.compile(opt_model, backend="inductor")
|
|
opt_outputs = opt_model(**inputs)
|
|
opt_loss = opt_outputs.loss
|
|
opt_loss.backward()
|
|
|
|
inputs_flat = [inputs[k] for k in inputs]
|
|
correct_results = collect_results(
|
|
eager_model, correct_outputs.logits, correct_loss, inputs_flat
|
|
)
|
|
opt_results = collect_results(
|
|
opt_model, opt_outputs.logits, opt_loss, inputs_flat
|
|
)
|
|
self.assertTrue(same(correct_results, opt_results))
|
|
|
|
@import_transformers_or_skip()
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
# TODO(whc) Investigate why cudagraphs breaks inductor+fsdp for hf_bert
|
|
@patch.object(torch._inductor.config.triton, "cudagraphs", False)
|
|
@patch.object(torch._inductor.config, "fallback_random", True)
|
|
@config.patch(guard_nn_modules=True, enable_compiler_collectives=True)
|
|
def test_hf_bert_fsdp_activation_checkpointing(self):
|
|
from transformers.models.bert.modeling_bert import BertLayer
|
|
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
for wrap_policy, test_instance in (
|
|
(
|
|
functools.partial(
|
|
transformer_auto_wrap_policy, transformer_layer_cls=(BertLayer,)
|
|
),
|
|
"FSDP with recursive wrapping BertLayer instances",
|
|
),
|
|
):
|
|
print(
|
|
f"Running hf_bert_activation_checkpointing test for {test_instance}"
|
|
)
|
|
model, inputs = get_hf_bert(self.rank)
|
|
check_fn = lambda submodule: isinstance( # noqa: E731
|
|
submodule, BertLayer
|
|
)
|
|
reset_rng_state()
|
|
eager_model = apply_fsdp_with_checkpointing(
|
|
model, wrap_policy, check_fn
|
|
)
|
|
correct_outputs = eager_model(**inputs)
|
|
correct_loss = correct_outputs.loss
|
|
correct_loss.backward()
|
|
|
|
reset_rng_state()
|
|
opt_model = apply_fsdp_with_checkpointing(model, wrap_policy, check_fn)
|
|
opt_model = torch.compile(opt_model, backend="inductor")
|
|
opt_outputs = opt_model(**inputs)
|
|
opt_loss = opt_outputs.loss
|
|
opt_loss.backward()
|
|
|
|
inputs_flat = [inputs[k] for k in inputs]
|
|
correct_results = collect_results(
|
|
eager_model, correct_outputs.logits, correct_loss, inputs_flat
|
|
)
|
|
opt_results = collect_results(
|
|
opt_model, opt_outputs.logits, opt_loss, inputs_flat
|
|
)
|
|
self.assertTrue(same(correct_results, opt_results))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_automatic_dynamic_tensor(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self, input_size, output_size):
|
|
super().__init__()
|
|
self.linear = nn.Linear(input_size, output_size)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
model = SimpleModel(10, 2).to(self.rank)
|
|
model.forward = torch.compile(model.forward)
|
|
ddp_model = DDP(model, device_ids=[self.rank])
|
|
|
|
loss_fn = nn.CrossEntropyLoss()
|
|
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
|
|
|
|
def B(s):
|
|
return [torch.randn(s, 10), torch.randint(0, 2, (s,))]
|
|
|
|
if self.rank == 0:
|
|
dataloader = [B(5), B(8), B(6)]
|
|
else:
|
|
dataloader = [B(6), B(6), B(3)]
|
|
|
|
for data, labels in dataloader:
|
|
data, labels = data.to(self.rank), labels.to(self.rank)
|
|
optimizer.zero_grad()
|
|
output = ddp_model(data)
|
|
loss = loss_fn(output, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
# Number of compiles same on all nodes
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_automatic_dynamic_scalar(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
# TODO: This should be possible to do inside the function, but
|
|
device = f"{self.device_type}:{self.rank}"
|
|
|
|
@torch.compile()
|
|
def f(x, y):
|
|
return x + torch.ones(y, device=device).sum()
|
|
|
|
if self.rank == 0:
|
|
dataloader = [3, 3, 7]
|
|
else:
|
|
dataloader = [3, 4, 9]
|
|
|
|
for data in dataloader:
|
|
f(torch.randn(5, device=self.rank), data)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
# Number of compiles same on all nodes
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_automatic_dynamic_speculation_divergence(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(x, y):
|
|
zx = x.shape # noqa: F841
|
|
zy = y.shape # noqa: F841
|
|
return x.sum() + y.sum()
|
|
|
|
if self.rank == 0:
|
|
dataloader = [4, 4]
|
|
else:
|
|
dataloader = [3, 4]
|
|
|
|
for data in dataloader:
|
|
f(
|
|
torch.randn(data, device=self.rank),
|
|
torch.randn(data, device=self.rank),
|
|
)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
# Number of compiles same on all nodes
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_graph_break_empty_graph_still_collective(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(x, y):
|
|
z = y # noqa: F841
|
|
print("woof")
|
|
zx = x.shape # noqa: F841
|
|
zy = y.shape # noqa: F841
|
|
return x.sum() + y.sum()
|
|
|
|
if self.rank == 0:
|
|
dataloader = [5, 5, 6]
|
|
else:
|
|
dataloader = [3, 4, 5]
|
|
|
|
for data in dataloader:
|
|
f(
|
|
torch.randn(data, device=self.rank),
|
|
torch.randn(data, device=self.rank),
|
|
)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
# Number of compiles same on all nodes
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_dim_mismatch(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(x, y):
|
|
zx = x.shape # noqa: F841
|
|
zy = y.shape # noqa: F841
|
|
return x.sum() + y.sum()
|
|
|
|
if self.rank == 0:
|
|
dataloader = [[4, 2]]
|
|
else:
|
|
dataloader = [[3]]
|
|
|
|
for data in dataloader:
|
|
f(
|
|
torch.randn(data, device=self.rank),
|
|
torch.randn(data, device=self.rank),
|
|
)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_missing_source(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(rank, xs):
|
|
return xs[rank].sum()
|
|
|
|
xs = []
|
|
for _ in range(self.world_size):
|
|
xs.append(torch.randn(10, device=self.rank))
|
|
|
|
f(self.rank, xs)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_scalar_missing_source(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(rank, xs):
|
|
return torch.tensor(xs[rank], device=self.rank)
|
|
|
|
xs = []
|
|
for i in range(self.world_size):
|
|
xs.append(10 + i)
|
|
|
|
f(self.rank, xs)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@config.patch(enable_compiler_collectives=True)
|
|
def test_compiler_collectives_type_mismatch(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(x):
|
|
if isinstance(x, int):
|
|
return torch.tensor(x, device=self.rank)
|
|
else:
|
|
return x.sum()
|
|
|
|
if self.rank == 0:
|
|
x = torch.randn(10, device=self.rank)
|
|
else:
|
|
x = 12
|
|
f(x)
|
|
|
|
# This deadlocks, I guess we don't support this
|
|
"""
|
|
if self.rank == 0:
|
|
x = torch.randn(12, device=self.rank)
|
|
else:
|
|
x = 10
|
|
f(x)
|
|
"""
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@enable_guard_collectives()
|
|
def test_guard_collective(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile()
|
|
def f(x):
|
|
return x.sum()
|
|
|
|
x = torch.randn(10, device=self.rank)
|
|
f(x)
|
|
|
|
if self.rank == 0:
|
|
x = torch.randn(10, device=self.rank)
|
|
else:
|
|
x = torch.randn(12, device=self.rank) # recompile on one rank
|
|
f(x)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_get_pg_attr(self):
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
pg = dist.distributed_c10d._get_default_group()
|
|
|
|
device = f"{self.device_type}:{self.rank}"
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def f(x):
|
|
if dist.distributed_c10d._rank_not_in_group(pg):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.ones(4, device=device)
|
|
self.assertEqual(f(x), x - 1)
|
|
|
|
pg = dist.distributed_c10d.GroupMember.NON_GROUP_MEMBER
|
|
self.assertEqual(f(x), x + 1)
|
|
|
|
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "fx_graph_cache", False)
|
|
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
|
|
def test_asymmetric_compilation(self):
|
|
from torch._dynamo.comptime import comptime
|
|
|
|
with _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
device = f"{self.device_type}:{self.rank}"
|
|
|
|
pg = dist.distributed_c10d._get_default_group()
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
sleep_time = 5
|
|
|
|
@torch.compile(backend=cnt)
|
|
def f(x):
|
|
if self.rank == 0:
|
|
comptime.sleep(sleep_time)
|
|
|
|
y = 2 * x
|
|
return y.sum()
|
|
|
|
backend = pg._get_backend(torch.device(device))
|
|
backend._set_default_timeout(timedelta(seconds=sleep_time - 2))
|
|
|
|
x = torch.ones(4, device=device)
|
|
|
|
# NCCL startup is lazy
|
|
w = pg.allreduce(x)
|
|
w.wait()
|
|
|
|
f(x)
|
|
if self.rank != 0:
|
|
# test fails with NCCL timeout without this line
|
|
dist.distributed_c10d._add_ephemeral_timeout_for_all_pgs(
|
|
timedelta(seconds=sleep_time)
|
|
)
|
|
|
|
w = pg.allreduce(x)
|
|
w.wait()
|
|
torch.accelerator.synchronize(device)
|
|
|
|
metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
# Number of compiles same on all nodes
|
|
res = [None] * self.world_size
|
|
torch.distributed.all_gather_object(res, len(metrics))
|
|
for r in res[1:]:
|
|
self.assertEqual(res[0], r)
|
|
|
|
@skipIfXpu # ProcessGroupXCCL doesn't support _set_default_timeout yet.
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(torch._inductor.config, "fx_graph_cache", True)
|
|
@patch.object(torch._inductor.config, "fx_graph_remote_cache", False)
|
|
@patch.object(torch._inductor.config, "sleep_sec_TESTING_ONLY", 10)
|
|
def test_asymmetric_compilation_with_fx_cache(self):
|
|
from torch._dynamo.utils import counters
|
|
from torch._inductor.utils import fresh_cache
|
|
|
|
with fresh_cache(), _dynamo_dist_per_rank_init(self.rank, self.world_size):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
device = f"{self.device_type}:{self.rank}"
|
|
|
|
pg = dist.distributed_c10d._get_default_group()
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
y = 2 * x
|
|
return y.sum()
|
|
|
|
backend = pg._get_backend(torch.device(device))
|
|
backend._set_default_timeout(timedelta(seconds=5))
|
|
counters.clear()
|
|
|
|
x = torch.ones(4, device=device)
|
|
|
|
f(x)
|
|
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
|
|
|
w = pg.allreduce(x)
|
|
w.wait()
|
|
torch.accelerator.synchronize(device)
|
|
torch._dynamo.reset()
|
|
|
|
if self.rank == 0:
|
|
with fresh_cache():
|
|
f(x)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
|
else:
|
|
f(x)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
|
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
|
|
|
|
w = pg.allreduce(x)
|
|
w.wait()
|
|
torch.accelerator.synchronize(device)
|
|
|
|
|
|
@requires_accelerator_dist_backend(["nccl", "xccl"])
|
|
@unittest.skipUnless(torch.accelerator.is_available(), "Requires accelerator")
|
|
class TestSingleProc(DynamoDistributedSingleProcTestCase):
|
|
"""
|
|
Test harness initializes dist process group.
|
|
|
|
Test simple things here since they are simpler to debug.
|
|
Use TestMultiProc for things that really need to run on multiple nodes
|
|
"""
|
|
|
|
device_type = (
|
|
acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
|
|
)
|
|
|
|
def get_model(
|
|
self, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5, ctx_manager=None
|
|
):
|
|
m = ToyModel(
|
|
in_feat=in_feat,
|
|
hidden_feat=hidden_feat,
|
|
out_feat=out_feat,
|
|
ctx_manager=ctx_manager,
|
|
).to(self.device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(bsz, in_feat).to(self.device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ddp_baseline_aot_eager(self):
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids)
|
|
ddp_m = torch.compile(ddp_m, backend="aot_eager")
|
|
outputs = ddp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ddp_baseline_inductor(self):
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids)
|
|
ddp_m = torch.compile(ddp_m, backend="inductor")
|
|
outputs = ddp_m(inputs)
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_graph_split(self):
|
|
assert config.optimize_ddp
|
|
"""
|
|
Just ensures that the appropriate number of splits happen (based on
|
|
bucket size and model parameters) - verifies the number of times
|
|
the user-provided compiler is called by the DDPOptimizer which is
|
|
doing the graph splitting
|
|
"""
|
|
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch.compile(backend=check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 3)
|
|
|
|
# ensure compatibility with dynamo explain
|
|
|
|
explain_out = torch._dynamo.explain(ddp_m)(inputs)
|
|
break_reasons = explain_out.break_reasons
|
|
self.assertEqual(len(break_reasons), 3)
|
|
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_graph_split_ctx_manager(self):
|
|
"""
|
|
Ensures that we get the right number of splits and that the respective
|
|
context managers' effects are applied to the computation.
|
|
"""
|
|
|
|
for get_compiler in [
|
|
lambda: CheckSplitsCompiler(),
|
|
lambda: None,
|
|
]:
|
|
for ctx_manager, output_test in [
|
|
(
|
|
lambda: torch.autocast(
|
|
torch.device(self.device).type, torch.float16
|
|
),
|
|
lambda out: self.assertEqual(out.dtype, torch.float16),
|
|
),
|
|
(torch.enable_grad, lambda out: self.assertTrue(out.requires_grad)),
|
|
(torch.no_grad, lambda out: self.assertTrue(not out.requires_grad)),
|
|
]:
|
|
m, inputs, correct_outputs = self.get_model(
|
|
out_feat=1000,
|
|
hidden_feat=1000,
|
|
in_feat=1000,
|
|
ctx_manager=ctx_manager,
|
|
)
|
|
# inp - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
|
|
# hidden - 1000 * 1000 matrix of float32 (4 bytes) = 4MB
|
|
bucket_cap_mb = 3.5 # 4MB
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=bucket_cap_mb)
|
|
|
|
compiler = get_compiler()
|
|
|
|
@torch.compile(backend=compiler.compile_fn if compiler else "aot_eager")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
if compiler:
|
|
self.assertEqual(compiler.compiler_called, 4)
|
|
|
|
output_test(opt_outputs)
|
|
|
|
# ensure compatibility with dynamo explain
|
|
|
|
explain_out = torch._dynamo.explain(ddp_m)(inputs)
|
|
break_reasons = explain_out.break_reasons
|
|
self.assertEqual(len(break_reasons), 4)
|
|
self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons))
|
|
|
|
@skipIfXpu # XPU device doesn't support flex_attention yet.
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_compiled_flex_attention_full_model_ddp(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, S, H, D):
|
|
super().__init__()
|
|
|
|
self.S = S
|
|
self.H = H
|
|
self.D = D
|
|
|
|
alibi_bias = self.generate_alibi_bias(H)
|
|
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
|
|
self.attention = flex_attention
|
|
|
|
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
|
|
self.project_v = torch.nn.Linear(H * D, H * D)
|
|
|
|
def forward(self, hidden_states):
|
|
batch_size, _, _ = hidden_states.size()
|
|
|
|
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
|
|
query = query.view(self.S, batch_size, self.H, self.D)
|
|
query = query.permute(1, 2, 0, 3)
|
|
|
|
key = key.view(self.S, batch_size, self.H, self.D)
|
|
key = key.permute(1, 2, 0, 3)
|
|
|
|
value = self.project_v(hidden_states)
|
|
value = value.view(self.S, batch_size, self.H, self.D)
|
|
value = value.permute(1, 2, 0, 3)
|
|
|
|
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
|
|
|
|
def generate_alibi_bias(self, num_heads):
|
|
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
|
|
return torch.tensor(alibi_bias)
|
|
|
|
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
|
|
bias = (q_idx - kv_idx) * self.alibi_bias[h]
|
|
return score + bias
|
|
|
|
B = 16
|
|
H = 12
|
|
S = 512
|
|
D = 64
|
|
|
|
model = Model(S, H, D)
|
|
model.to(self.device_type)
|
|
model = torch.compile(model)
|
|
model = DDP(model, device_ids=self.device_ids)
|
|
|
|
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
|
|
model(hidden_states)
|
|
torch.accelerator.synchronize()
|
|
|
|
@skipIfXpu # XPU device doesn't support flex_attention yet.
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_compiled_flex_attention_local_ddp(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self, S, H, D):
|
|
super().__init__()
|
|
|
|
self.S = S
|
|
self.H = H
|
|
self.D = D
|
|
|
|
alibi_bias = self.generate_alibi_bias(H)
|
|
self.register_buffer("alibi_bias", alibi_bias, persistent=True)
|
|
self.attention = torch.compile(flex_attention)
|
|
|
|
self.project_qk = torch.nn.Linear(H * D, H * D * 2)
|
|
self.project_v = torch.nn.Linear(H * D, H * D)
|
|
|
|
def forward(self, hidden_states):
|
|
batch_size, _, _ = hidden_states.size()
|
|
|
|
query, key = self.project_qk(hidden_states).chunk(2, dim=2)
|
|
query = query.view(self.S, batch_size, self.H, self.D)
|
|
query = query.permute(1, 2, 0, 3)
|
|
|
|
key = key.view(self.S, batch_size, self.H, self.D)
|
|
key = key.permute(1, 2, 0, 3)
|
|
|
|
value = self.project_v(hidden_states)
|
|
value = value.view(self.S, batch_size, self.H, self.D)
|
|
value = value.permute(1, 2, 0, 3)
|
|
|
|
return self.attention(query, key, value, score_mod=self.alibi_score_mod)
|
|
|
|
def generate_alibi_bias(self, num_heads):
|
|
alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)]
|
|
return torch.tensor(alibi_bias)
|
|
|
|
def alibi_score_mod(self, score, b, h, q_idx, kv_idx):
|
|
bias = (q_idx - kv_idx) * self.alibi_bias[h]
|
|
return score + bias
|
|
|
|
B = 16
|
|
H = 12
|
|
S = 512
|
|
D = 64
|
|
|
|
model = Model(S, H, D)
|
|
model.to(self.device_type)
|
|
model = torch.compile(model)
|
|
model = DDP(model, device_ids=self.device_ids)
|
|
|
|
hidden_states = torch.randn(B, S, H * D).to(self.device_type)
|
|
model(hidden_states)
|
|
torch.accelerator.synchronize()
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_graph_split_inductor(self):
|
|
assert config.optimize_ddp
|
|
"""
|
|
Same as above, but using inductor backend.
|
|
We observed issues with inductor/fx interface in the past.
|
|
"""
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
@torch.compile(backend="inductor")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
|
|
@torch._inductor.config.patch(
|
|
{"layout_optimization": True, "keep_output_stride": False}
|
|
)
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def _test_graph_split_inductor_layout_optimizations_impl(self, context):
|
|
assert config.optimize_ddp
|
|
channel_dim = 512
|
|
# channel dim must be > 64 for inductor to do layout optimization and use NHWC
|
|
|
|
class ToyModelConv(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
*[
|
|
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
|
|
nn.ReLU(),
|
|
]
|
|
+ [
|
|
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
|
|
nn.ReLU(),
|
|
]
|
|
+ [
|
|
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
|
|
nn.ReLU(),
|
|
]
|
|
+ [
|
|
nn.Conv2d(channel_dim, channel_dim, 1, stride=1, bias=False),
|
|
nn.ReLU(),
|
|
]
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
return self.net(inputs)
|
|
|
|
def get_model():
|
|
m = ToyModelConv().to(self.device)
|
|
m.apply(init_weights)
|
|
inputs = torch.rand(2, channel_dim, channel_dim, 128).to(self.device)
|
|
outputs = m(inputs)
|
|
return m, inputs, outputs
|
|
|
|
with context():
|
|
m, inputs, correct_outputs = get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
@torch.compile(backend="inductor")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_graph_split_inductor_layout_optimizations_training(self):
|
|
self._test_graph_split_inductor_layout_optimizations_impl(
|
|
contextlib.nullcontext
|
|
)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_graph_split_inductor_layout_optimizations_inference(self):
|
|
self._test_graph_split_inductor_layout_optimizations_impl(torch.no_grad)
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_graph_split_inductor_transpose(self):
|
|
assert config.optimize_ddp
|
|
|
|
B = 100
|
|
N = 30
|
|
D = 50
|
|
K = 70
|
|
|
|
class Foo(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear0 = nn.Linear(N, K)
|
|
self.linear1 = torch.nn.Linear(D * K, 2048)
|
|
|
|
def forward(self, x):
|
|
xt = x.transpose(2, 1)
|
|
xt = self.linear0(xt).flatten(1)
|
|
return self.linear1(xt)
|
|
|
|
mod = Foo().to(self.device)
|
|
|
|
compiled_mod = torch.compile(mod, backend="inductor")
|
|
ddp_compiled_mod = DDP(compiled_mod, device_ids=self.device_ids)
|
|
|
|
x = torch.randn((B, N, D), dtype=torch.float32, device=self.device)
|
|
self.assertTrue(same(mod(x), ddp_compiled_mod(x)))
|
|
|
|
x_1 = torch.randn((B * 2, N, D), dtype=torch.float32, device=self.device)
|
|
self.assertTrue(same(mod(x_1), ddp_compiled_mod(x_1)))
|
|
|
|
x_2 = torch.randn((B * 3, N, D), dtype=torch.float32, device=self.device)
|
|
self.assertTrue(same(mod(x_2), ddp_compiled_mod(x_2)))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_no_split(self):
|
|
"""
|
|
Ensures the DDPOptimizer returns a correct, compiled module without
|
|
introducing graph splits. (Based on model parameters fitting in the bucket)
|
|
"""
|
|
# DDP will always do a 'first bucket' with a really small size; so only a tiny model will escape this
|
|
m, inputs, correct_outputs = self.get_model(hidden_feat=5)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=250)
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch.compile(backend=check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 1)
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_aot_autograd(self):
|
|
"""
|
|
Explicitly check AotAutograd family of compilers work,
|
|
since they require example inputs propagated between graph splits.
|
|
"""
|
|
m, inputs, correct_outputs = self.get_model()
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
|
|
@torch.compile(backend="aot_eager")
|
|
def opt_fn(inputs):
|
|
return ddp_m(inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
opt_outputs.sum().backward()
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_custom_layer(self):
|
|
"""
|
|
Just ensures that the appropriate number of splits happen (based on
|
|
bucket size and model parameters) - verifies the number of times
|
|
the user-provided compiler is called by the DDPOptimizer which is
|
|
doing the graph splitting
|
|
"""
|
|
m, inputs, correct_outputs = get_custom_model(self.device)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=1)
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
|
|
@torch.compile(backend=check_splits_compiler.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(*inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 3)
|
|
|
|
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_empty_graph_inductor(self):
|
|
def fn():
|
|
get_world_size = torch.distributed.distributed_c10d.get_world_size()
|
|
return (get_world_size,)
|
|
|
|
opt_fn = torch.compile(fn, backend="inductor")
|
|
res = None
|
|
try:
|
|
res = opt_fn()[0]
|
|
except Exception:
|
|
pass
|
|
self.assertEqual(res, 1)
|
|
|
|
@patch.object(config, "optimize_ddp", False)
|
|
def test_ignored_parameters(self):
|
|
"""
|
|
Verifies ddp graph-split logic ignores parameters marked to ignore on DDP module.
|
|
Hooks up graph-split optimizer manually so it can peek at internal state.
|
|
"""
|
|
m, inputs, correct_outputs = get_custom_model(self.device)
|
|
parameters_to_ignore = ["seq.2.weight", "seq.4.linear.bias"]
|
|
DDP._set_params_and_buffers_to_ignore_for_model(m, parameters_to_ignore)
|
|
ddp_m = DDP(m, device_ids=self.device_ids, bucket_cap_mb=25)
|
|
parameter_ids_to_ignore = [
|
|
id(ddp_m.module.get_parameter(p)) for p in ddp_m.parameters_to_ignore
|
|
]
|
|
|
|
check_splits_compiler = CheckSplitsCompiler()
|
|
ddp_optimizer = DDPOptimizer(
|
|
bucket_bytes_cap=ddp_m.bucket_bytes_cap,
|
|
backend_compile_fn=check_splits_compiler.compile_fn,
|
|
)
|
|
|
|
@torch.compile(backend=ddp_optimizer.compile_fn)
|
|
def opt_fn(inputs):
|
|
return ddp_m(*inputs)
|
|
|
|
opt_outputs = opt_fn(inputs)
|
|
self.assertTrue(same(correct_outputs, opt_outputs))
|
|
self.assertEqual(check_splits_compiler.compiler_called, 2)
|
|
for b in ddp_optimizer.buckets:
|
|
for p_id in b.param_ids:
|
|
self.assertFalse(p_id in parameter_ids_to_ignore)
|
|
|
|
@patch.object(config, "optimize_ddp", True)
|
|
def test_higher_order_op(self):
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
N = 1000
|
|
|
|
class InnerModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(N, N)
|
|
self.linear2 = torch.nn.Linear(N, N)
|
|
|
|
def forward(self, x):
|
|
a = self.linear1(x)
|
|
a = self.linear2(a)
|
|
return a
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.inner_mod1 = InnerModule()
|
|
self.inner_mod2 = InnerModule()
|
|
|
|
def forward(self, x):
|
|
a = checkpoint(self.inner_mod1, x, use_reentrant=False)
|
|
a = torch.cos(a)
|
|
a = checkpoint(self.inner_mod2, a, use_reentrant=False)
|
|
a = torch.cos(a)
|
|
return a
|
|
|
|
mod = MockModule().to(self.device_type)
|
|
mod = DDP(mod, bucket_cap_mb=1)
|
|
x = torch.randn(N, N, device=self.device_type, requires_grad=True)
|
|
args = (x,)
|
|
|
|
backend = "aot_eager"
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
torch.compile(mod, backend=cnt)(*args)
|
|
|
|
def test_fsdp_orig_params_assert(self):
|
|
# Test with basic FSDP wrapping (outer wrap around whole model)
|
|
m, inputs, _ = get_model(f"{self.device_type}:{self.rank}")
|
|
fsdp_m = FSDP(m, use_orig_params=False)
|
|
# Test is that this function call does not throw an exception.
|
|
fsdp_m = torch.compile(fsdp_m)
|
|
|
|
def test_fsdp_skip_guards(self):
|
|
"""
|
|
It's currently difficult to test dynamo guards. Most guards tests are indirect- modify something and
|
|
observe that the guard in question failed. In this case, since the FSDP guards were already deemed
|
|
useless and skipping them is expected to have no practical effect, it's pretty contrived to even try to
|
|
make those guards fail. Instead, we observe the 'guard source' printed by dynamo's comptime print_guards
|
|
function.
|
|
|
|
Note: comptime prints the guards before the time they get installed or not installed, so in both cases
|
|
(skip or no skip) the same guards get printed. The difference is that in the skip case, they show up
|
|
with a special 'guard source' which will cause them to not be installed. So all we check for is the expected
|
|
guard source 'local_fsdp_module'.
|
|
"""
|
|
global GUARDS_FILE
|
|
GUARDS_FILE = StringIO()
|
|
|
|
for skip_guards, expected_guard_source in (
|
|
(True, "local_fsdp_module"),
|
|
(False, "local_unspecialized_nn_module"),
|
|
):
|
|
torch._dynamo.reset()
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, out_feat), nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
out = self.net(inputs)
|
|
|
|
@comptime
|
|
def _(ctx):
|
|
ctx.print_guards(file=GUARDS_FILE)
|
|
|
|
return out
|
|
|
|
device = f"{self.device_type}:{self.rank}"
|
|
m = ToyModel(
|
|
in_feat=10,
|
|
hidden_feat=5000,
|
|
out_feat=5,
|
|
).to(device)
|
|
inputs = torch.rand(20, 10).to(device)
|
|
m.apply(init_weights)
|
|
correct_outputs = m(inputs)
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
|
|
with torch._dynamo.config.patch(skip_fsdp_guards=skip_guards):
|
|
opt_m = torch.compile(fsdp_m, backend="aot_eager")
|
|
outputs = opt_m(inputs)
|
|
|
|
# far from an exhaustive check of all the expected guards, just check a couple of them.
|
|
FileCheck().check("""local "L['self']" TYPE_MATCH""").check(
|
|
f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH"""
|
|
).check(
|
|
f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH"""
|
|
).run(GUARDS_FILE.getvalue())
|
|
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
def test_fsdp_skip_register_attr_or_module(self):
|
|
"""
|
|
ensure FSDP module is not registered as attributes
|
|
in the fx graph
|
|
see `not source.guard_source().is_fsdp_module()`
|
|
before calling `register_attr_or_module`
|
|
in variables/builder.py
|
|
"""
|
|
|
|
class ToyModel(nn.Module):
|
|
def __init__(self, in_feat=10, hidden_feat=5000, out_feat=5):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
*[nn.Linear(in_feat, hidden_feat), nn.ReLU()]
|
|
+ [nn.Linear(hidden_feat, hidden_feat), nn.ReLU()]
|
|
)
|
|
|
|
def forward(self, inputs):
|
|
out = self.net(inputs)
|
|
return out
|
|
|
|
torch._dynamo.reset()
|
|
|
|
device = f"{self.device_type}:{self.rank}"
|
|
m = ToyModel(
|
|
in_feat=10,
|
|
hidden_feat=5000,
|
|
out_feat=5,
|
|
).to(device)
|
|
inputs = torch.rand(20, 10).to(device)
|
|
m.apply(init_weights)
|
|
correct_outputs = m(inputs)
|
|
fsdp_m = FSDP(m, use_orig_params=True)
|
|
|
|
def debug_compiler(gm, _):
|
|
for node in gm.graph.nodes:
|
|
if node.op == "get_attr":
|
|
for name in [
|
|
"l__self___net_0_weight",
|
|
"l__self___net_0_bias",
|
|
"l__self___net_2_weight",
|
|
"l__self___net_2_bias",
|
|
]:
|
|
self.assertFalse(
|
|
name in node.name,
|
|
f"FSDP module {name} should not be registered as attributes",
|
|
)
|
|
return gm
|
|
|
|
opt_m = torch.compile(fsdp_m, backend=debug_compiler)
|
|
outputs = opt_m(inputs)
|
|
|
|
self.assertTrue(same(correct_outputs, outputs))
|
|
|
|
def test_fsdp_dup_tensors_same_source(self):
|
|
"""
|
|
Tests that FSDP-managed modules' parameters and buffers with the same
|
|
source are de-duplicated, meaning that they are each only passed once
|
|
as a graph input.
|
|
"""
|
|
|
|
class DuplicateModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
device_type = (
|
|
acc.type
|
|
if (acc := torch.accelerator.current_accelerator())
|
|
else "cpu"
|
|
)
|
|
self._param = torch.randn((3,), device=device_type)
|
|
self._buf = torch.nn.Buffer(
|
|
torch.randn((3,), requires_grad=False, device=device_type)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Use `_param` and `_buf` each twice in this compiled forward
|
|
# to exercise if they are de-duplicated by TorchDynamo
|
|
z = x + self._buf + self._buf
|
|
z += self._param + self._param
|
|
return z
|
|
|
|
model = DuplicateModule()
|
|
fsdp_model = FSDP(copy.deepcopy(model), use_orig_params=True)
|
|
fsdp_model = torch.compile(fsdp_model, backend="aot_eager")
|
|
inp = torch.randn((2, 3), device=self.device_type)
|
|
local_out = model(inp)
|
|
fsdp_out = fsdp_model(inp)
|
|
self.assertEqual(local_out, fsdp_out)
|
|
|
|
@patch.object(config, "guard_nn_modules", True)
|
|
def test_fsdp_dup_tensors_diff_source(self):
|
|
"""
|
|
Tests that FSDP-managed modules' parameters and buffers with different
|
|
source do not result in incorrect AOTAutograd de-dup guards like
|
|
``a is b``, where ``a`` and ``b`` are certainly not the same. We check
|
|
this by checking for per-invocation recompiles.
|
|
"""
|
|
|
|
class BufModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
device_type = (
|
|
acc.type
|
|
if (acc := torch.accelerator.current_accelerator())
|
|
else "cpu"
|
|
)
|
|
self._buf = nn.Buffer(
|
|
torch.randn((3,), requires_grad=False, device=device_type)
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self._buf
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
device_type = (
|
|
acc.type
|
|
if (acc := torch.accelerator.current_accelerator())
|
|
else "cpu"
|
|
)
|
|
self._param = nn.Parameter(torch.randn((1,), device=device_type))
|
|
self._buf_module = BufModule()
|
|
# Share the buffer, meaning same tensor but different source
|
|
self._buf = self._buf_module._buf
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
# Use the same buffer tensor twice in the compiled forward,
|
|
# including a data mutation to trigger de-dup logic
|
|
self._buf.mul_(2)
|
|
z = x + self._buf
|
|
z = self._buf_module(z)
|
|
z += self._param
|
|
return z
|
|
|
|
fsdp_model = FSDP(Model(), use_orig_params=True)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
fsdp_model = torch.compile(fsdp_model, backend=cnt)
|
|
inp = torch.randn((2, 3), device=self.device_type)
|
|
for _ in range(15):
|
|
fsdp_model(inp)
|
|
# Check for no recompiles (if there were incorrect de-dup guards, then
|
|
# the frame count would be equal to the number of forward calls)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_fsdp_staticmethod(self):
|
|
"""
|
|
Tests that Dynamo compiles staticmethods for FSDP-managed modules
|
|
correctly both when the staticmethod is invoked from the class and from
|
|
the object itself.
|
|
"""
|
|
|
|
class ModuleWithStaticMethod(nn.Module):
|
|
def __init__(self, use_self: bool):
|
|
super().__init__()
|
|
self._use_self = use_self
|
|
torch.manual_seed(42) # force `_param` to be deterministic
|
|
device_type = (
|
|
acc.type
|
|
if (acc := torch.accelerator.current_accelerator())
|
|
else "cpu"
|
|
)
|
|
self._param = nn.Parameter(torch.randn((3,), device=device_type))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
if self._use_self:
|
|
z = self._add(x, self._param)
|
|
else:
|
|
z = ModuleWithStaticMethod._add(x, self._param)
|
|
z *= 2
|
|
return z
|
|
|
|
@staticmethod
|
|
def _add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + y
|
|
|
|
model = ModuleWithStaticMethod(False)
|
|
x = torch.randn((2, 3), device=self.device_type)
|
|
ref_out = model(x)
|
|
test_outs: list[torch.Tensor] = []
|
|
|
|
for use_self in (False, True):
|
|
model = ModuleWithStaticMethod(use_self)
|
|
fsdp_model = FSDP(model, use_orig_params=True)
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
fsdp_model = torch.compile(fsdp_model, backend=cnt)
|
|
test_outs.append(fsdp_model(x))
|
|
# Check for no recompiles, which could happen if incorrectly
|
|
# passing args to the staticmethod (e.g. doubly passing `self`)
|
|
# 3 is expected here for 1 forward.
|
|
# Graph 1 should be add and imul
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
for test_out in test_outs:
|
|
self.assertEqual(test_out, ref_out)
|
|
|
|
def test_async_subclass_no_specialize(self):
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend("eager")
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True, dynamic=True)
|
|
def f(x):
|
|
return x + 1
|
|
|
|
f(_maybe_wrap_tensor(torch.randn(10)))
|
|
f(_maybe_wrap_tensor(torch.randn(12)))
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|