mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DCP] Enable nD device_mesh resharding DTensor in DCP and add associated tests (#106230)
This PR: 1. Drop assert for 1D DeviceMesh check to allow DTensor with nD DeviceMesh when creating write_item. 2. Add tests for both placement changes and mesh changes for both 1D and 2D scenarios. cc. @kumpera @wanchaol @fegin Pull Request resolved: https://github.com/pytorch/pytorch/pull/106230 Approved by: https://github.com/kumpera
This commit is contained in:
253
test/distributed/checkpoint/test_dtensor_resharding.py
Normal file
253
test/distributed/checkpoint/test_dtensor_resharding.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed._tensor import (
|
||||
distribute_tensor,
|
||||
init_device_mesh,
|
||||
Replicate,
|
||||
Shard,
|
||||
zeros,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
parametrize,
|
||||
run_tests,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
CHECKPOINT_DIR = "checkpoint"
|
||||
|
||||
ONE_D_PLACEMENTS = [
|
||||
[Shard(0)],
|
||||
[Replicate()],
|
||||
]
|
||||
ONE_D_TO_ONE_D_PLACEMENTS = [
|
||||
([Replicate()], [Shard(0)]),
|
||||
([Shard(0)], [Replicate()]),
|
||||
]
|
||||
|
||||
TWO_D_PLACEMENTS = [
|
||||
[Replicate(), Replicate()],
|
||||
[Replicate(), Shard(0)],
|
||||
[Shard(0), Replicate()],
|
||||
[Shard(0), Shard(0)],
|
||||
]
|
||||
TWO_D_TO_TWO_D_PLACEMENTS = []
|
||||
for p1 in TWO_D_PLACEMENTS:
|
||||
for p2 in TWO_D_PLACEMENTS:
|
||||
if p1 != p2:
|
||||
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
|
||||
|
||||
|
||||
class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("one_d_to_one_d_placements", ONE_D_TO_ONE_D_PLACEMENTS)
|
||||
@with_temp_dir
|
||||
def test_1d_to_1d_reshard_placement_change(self, one_d_to_one_d_placements) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
original_placement, new_placement = one_d_to_one_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, device_mesh, placements=original_placement
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save_state_dict(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros([4, 4], device_mesh=device_mesh, placements=new_placement)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load_state_dict(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
# redistribute the tensor back to its original placment for comparison.
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@parametrize("two_d_to_two_d_placements", TWO_D_TO_TWO_D_PLACEMENTS)
|
||||
@with_temp_dir
|
||||
def test_2d_to_2d_reshard_placement_change(self, two_d_to_two_d_placements) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
original_placement, new_placement = two_d_to_two_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor,
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save_state_dict(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load_state_dict(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1d_to_2d_reshard_mesh_change(self) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save_state_dict(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load_state_dict(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_2d_to_1d_reshard_mesh_change(self) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save_state_dict(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load_state_dict(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_1d,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
|
||||
# TODO: Add dtensor resharding test when world size changes.
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestDTensorReshardPlacementChange)
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -58,8 +58,6 @@ def _sharded_tensor_metadata(
|
||||
|
||||
|
||||
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
|
||||
# TODO: remove assert after 2D tests are added.
|
||||
assert tensor.device_mesh.ndim == 1, "Only 1D DeviceMeshes can currently be handled."
|
||||
sizes, offsets = compute_local_shape_and_global_offset(
|
||||
tensor.shape, tensor.device_mesh, tensor.placements
|
||||
)
|
||||
@ -222,8 +220,6 @@ def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
|
||||
|
||||
|
||||
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
|
||||
# TODO: remove assert after 2D tests are added.
|
||||
assert tensor.device_mesh.ndim == 1, "Only 1D DeviceMeshes can currently be handled."
|
||||
sizes, offsets = compute_local_shape_and_global_offset(
|
||||
tensor.shape, tensor.device_mesh, tensor.placements
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ def with_temp_dir(
|
||||
self.temp_dir = object_list[0]
|
||||
|
||||
try:
|
||||
func(self)
|
||||
func(self, *args, **kwargs)
|
||||
finally:
|
||||
if dist.get_rank() == 0:
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
Reference in New Issue
Block a user