mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Extend implicit replication to replicate DTensor for foreach ops so model doesn't have to be fully tp-ed when using 2D (#134551)
Fixes [134212](https://github.com/pytorch/pytorch/issues/134212) Currently, when we use 2D FSDP with TP, `optimizer.step()` would fail if the model were not fully tensor parallelized. If we don't have the entire model tensor parallelized when doing 2D, we would have both 1D and 2D DTensor parameters. As foreach is turned on by default, `optimizer.step()` would fail as cross mesh op is not allowed. Error as follows: ``` NotImplementedError: aten._foreach_mul_.Scalar: DTensor does not support cross-mesh operation yet!Got meshes: DeviceMesh('cuda', [[0, 1], [2, 3]], mesh_dim_names=('dp', 'tp')) DeviceMesh('cuda', [1, 3], mesh_dim_names=('dp',)) ``` In this PR, we extend implicit_replication to replicate DTensor in missing dimensions for foreach ops. If users don't want to fully tensor parallelize the model when using 2D, they have the option of using the `implicit_replication()` context manager for `optimizer.step()`. In this case, we would swap out the 1D DTensorSpec and replace it with 2D DTensorSpec. However, we don't want to turn this on by default yet, as we want the users to be aware that the tp dimension is replicated if a layer is not tp-ed. With implicit implication turning on, try replicate dtensor spec in missing dimension would work for most cases for foreach case except when the first DTensor in the list is one that also need to be replicated. This is currently a limitation, which I don't have a good solution yet. Currently, with this change, we can handle most of the cases except the case that the first DTensor's ndim is not the largest. ``` [2D_DTensor, 1D_DTensor...] ---> Implicit_replication() can handle this. [1D_DTensor, 2D_DTensor...] ---> Implicit_replication() can't handle this. ``` This change doesn't affect the existing default behavior, as `implicit_replication()` is not turned on by default. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134551 Approved by: https://github.com/tianyu-l
This commit is contained in:
@ -15,6 +15,7 @@ from torch.distributed._tensor import (
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.debug import CommDebugMode
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed._tensor.placement_types import (
|
||||
DTensorSpec,
|
||||
Partial,
|
||||
@ -778,8 +779,6 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
local_tensor1 = torch.ones(4, 3)
|
||||
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
|
||||
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
|
||||
with implicit_replication():
|
||||
# We put the scalar tensor as the left operand so we can test out
|
||||
# when a non-dtensor is a the arg in the args list.
|
||||
@ -816,6 +815,41 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
(numel_1_tensor + sharded_dtensor).to_local(), numel_1_tensor + local_tensor
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_implicit_replication_for_foreach_ops(self):
|
||||
mesh = init_device_mesh(
|
||||
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
global_tensor1 = torch.randn(4, 2)
|
||||
dtensor_2d = distribute_tensor(global_tensor1, mesh, [Shard(0), Shard(1)])
|
||||
self.assertEqual(dtensor_2d.full_tensor(), global_tensor1)
|
||||
global_tensor2 = torch.randn(4)
|
||||
dtensor_1d = distribute_tensor(global_tensor2, mesh["dp"], [Shard(0)])
|
||||
dtensor_list = [dtensor_2d, dtensor_1d]
|
||||
|
||||
# Check without implicit replication, cross mesh error raises.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DTensor does not support cross-mesh operation yet!"
|
||||
):
|
||||
torch._foreach_mul(dtensor_list, 2.0)
|
||||
|
||||
# Check dtensor result matches tensor result.
|
||||
with implicit_replication():
|
||||
torch._foreach_mul_(dtensor_list, 2.0)
|
||||
self.assertEqual(dtensor_list[0].full_tensor(), global_tensor1 * 2.0)
|
||||
self.assertEqual(dtensor_list[1].full_tensor(), global_tensor2 * 2.0)
|
||||
|
||||
mesh_1d = DeviceMesh.from_group(mesh["tp"].get_group(), self.device_type)
|
||||
dtensor_1d = distribute_tensor(global_tensor2, mesh_1d, [Shard(0)])
|
||||
dtensor_list = [dtensor_2d, dtensor_1d]
|
||||
|
||||
# Check even with implicit replication, cross mesh error raises if different device mesh don't
|
||||
# belong to the same root mesh.
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "DTensor does not support cross-mesh operation yet!"
|
||||
):
|
||||
torch._foreach_mul_(dtensor_list, 2.0)
|
||||
|
||||
@with_comms
|
||||
def test_metadata_consistency_check(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
|
||||
import torch
|
||||
@ -7,6 +8,7 @@ import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from torch.distributed._composable.fsdp import fully_shard
|
||||
from torch.distributed._tensor import DTensor, init_device_mesh
|
||||
from torch.distributed._tensor.experimental import implicit_replication
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
@ -439,8 +441,17 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
self.assertEqual(base_osd, fsdp2_tp_full_osd)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_save_with_fsdp2_tp_and_load_with_tp(self):
|
||||
self.run_subtests(
|
||||
{"allow_implicit_replication": [True, False]},
|
||||
self._test_save_with_fsdp2_tp_and_load_with_tp,
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def _test_save_with_fsdp2_tp_and_load_with_tp(
|
||||
self, allow_implicit_replication: bool
|
||||
):
|
||||
"""
|
||||
Test that we can save a model with FSDP2 + TP on 2d mesh and load it with TP.
|
||||
"""
|
||||
@ -449,6 +460,11 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
base_model = nn.Sequential(MLP(mlp_dim), MLP(mlp_dim), MLP(mlp_dim))
|
||||
return base_model
|
||||
|
||||
cm = (
|
||||
implicit_replication()
|
||||
if allow_implicit_replication
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
tp_parallelize_plan = {
|
||||
"0.in_proj": ColwiseParallel(),
|
||||
"0.out_proj": RowwiseParallel(),
|
||||
@ -457,108 +473,124 @@ class TestFullyShardWithDistributedStateDict(FSDPTest):
|
||||
"2.in_proj": ColwiseParallel(),
|
||||
"2.out_proj": RowwiseParallel(),
|
||||
}
|
||||
if allow_implicit_replication:
|
||||
# intentionally pop the plans for some tp layers so that the model is not fully tensor parallelized
|
||||
tp_parallelize_plan.pop("0.in_proj")
|
||||
tp_parallelize_plan.pop("0.out_proj")
|
||||
|
||||
# init device mesh
|
||||
dp_size = 2
|
||||
global_mesh_1d = init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("tp",)
|
||||
)
|
||||
global_mesh_2d = init_device_mesh(
|
||||
"cuda", (dp_size, self.world_size // dp_size), mesh_dim_names=("dp", "tp")
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
|
||||
|
||||
for save_full_state_dict in [True, False]:
|
||||
# Save state dict with original model
|
||||
base_model = _get_base_model().cuda()
|
||||
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
|
||||
|
||||
# Save state dict with FSDP2 + TP model
|
||||
fsdp2_tp_model = copy.deepcopy(base_model)
|
||||
fsdp2_tp_model = parallelize_module(
|
||||
fsdp2_tp_model,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=tp_parallelize_plan,
|
||||
)
|
||||
for module in fsdp2_tp_model:
|
||||
fully_shard(module, mesh=dp_mesh)
|
||||
fully_shard(fsdp2_tp_model, mesh=dp_mesh)
|
||||
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
|
||||
|
||||
# one-step training to modify state dict
|
||||
inp = torch.randn((2,), device=self.rank)
|
||||
base_model(inp).sum().backward()
|
||||
base_optim.step()
|
||||
fsdp2_tp_model(inp).sum().backward()
|
||||
fsdp2_tp_optim.step()
|
||||
|
||||
# obtain the unsharded state dict
|
||||
base_msd = get_model_state_dict(
|
||||
base_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
base_osd = get_optimizer_state_dict(
|
||||
base_model,
|
||||
base_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
|
||||
# obtain FSDP2 + TP state dict
|
||||
fsdp2_tp_msd = get_model_state_dict(
|
||||
fsdp2_tp_model,
|
||||
options=StateDictOptions(full_state_dict=save_full_state_dict),
|
||||
)
|
||||
fsdp2_tp_osd = get_optimizer_state_dict(
|
||||
fsdp2_tp_model,
|
||||
fsdp2_tp_optim,
|
||||
options=StateDictOptions(full_state_dict=save_full_state_dict),
|
||||
)
|
||||
|
||||
fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd}
|
||||
dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir)
|
||||
|
||||
fsdp2_tp_full_msd = get_model_state_dict(
|
||||
fsdp2_tp_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
fsdp2_tp_full_osd = get_optimizer_state_dict(
|
||||
fsdp2_tp_model,
|
||||
fsdp2_tp_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
|
||||
# Load state dict into model with TP applied
|
||||
tp_model = _get_base_model()
|
||||
tp_model = parallelize_module(
|
||||
tp_model,
|
||||
device_mesh=global_mesh_1d,
|
||||
parallelize_plan=tp_parallelize_plan,
|
||||
)
|
||||
tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
|
||||
|
||||
tp_state_dict = {
|
||||
"model": get_model_state_dict(tp_model),
|
||||
"optim": get_optimizer_state_dict(tp_model, tp_optim),
|
||||
with cm:
|
||||
tp_parallelize_plan = {
|
||||
"0.in_proj": ColwiseParallel(),
|
||||
"0.out_proj": RowwiseParallel(),
|
||||
"1.in_proj": ColwiseParallel(),
|
||||
"1.out_proj": RowwiseParallel(),
|
||||
"2.in_proj": ColwiseParallel(),
|
||||
"2.out_proj": RowwiseParallel(),
|
||||
}
|
||||
dcp.load(tp_state_dict, checkpoint_id=self.temp_dir)
|
||||
tp_model.load_state_dict(tp_state_dict["model"])
|
||||
tp_optim.load_state_dict(tp_state_dict["optim"])
|
||||
|
||||
tp_full_msd = get_model_state_dict(
|
||||
tp_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
# init device mesh
|
||||
dp_size = 2
|
||||
global_mesh_1d = init_device_mesh(
|
||||
"cuda", (self.world_size,), mesh_dim_names=("tp",)
|
||||
)
|
||||
tp_full_osd = get_optimizer_state_dict(
|
||||
tp_model,
|
||||
tp_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
global_mesh_2d = init_device_mesh(
|
||||
"cuda",
|
||||
(dp_size, self.world_size // dp_size),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
dp_mesh, tp_mesh = global_mesh_2d["dp"], global_mesh_2d["tp"]
|
||||
|
||||
# Compare full state dict to make sure they are the same.
|
||||
self.assertEqual(base_msd, tp_full_msd)
|
||||
self.assertEqual(base_osd, tp_full_osd)
|
||||
self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
|
||||
self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
|
||||
for save_full_state_dict in [True, False]:
|
||||
# Save state dict with original model
|
||||
base_model = _get_base_model().cuda()
|
||||
base_optim = torch.optim.AdamW(base_model.parameters(), lr=0.1)
|
||||
|
||||
# Save state dict with FSDP2 + TP model
|
||||
fsdp2_tp_model = copy.deepcopy(base_model)
|
||||
fsdp2_tp_model = parallelize_module(
|
||||
fsdp2_tp_model,
|
||||
device_mesh=tp_mesh,
|
||||
parallelize_plan=tp_parallelize_plan,
|
||||
)
|
||||
for module in fsdp2_tp_model:
|
||||
fully_shard(module, mesh=dp_mesh)
|
||||
fully_shard(fsdp2_tp_model, mesh=dp_mesh)
|
||||
fsdp2_tp_optim = torch.optim.AdamW(fsdp2_tp_model.parameters(), lr=0.1)
|
||||
|
||||
# one-step training to modify state dict
|
||||
inp = torch.randn((2,), device=self.rank)
|
||||
base_model(inp).sum().backward()
|
||||
base_optim.step()
|
||||
fsdp2_tp_model(inp).sum().backward()
|
||||
fsdp2_tp_optim.step()
|
||||
|
||||
# obtain the unsharded state dict
|
||||
base_msd = get_model_state_dict(
|
||||
base_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
base_osd = get_optimizer_state_dict(
|
||||
base_model,
|
||||
base_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
|
||||
# obtain FSDP2 + TP state dict
|
||||
fsdp2_tp_msd = get_model_state_dict(
|
||||
fsdp2_tp_model,
|
||||
options=StateDictOptions(full_state_dict=save_full_state_dict),
|
||||
)
|
||||
fsdp2_tp_osd = get_optimizer_state_dict(
|
||||
fsdp2_tp_model,
|
||||
fsdp2_tp_optim,
|
||||
options=StateDictOptions(full_state_dict=save_full_state_dict),
|
||||
)
|
||||
|
||||
fsdp2_tp_state_dict = {"model": fsdp2_tp_msd, "optim": fsdp2_tp_osd}
|
||||
dcp.save(fsdp2_tp_state_dict, checkpoint_id=self.temp_dir)
|
||||
|
||||
fsdp2_tp_full_msd = get_model_state_dict(
|
||||
fsdp2_tp_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
fsdp2_tp_full_osd = get_optimizer_state_dict(
|
||||
fsdp2_tp_model,
|
||||
fsdp2_tp_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
|
||||
# Load state dict into model with TP applied
|
||||
tp_model = _get_base_model()
|
||||
tp_model = parallelize_module(
|
||||
tp_model,
|
||||
device_mesh=global_mesh_1d,
|
||||
parallelize_plan=tp_parallelize_plan,
|
||||
)
|
||||
tp_optim = torch.optim.AdamW(tp_model.parameters(), lr=0.1)
|
||||
|
||||
tp_state_dict = {
|
||||
"model": get_model_state_dict(tp_model),
|
||||
"optim": get_optimizer_state_dict(tp_model, tp_optim),
|
||||
}
|
||||
dcp.load(tp_state_dict, checkpoint_id=self.temp_dir)
|
||||
tp_model.load_state_dict(tp_state_dict["model"])
|
||||
tp_optim.load_state_dict(tp_state_dict["optim"])
|
||||
|
||||
tp_full_msd = get_model_state_dict(
|
||||
tp_model,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
tp_full_osd = get_optimizer_state_dict(
|
||||
tp_model,
|
||||
tp_optim,
|
||||
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
|
||||
)
|
||||
|
||||
# Compare full state dict to make sure they are the same.
|
||||
self.assertEqual(base_msd, tp_full_msd)
|
||||
self.assertEqual(base_osd, tp_full_osd)
|
||||
self.assertEqual(fsdp2_tp_full_msd, tp_full_msd)
|
||||
self.assertEqual(fsdp2_tp_full_osd, tp_full_osd)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -309,16 +309,20 @@ class OpDispatcher:
|
||||
|
||||
for arg in args_list:
|
||||
if isinstance(arg, dtensor.DTensor):
|
||||
args_schema.append(arg._spec)
|
||||
local_args.append(arg._local_tensor)
|
||||
if mesh is not None:
|
||||
if mesh != arg.device_mesh:
|
||||
raise NotImplementedError(
|
||||
f"{op_call}: DTensor does not support cross-mesh operation yet!"
|
||||
f"Got meshes: {mesh} {arg.device_mesh}"
|
||||
)
|
||||
if mesh is not None and mesh != arg.device_mesh:
|
||||
# TODO: try replicate dtensor spec in missing dimension would work
|
||||
# for most cases for foreach case except when the first DTensor in
|
||||
# the list is one that also need to be replicated. We need to revisit
|
||||
# how we want to handle this corner case. For now, this case would hit
|
||||
# the cross mesh error even if implicit replication is turned on.
|
||||
spec = self._try_replicate_dtensor_spec_in_missing_dim(
|
||||
op_call, arg, mesh
|
||||
)
|
||||
args_schema.append(spec)
|
||||
else:
|
||||
mesh = arg.device_mesh
|
||||
args_schema.append(arg._spec)
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
||||
args_schema.append(
|
||||
@ -331,15 +335,15 @@ class OpDispatcher:
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, dtensor.DTensor):
|
||||
kwargs_schema[k] = v._spec
|
||||
local_kwargs[k] = v._local_tensor
|
||||
if mesh is not None:
|
||||
if mesh != v.device_mesh:
|
||||
raise NotImplementedError(
|
||||
f"{op_call}: DTensor does not support cross-mesh operation yet!"
|
||||
)
|
||||
if mesh is not None and mesh != v.device_mesh:
|
||||
spec = self._try_replicate_dtensor_spec_in_missing_dim(
|
||||
op_call, v, mesh
|
||||
)
|
||||
kwargs_schema[k] = spec
|
||||
else:
|
||||
mesh = v.device_mesh
|
||||
kwargs_schema[k] = v._spec
|
||||
elif isinstance(v, torch.Tensor):
|
||||
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
|
||||
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
||||
@ -426,3 +430,39 @@ class OpDispatcher:
|
||||
" torch.Tensor to DTensor before calling distributed operators!"
|
||||
)
|
||||
return replication_spec
|
||||
|
||||
def _try_replicate_dtensor_spec_in_missing_dim(
|
||||
self,
|
||||
op_call: torch._ops.OpOverload,
|
||||
dtensor_arg: "dtensor.DTensor",
|
||||
mesh: "DeviceMesh",
|
||||
) -> DTensorSpec:
|
||||
# util function to produce a new spec for a DTensor arg/kwarg
|
||||
# that puts Replicate() placement in the missing dimension for foreach ops
|
||||
from torch.distributed.device_mesh import _mesh_resources
|
||||
|
||||
cur_mesh = dtensor_arg.device_mesh
|
||||
root_mesh = _mesh_resources.get_root_mesh(cur_mesh)
|
||||
if (
|
||||
self._allow_implicit_replication
|
||||
and "foreach" in op_call.__name__
|
||||
and root_mesh == mesh
|
||||
):
|
||||
placements = [Replicate() for _ in range(root_mesh.ndim)]
|
||||
cur_mesh_root_idx = _mesh_resources.get_root_mesh_dim(cur_mesh)
|
||||
placements[cur_mesh_root_idx] = dtensor_arg.placements[0] # type: ignore[call-overload]
|
||||
replicate_spec = DTensorSpec(
|
||||
root_mesh,
|
||||
tuple(placements),
|
||||
tensor_meta=TensorMeta(
|
||||
shape=dtensor_arg.shape,
|
||||
stride=dtensor_arg.stride(),
|
||||
dtype=dtensor_arg.dtype,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{op_call}: DTensor does not support cross-mesh operation yet! "
|
||||
f"Got meshes: {mesh} {cur_mesh}"
|
||||
)
|
||||
return replicate_spec
|
||||
|
Reference in New Issue
Block a user