[DCP] Enable nD device_mesh resharding DTensor in DCP and add associated tests (#106230)

This PR:
     1. Drop assert for 1D DeviceMesh check to allow DTensor with nD DeviceMesh when creating write_item.
     2. Add tests for both placement changes and mesh changes for both 1D and 2D scenarios.

cc. @kumpera  @wanchaol  @fegin
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106230
Approved by: https://github.com/kumpera
This commit is contained in:
Iris
2023-09-12 00:47:54 +00:00
committed by PyTorch MergeBot
parent 8025b193a9
commit b6f9d4dbc4
3 changed files with 254 additions and 5 deletions

View File

@ -0,0 +1,253 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
zeros,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
CHECKPOINT_DIR = "checkpoint"
ONE_D_PLACEMENTS = [
[Shard(0)],
[Replicate()],
]
ONE_D_TO_ONE_D_PLACEMENTS = [
([Replicate()], [Shard(0)]),
([Shard(0)], [Replicate()]),
]
TWO_D_PLACEMENTS = [
[Replicate(), Replicate()],
[Replicate(), Shard(0)],
[Shard(0), Replicate()],
[Shard(0), Shard(0)],
]
TWO_D_TO_TWO_D_PLACEMENTS = []
for p1 in TWO_D_PLACEMENTS:
for p2 in TWO_D_PLACEMENTS:
if p1 != p2:
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
class TestDTensorReshardPlacementChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
"""
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("one_d_to_one_d_placements", ONE_D_TO_ONE_D_PLACEMENTS)
@with_temp_dir
def test_1d_to_1d_reshard_placement_change(self, one_d_to_one_d_placements) -> None:
CHECKPOINT_DIR = self.temp_dir
original_placement, new_placement = one_d_to_one_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, device_mesh, placements=original_placement
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save_state_dict(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
zero_dtensor = zeros([4, 4], device_mesh=device_mesh, placements=new_placement)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=[Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
# redistribute the tensor back to its original placment for comparison.
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
@with_comms
@skip_if_lt_x_gpu(2)
@parametrize("two_d_to_two_d_placements", TWO_D_TO_TWO_D_PLACEMENTS)
@with_temp_dir
def test_2d_to_2d_reshard_placement_change(self, two_d_to_two_d_placements) -> None:
CHECKPOINT_DIR = self.temp_dir
original_placement, new_placement = two_d_to_two_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor,
mesh_2d,
placements=original_placement,
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save_state_dict(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
class TestDTensorReshardMeshChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
"""
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_1d_to_2d_reshard_mesh_change(self) -> None:
CHECKPOINT_DIR = self.temp_dir
for placements_1d in ONE_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_1d, placements=placements_1d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save_state_dict(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
for placements_2d in TWO_D_PLACEMENTS:
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_2d, placements=placements_2d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_2d_to_1d_reshard_mesh_change(self) -> None:
CHECKPOINT_DIR = self.temp_dir
for placements_2d in TWO_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_2d, placements=placements_2d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save_state_dict(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
for placements_1d in ONE_D_PLACEMENTS:
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_1d, placements=placements_1d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load_state_dict(
state_dict=state_dict_to_load,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_1d,
placements=[Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
# TODO: Add dtensor resharding test when world size changes.
instantiate_parametrized_tests(TestDTensorReshardPlacementChange)
if __name__ == "__main__":
run_tests()

View File

@ -58,8 +58,6 @@ def _sharded_tensor_metadata(
def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
# TODO: remove assert after 2D tests are added.
assert tensor.device_mesh.ndim == 1, "Only 1D DeviceMeshes can currently be handled."
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)
@ -222,8 +220,6 @@ def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
# TODO: remove assert after 2D tests are added.
assert tensor.device_mesh.ndim == 1, "Only 1D DeviceMeshes can currently be handled."
sizes, offsets = compute_local_shape_and_global_offset(
tensor.shape, tensor.device_mesh, tensor.placements
)

View File

@ -31,7 +31,7 @@ def with_temp_dir(
self.temp_dir = object_list[0]
try:
func(self)
func(self, *args, **kwargs)
finally:
if dist.get_rank() == 0:
shutil.rmtree(self.temp_dir, ignore_errors=True)