Compare commits

...

2 Commits

Author SHA1 Message Date
d0e00d5448 remove resize and del 2024-06-20 17:07:47 -07:00
25229787d6 [Traceable FSDP2] Add unit tests for simple MLP and transformer model
ghstack-source-id: a23c48d5d56d2633f7ec7efb7b1567340d4feeb1
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129157
2024-06-20 14:45:17 -07:00
2 changed files with 186 additions and 3 deletions

View File

@ -1,16 +1,30 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import unittest
import torch
import torch._dynamo.testing
from torch import nn
from torch._dynamo import compiled_autograd
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp._fsdp_common import TrainingState
from torch.distributed._composable.fsdp._fsdp_init import (
_get_managed_modules,
_get_managed_states,
)
from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup
from torch.distributed._tensor import init_device_mesh
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, MLP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
)
from torch.utils._triton import has_triton
@ -64,6 +78,10 @@ class TestFullyShardCompileCompute(FSDPTest):
class TestFullyShardCompile(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
def test_dynamo_trace_use_training_state(self):
torch._dynamo.reset()
# Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager.
@ -100,6 +118,174 @@ class TestFullyShardCompile(FSDPTest):
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@torch._functorch.config.patch(recompute_views=True)
def _test_traceable_fsdp(
self, model_init_fn, input_creation_fn, backend, fullgraph
):
n_iter = 10
def compiler_fn(compiled_autograd_backend):
def _fn(gm):
# fullgraph=True because graph-break in Compiled Autograd BWD graph is not supported by Traceable FSDP2 yet
# (main difficulty comes from queue_callback not working well when BWD has graph break).
return torch.compile(
gm, backend=compiled_autograd_backend, fullgraph=True
)
return _fn
def run_all_iters(model, optim, compiled_autograd_backend=None):
torch.manual_seed(42)
losses = []
for i in range(n_iter):
optim.zero_grad(set_to_none=True)
inp = input_creation_fn()
if compiled_autograd_backend is not None:
maybe_compiled_autograd_ctx = compiled_autograd.enable(
compiler_fn(compiled_autograd_backend)
)
else:
maybe_compiled_autograd_ctx = contextlib.nullcontext()
with maybe_compiled_autograd_ctx:
out = model(inp)
loss = out.sum()
losses.append(loss.item())
loss.backward()
optim.step()
torch.cuda.synchronize()
return losses
def test_compiled():
model, optim = model_init_fn()
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_all_iters(model, optim, 1)
model_compiled = torch.compile(model, backend=backend, fullgraph=True)
res = run_all_iters(
model_compiled, optim, compiled_autograd_backend=backend
)
optim.zero_grad(set_to_none=True)
return res
def test_eager():
model, optim = model_init_fn()
# FSDP2 does lazy init using 1st run, so run it once to init using eager mode
run_all_iters(model, optim, 1)
res = run_all_iters(model, optim)
optim.zero_grad(set_to_none=True)
return res
losses_compiled = test_compiled()
losses_eager = test_eager()
for loss_compiled, loss_eager in zip(losses_compiled, losses_eager):
self.assertTrue(
torch.allclose(
torch.tensor(loss_compiled), torch.tensor(loss_eager), rtol=1e-3
),
f"{loss_compiled} vs {loss_eager}",
)
def _create_simple_mlp_factory_fns(self):
hidden_dim = 16
def model_init_fn():
torch.manual_seed(0)
fsdp_config = {}
model = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim, device="cuda"),
)
fully_shard(model, reshard_after_forward=True, **fsdp_config)
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
return model, optim
def input_creation_fn():
torch.manual_seed(0)
inp = torch.randn((2, hidden_dim), device="cuda", requires_grad=False)
return inp
return model_init_fn, input_creation_fn
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "eager", fullgraph=True
)
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
)
@unittest.expectedFailure
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_simple_mlp_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_simple_mlp_factory_fns(), "inductor", fullgraph=True
)
def _create_transformer_factory_fns(self):
hidden_dim = 16
def model_init_fn():
torch.manual_seed(0)
fsdp_config = {}
mesh = init_device_mesh("cuda", (self.world_size,))
model_args = ModelArgs(
dim=hidden_dim,
n_layers=2,
n_heads=1,
vocab_size=1024,
)
model = Transformer(model_args)
for layer_id, mod in enumerate(model.layers):
fully_shard(mod, mesh=mesh, reshard_after_forward=True, **fsdp_config)
model.layers[layer_id] = mod
model = fully_shard(
model, mesh=mesh, reshard_after_forward=True, **fsdp_config
)
optim = torch.optim.SGD(model.parameters(), lr=1e-6)
return model, optim
def input_creation_fn():
torch.manual_seed(0)
inp = torch.zeros(
(2, hidden_dim),
device="cuda",
requires_grad=False,
dtype=torch.long,
)
return inp
return model_init_fn, input_creation_fn
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_eager(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "eager", fullgraph=True
)
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_aot_eager(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "aot_eager", fullgraph=True
)
@unittest.expectedFailure
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
def test_transformer_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "inductor", fullgraph=True
)
if __name__ == "__main__":
run_tests()

View File

@ -351,9 +351,6 @@ base_dir = dirname(dirname(dirname(abspath(__file__))))
# Trace through NumPy or graphbreak
trace_numpy = True
# Trace through torch.distributed code
trace_distributed = False
# Default NumPy dtypes when tracing with torch.compile
# We default to 64bits. For efficiency, one may want to change these to float32
numpy_default_float = "float64"