[DTensor] Add a private util for sharding tensor (#142288)

Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard.

warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142288
Approved by: https://github.com/wz337
This commit is contained in:
Ke Wen
2024-12-06 21:27:14 -08:00
committed by PyTorch MergeBot
parent 2d9b081012
commit a58d2f14e8
2 changed files with 87 additions and 0 deletions

View File

@ -25,6 +25,7 @@ from torch.distributed._tensor.placement_types import (
Shard,
TensorMeta,
)
from torch.distributed.tensor._api import _shard_tensor
from torch.distributed.tensor.debug import CommDebugMode
from torch.distributed.tensor.parallel import (
ColwiseParallel,
@ -572,6 +573,45 @@ class DTensorTest(DTensorTestBase):
)
self._attempt_load_from_subprocess(filename, import_string, err_msg)
@with_comms
def test_shard_tensor(self):
ws = self.world_size
device_mesh = DeviceMesh(self.device_type, list(range(ws)))
full_tensor = torch.arange(ws * ws).reshape(ws, ws)
# Shard by row
placements = [Shard(0)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
# Shard by column
placements = [Shard(1)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
# assert full tensor is not changed
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
@with_comms
def test_shard_tensor_2d(self):
ws = self.world_size
full_tensor = torch.arange(ws).reshape(2, ws // 2)
device_mesh = DeviceMesh(self.device_type, full_tensor)
# Shard by row and column
placements = [Shard(0), Shard(1)]
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
self.assertEqual(sharded_tensor.size(), torch.Size([2, ws // 2]))
self.assertEqual(sharded_tensor.placements, placements)
local_tensor = sharded_tensor.to_local()
self.assertEqual(local_tensor.item(), self.rank)
class DTensorMeshTest(DTensorTestBase):
@property