[DSD] Fix distributed state dict full_state_dict option hang during set_state_dict (#135725)

Fix https://github.com/pytorch/pytorch/issues/134095
This fix distributed state dict full_state_dict option hang during set_state_dict. We switch `_distribute_tensors` in _state_dict_utils.py to use `DTensor.from_local` instead of `distribute_tensor` to support FSDP2+TP 2D strided sharding use case, as `distribute_tensor` cannot handle strided sharding yet. `distribute_tensor` incurs a scatter behind the scenes, while `DTensor.from_local` takes the local slice from the full tensor on each rank to create the DTensor (no collective).  This means it's the user's responsibility to make sure the full_tensor from the full_state_dict is the same across all ranks.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135725
Approved by: https://github.com/fegin
This commit is contained in:
wz337
2024-09-12 14:53:32 -07:00
committed by PyTorch MergeBot
parent 6cdc70bccd
commit 0cdc6a8dcd
2 changed files with 67 additions and 2 deletions

View File

@ -17,7 +17,9 @@ from torch.distributed._tensor import DTensor, init_device_mesh, Replicate, Shar
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@ -335,6 +337,57 @@ class TestFullyShard2DTraining(FSDPTest):
self.assertEqual(loss_no_cp2, loss_cp2)
class TestFullyShard2DStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl"
@with_comms
@skip_if_lt_x_gpu(4)
def test_fully_shard_tp_2d_set_full_state_dict(self):
dummy_model = SimpleModel().cuda()
mesh_2d = init_device_mesh(
"cuda",
(2, self.world_size // 2),
mesh_dim_names=("dp", "tp"),
)
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
"net3": ColwiseParallel(),
}
model = parallelize_module(dummy_model, tp_mesh, parallelize_plan)
fully_shard(model, mesh=dp_mesh)
optim = torch.optim.Adam(model.parameters(), lr=0.01)
model(model.get_input()).sum().backward()
optim.step()
# ref_msd, ref_osd are both the default sharded state dict
ref_msd = copy.deepcopy(get_model_state_dict(model))
ref_osd = copy.deepcopy(get_optimizer_state_dict(model, optimizers=optim))
options = StateDictOptions(
full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True
)
full_msd = get_model_state_dict(model, options=options)
full_osd = get_optimizer_state_dict(model, optimizers=optim, options=options)
# load full_msd and full_osd into model and optim.
# this loads the slice of full tensor into each rank's local DTensor.
set_model_state_dict(model, full_msd, options=options)
set_optimizer_state_dict(
model, optimizers=optim, optim_state_dict=full_osd, options=options
)
# check after setting full state dict, the model and optim default sharded state dict
# are the same as the initial default sharded state dict.
new_msd = get_model_state_dict(model)
new_osd = get_optimizer_state_dict(model, optimizers=optim)
self.assertEqual(ref_msd, new_msd)
self.assertEqual(ref_osd, new_osd)
class Test2dFSDP1ParallelIntegration(DTensorTestBase):
def init_model(self, device_type, model_parallel_size=2):
torch.manual_seed(0)
@ -544,6 +597,11 @@ class TestNew2dParallelTraining(DTensorTestBase):
# TODO: update all state dict unit tests to use distributed.checkpoint.state_dict,
# and consolidate all the state_dict test in test.distributed.checkpoint.
class TestNew2dParallelStateDict(DTensorTestBase):
@property
def backend(self):
# need to specify gloo backend for testing cpu offload
return "cpu:gloo,cuda:nccl"
@with_comms
@skip_if_lt_x_gpu(4)
def test_fsdp_2d_extension(self):

View File

@ -28,6 +28,7 @@ if dist.is_available() or TYPE_CHECKING:
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.tensor import distribute_tensor, DTensor, Replicate
from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
def _identity_func(
@ -551,8 +552,14 @@ def _distribute_tensors(
local_state = _local_state[0]
full_tensor = _local_state[1]
local_state_dict[key] = distribute_tensor(
full_tensor, local_state.device_mesh, local_state.placements
shape, offset = compute_local_shape_and_global_offset(
full_tensor.shape, local_state.device_mesh, local_state.placements
)
slices = [slice(offset[i], shape[i] + offset[i]) for i in range(len(shape))]
local_tensor = full_tensor[slices]
local_state_dict[key] = DTensor.from_local(
local_tensor, local_state.device_mesh, local_state.placements
)