mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725)
Fix https://github.com/pytorch/pytorch/issues/134095 This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective). This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725 Approved by: https://github.com/fegin
This commit is contained in:
@ -17,7 +17,9 @@ from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shar
|
||||
from torch.distributed.checkpoint.state_dict import (
|
||||
get_model_state_dict,
|
||||
get_optimizer_state_dict,
|
||||
set_model_state_dict,
|
||||
set_optimizer_state_dict,
|
||||
StateDictOptions,
|
||||
)
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
@ -335,6 +337,57 @@ class TestFullyShard2DTraining(FSDPTest):
|
||||
self.assertEqual(loss_no_cp2, loss_cp2)
|
||||
|
||||
|
||||
class TestFullyShard2DStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fully_shard_tp_2d_set_full_state_dict(self):
|
||||
dummy_model = SimpleModel().cuda()
|
||||
mesh_2d = init_device_mesh(
|
||||
"cuda",
|
||||
(2, self.world_size // 2),
|
||||
mesh_dim_names=("dp", "tp"),
|
||||
)
|
||||
tp_mesh = mesh_2d["tp"]
|
||||
dp_mesh = mesh_2d["dp"]
|
||||
parallelize_plan = {
|
||||
"net1": ColwiseParallel(),
|
||||
"net2": RowwiseParallel(),
|
||||
"net3": ColwiseParallel(),
|
||||
}
|
||||
model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
|
||||
fully_shard(model, mesh=dp_mesh)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.01)
|
||||
model(model.get_input()).sum().backward()
|
||||
optim.step()
|
||||
# ref_msd, ref_osd are both the default sharded state dict
|
||||
ref_msd = copy.deepcopy(get_model_state_dict(model))
|
||||
ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim))
|
||||
|
||||
options = StateDictOptions(
|
||||
full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True
|
||||
)
|
||||
full_msd = get_model_state_dict(model, options=options)
|
||||
full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options)
|
||||
# load full_msd and full_osd into model and optim.
|
||||
# this loads the slice of full tensor into each rank's local DTensor.
|
||||
set_model_state_dict(model, full_msd, options=options)
|
||||
set_optimizer_state_dict(
|
||||
model, optimizers=optim, optim_state_dict=full_osd, options=options
|
||||
)
|
||||
|
||||
# check after setting full state dict, the model and optim default sharded state dict
|
||||
# are the same as the initial default sharded state dict.
|
||||
new_msd = get_model_state_dict(model)
|
||||
new_osd = get_optimizer_state_dict(model, optimizers=optim)
|
||||
self.assertEqual(ref_msd, new_msd)
|
||||
self.assertEqual(ref_osd, new_osd)
|
||||
|
||||
|
||||
class Test2dFSDP1ParallelIntegration(DTensorTestBase):
|
||||
def init_model(self, device_type, model_parallel_size=2):
|
||||
torch.manual_seed(0)
|
||||
@ -544,6 +597,11 @@ class TestNew2dParallelTraining(DTensorTestBase):
|
||||
# TODO: update all state dict unit tests to use distributed.checkpoint.state_dict,
|
||||
# and consolidate all the state_dict test in test.distributed.checkpoint.
|
||||
class TestNew2dParallelStateDict(DTensorTestBase):
|
||||
@property
|
||||
def backend(self):
|
||||
# need to specify gloo backend for testing cpu offload
|
||||
return "cpu:gloo,cuda:nccl"
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_fsdp_2d_extension(self):
|
||||
|
@ -28,6 +28,7 @@ if dist.is_available() or TYPE_CHECKING:
|
||||
from torch.distributed import distributed_c10d
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate
|
||||
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
|
||||
|
||||
|
||||
def _identity_func(
|
||||
@ -551,8 +552,14 @@ def _distribute_tensors(
|
||||
|
||||
local_state = _local_state[0]
|
||||
full_tensor = _local_state[1]
|
||||
local_state_dict[key] = distribute_tensor(
|
||||
full_tensor, local_state.device_mesh, local_state.placements
|
||||
|
||||
shape, offset = compute_local_shape_and_global_offset(
|
||||
full_tensor.shape, local_state.device_mesh, local_state.placements
|
||||
)
|
||||
slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))]
|
||||
local_tensor = full_tensor[slices]
|
||||
local_state_dict[key] = DTensor.from_local(
|
||||
local_tensor, local_state.device_mesh, local_state.placements
|
||||
)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user