[DTensor] Extend implicit replication to replicate DTensor for foreach ops so model doesn't have to be fully tp-ed when using 2D (#134551)

Fixes [134212](https://github.com/pytorch/pytorch/issues/134212)

Currently, when we use 2D FSDP with TP, `optimizer.step()` would fail if the model were not fully tensor parallelized. If we don't have the entire model tensor parallelized when doing 2D, we would have both 1D and 2D DTensor parameters. As foreach is turned on by default, `optimizer.step()` would fail as cross mesh op is not allowed. Error as follows:

```
NotImplementedError: aten._foreach_mul_.Scalar: DTensor does not support cross-mesh operation yet!Got meshes: DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')) DeviceMesh('cuda', [1, 3], mesh_dim_names=('dp',))
```

In this PR, we extend implicit_replication to replicate DTensor in missing dimensions for foreach ops. If users don't want to fully tensor parallelize the model when using 2D, they have the option of using the `implicit_replication()` context manager for `optimizer.step()`. In this case, we would swap out the 1D DTensorSpec and replace it with 2D DTensorSpec. However, we don't want to turn this on by default yet, as we want the users to be aware that the tp dimension is replicated if a layer is not tp-ed.

With implicit implication turning on, try replicate dtensor spec in missing dimension would work for most cases for foreach case except when the first DTensor in the list is one that also need to be replicated. This is currently a limitation, which I don't have a good solution yet. Currently, with this change, we can handle most of the cases except the case that the first DTensor's ndim is not the largest.
```
[2D_DTensor, 1D_DTensor...] ---> Implicit_replication() can handle this.
[1D_DTensor, 2D_DTensor...] ---> Implicit_replication() can't handle this.
```

This change doesn't affect the existing default behavior, as `implicit_replication()` is not turned on by default.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134551
Approved by: https://github.com/tianyu-l
This commit is contained in:
wz337
2024-08-29 09:01:29 +00:00
committed by PyTorch MergeBot
parent 3645634f3c
commit cfb642bb6b
3 changed files with 218 additions and 112 deletions

View File

@ -15,6 +15,7 @@ from torch.distributed._tensor import (
init_device_mesh,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed._tensor.placement_types import (
DTensorSpec,
Partial,
@ -778,8 +779,6 @@ class DTensorMeshTest(DTensorTestBase):
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
# We put the scalar tensor as the left operand so we can test out
# when a non-dtensor is a the arg in the args list.
@ -816,6 +815,41 @@ class DTensorMeshTest(DTensorTestBase):
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
)
@with_comms
def test_implicit_replication_for_foreach_ops(self):
mesh = init_device_mesh(
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
)
global_tensor1 = torch.randn(4, 2)
dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)])
self.assertEqual(dtensor_2d.full_tensor(), global_tensor1)
global_tensor2 = torch.randn(4)
dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)])
dtensor_list = [dtensor_2d, dtensor_1d]
# Check without implicit replication, cross mesh error raises.
with self.assertRaisesRegex(
RuntimeError, "DTensor does not support cross-mesh operation yet!"
):
torch._foreach_mul(dtensor_list, 2.0)
# Check dtensor result matches tensor result.
with implicit_replication():
torch._foreach_mul_(dtensor_list, 2.0)
self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0)
self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0)
mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type)
dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)])
dtensor_list = [dtensor_2d, dtensor_1d]
# Check even with implicit replication, cross mesh error raises if different device mesh don't
# belong to the same root mesh.
with self.assertRaisesRegex(
RuntimeError, "DTensor does not support cross-mesh operation yet!"
):
torch._foreach_mul_(dtensor_list, 2.0)
@with_comms
def test_metadata_consistency_check(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -1,5 +1,6 @@
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import torch
@ -7,6 +8,7 @@ import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed._tensor.experimental import implicit_replication
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
@ -439,8 +441,17 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
self.assertEqual(base_osd, fsdp2_tp_full_osd)
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_save_with_fsdp2_tp_and_load_with_tp(self):
self.run_subtests(
{"allow_implicit_replication": [True, False]},
self._test_save_with_fsdp2_tp_and_load_with_tp,
)
@skip_if_lt_x_gpu(4)
@with_temp_dir
def _test_save_with_fsdp2_tp_and_load_with_tp(
self, allow_implicit_replication: bool
):
"""
Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP.
"""
@ -449,6 +460,11 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
return base_model
cm = (
implicit_replication()
if allow_implicit_replication
else contextlib.nullcontext()
)
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
@ -457,108 +473,124 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
}
if allow_implicit_replication:
# intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
tp_parallelize_plan.pop("0.in_proj")
tp_parallelize_plan.pop("0.out_proj")
# init device mesh
dp_size = 2
global_mesh_1d = init_device_mesh(
"cuda", (self.world_size,), mesh_dim_names=("tp",)
)
global_mesh_2d = init_device_mesh(
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
)
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
for save_full_state_dict in [True, False]:
# Save state dict with original model
base_model = _get_base_model().cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
# Save state dict with FSDP2 + TP model
fsdp2_tp_model = copy.deepcopy(base_model)
fsdp2_tp_model = parallelize_module(
fsdp2_tp_model,
device_mesh=tp_mesh,
parallelize_plan=tp_parallelize_plan,
)
for module in fsdp2_tp_model:
fully_shard(module, mesh=dp_mesh)
fully_shard(fsdp2_tp_model, mesh=dp_mesh)
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
# one-step training to modify state dict
inp = torch.randn((2,), device=self.rank)
base_model(inp).sum().backward()
base_optim.step()
fsdp2_tp_model(inp).sum().backward()
fsdp2_tp_optim.step()
# obtain the unsharded state dict
base_msd = get_model_state_dict(
base_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
base_osd = get_optimizer_state_dict(
base_model,
base_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# obtain FSDP2 + TP state dict
fsdp2_tp_msd = get_model_state_dict(
fsdp2_tp_model,
options=StateDictOptions(full_state_dict=save_full_state_dict),
)
fsdp2_tp_osd = get_optimizer_state_dict(
fsdp2_tp_model,
fsdp2_tp_optim,
options=StateDictOptions(full_state_dict=save_full_state_dict),
)
fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd}
dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir)
fsdp2_tp_full_msd = get_model_state_dict(
fsdp2_tp_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
fsdp2_tp_full_osd = get_optimizer_state_dict(
fsdp2_tp_model,
fsdp2_tp_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# Load state dict into model with TP applied
tp_model = _get_base_model()
tp_model = parallelize_module(
tp_model,
device_mesh=global_mesh_1d,
parallelize_plan=tp_parallelize_plan,
)
tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
tp_state_dict = {
"model": get_model_state_dict(tp_model),
"optim": get_optimizer_state_dict(tp_model, tp_optim),
with cm:
tp_parallelize_plan = {
"0.in_proj": ColwiseParallel(),
"0.out_proj": RowwiseParallel(),
"1.in_proj": ColwiseParallel(),
"1.out_proj": RowwiseParallel(),
"2.in_proj": ColwiseParallel(),
"2.out_proj": RowwiseParallel(),
}
dcp.load(tp_state_dict, checkpoint_id=self.temp_dir)
tp_model.load_state_dict(tp_state_dict["model"])
tp_optim.load_state_dict(tp_state_dict["optim"])
tp_full_msd = get_model_state_dict(
tp_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
# init device mesh
dp_size = 2
global_mesh_1d = init_device_mesh(
"cuda", (self.world_size,), mesh_dim_names=("tp",)
)
tp_full_osd = get_optimizer_state_dict(
tp_model,
tp_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
global_mesh_2d = init_device_mesh(
"cuda",
(dp_size, self.world_size // dp_size),
mesh_dim_names=("dp", "tp"),
)
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
# Compare full state dict to make sure they are the same.
self.assertEqual(base_msd, tp_full_msd)
self.assertEqual(base_osd, tp_full_osd)
self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
for save_full_state_dict in [True, False]:
# Save state dict with original model
base_model = _get_base_model().cuda()
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
# Save state dict with FSDP2 + TP model
fsdp2_tp_model = copy.deepcopy(base_model)
fsdp2_tp_model = parallelize_module(
fsdp2_tp_model,
device_mesh=tp_mesh,
parallelize_plan=tp_parallelize_plan,
)
for module in fsdp2_tp_model:
fully_shard(module, mesh=dp_mesh)
fully_shard(fsdp2_tp_model, mesh=dp_mesh)
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
# one-step training to modify state dict
inp = torch.randn((2,), device=self.rank)
base_model(inp).sum().backward()
base_optim.step()
fsdp2_tp_model(inp).sum().backward()
fsdp2_tp_optim.step()
# obtain the unsharded state dict
base_msd = get_model_state_dict(
base_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
base_osd = get_optimizer_state_dict(
base_model,
base_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# obtain FSDP2 + TP state dict
fsdp2_tp_msd = get_model_state_dict(
fsdp2_tp_model,
options=StateDictOptions(full_state_dict=save_full_state_dict),
)
fsdp2_tp_osd = get_optimizer_state_dict(
fsdp2_tp_model,
fsdp2_tp_optim,
options=StateDictOptions(full_state_dict=save_full_state_dict),
)
fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd}
dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir)
fsdp2_tp_full_msd = get_model_state_dict(
fsdp2_tp_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
fsdp2_tp_full_osd = get_optimizer_state_dict(
fsdp2_tp_model,
fsdp2_tp_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# Load state dict into model with TP applied
tp_model = _get_base_model()
tp_model = parallelize_module(
tp_model,
device_mesh=global_mesh_1d,
parallelize_plan=tp_parallelize_plan,
)
tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
tp_state_dict = {
"model": get_model_state_dict(tp_model),
"optim": get_optimizer_state_dict(tp_model, tp_optim),
}
dcp.load(tp_state_dict, checkpoint_id=self.temp_dir)
tp_model.load_state_dict(tp_state_dict["model"])
tp_optim.load_state_dict(tp_state_dict["optim"])
tp_full_msd = get_model_state_dict(
tp_model,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
tp_full_osd = get_optimizer_state_dict(
tp_model,
tp_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
# Compare full state dict to make sure they are the same.
self.assertEqual(base_msd, tp_full_msd)
self.assertEqual(base_osd, tp_full_osd)
self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
if __name__ == "__main__":

View File

@ -309,16 +309,20 @@ class OpDispatcher:
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
local_args.append(arg._local_tensor)
if mesh is not None:
if mesh != arg.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
f"Got meshes: {mesh} {arg.device_mesh}"
)
if mesh is not None and mesh != arg.device_mesh:
# TODO: try replicate dtensor spec in missing dimension would work
# for most cases for foreach case except when the first DTensor in
# the list is one that also need to be replicated. We need to revisit
# how we want to handle this corner case. For now, this case would hit
# the cross mesh error even if implicit replication is turned on.
spec = self._try_replicate_dtensor_spec_in_missing_dim(
op_call, arg, mesh
)
args_schema.append(spec)
else:
mesh = arg.device_mesh
args_schema.append(arg._spec)
elif isinstance(arg, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
args_schema.append(
@ -331,15 +335,15 @@ class OpDispatcher:
for k, v in kwargs.items():
if isinstance(v, dtensor.DTensor):
kwargs_schema[k] = v._spec
local_kwargs[k] = v._local_tensor
if mesh is not None:
if mesh != v.device_mesh:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet!"
)
if mesh is not None and mesh != v.device_mesh:
spec = self._try_replicate_dtensor_spec_in_missing_dim(
op_call, v, mesh
)
kwargs_schema[k] = spec
else:
mesh = v.device_mesh
kwargs_schema[k] = v._spec
elif isinstance(v, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
@ -426,3 +430,39 @@ class OpDispatcher:
" torch.Tensor to DTensor before calling distributed operators!"
)
return replication_spec
def _try_replicate_dtensor_spec_in_missing_dim(
self,
op_call: torch._ops.OpOverload,
dtensor_arg: "dtensor.DTensor",
mesh: "DeviceMesh",
) -> DTensorSpec:
# util function to produce a new spec for a DTensor arg/kwarg
# that puts Replicate() placement in the missing dimension for foreach ops
from torch.distributed.device_mesh import _mesh_resources
cur_mesh = dtensor_arg.device_mesh
root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
if (
self._allow_implicit_replication
and "foreach" in op_call.__name__
and root_mesh == mesh
):
placements = [Replicate() for _ in range(root_mesh.ndim)]
cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload]
replicate_spec = DTensorSpec(
root_mesh,
tuple(placements),
tensor_meta=TensorMeta(
shape=dtensor_arg.shape,
stride=dtensor_arg.stride(),
dtype=dtensor_arg.dtype,
),
)
else:
raise NotImplementedError(
f"{op_call}: DTensor does not support cross-mesh operation yet! "
f"Got meshes: {mesh} {cur_mesh}"
)
return replicate_spec