mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] Add a private util for sharding tensor (#142288)
Locally shards a full tensor based on indicated sharding arrangement, and returns a DTensor containing the local shard. warning: This is a private API purposed to skip the communication otherwise required by `distribute_tensor`. It is only applicable to a case where all ranks have the same `full_tensor`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142288 Approved by: https://github.com/wz337
This commit is contained in:
@ -25,6 +25,7 @@ from torch.distributed._tensor.placement_types import (
|
||||
Shard,
|
||||
TensorMeta,
|
||||
)
|
||||
from torch.distributed.tensor._api import _shard_tensor
|
||||
from torch.distributed.tensor.debug import CommDebugMode
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
@ -572,6 +573,45 @@ class DTensorTest(DTensorTestBase):
|
||||
)
|
||||
self._attempt_load_from_subprocess(filename, import_string, err_msg)
|
||||
|
||||
@with_comms
|
||||
def test_shard_tensor(self):
|
||||
ws = self.world_size
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(ws)))
|
||||
full_tensor = torch.arange(ws * ws).reshape(ws, ws)
|
||||
|
||||
# Shard by row
|
||||
placements = [Shard(0)]
|
||||
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[range(self.rank, self.rank + 1), :])
|
||||
|
||||
# Shard by column
|
||||
placements = [Shard(1)]
|
||||
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([ws, ws]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor, full_tensor[:, range(self.rank, self.rank + 1)])
|
||||
|
||||
# assert full tensor is not changed
|
||||
self.assertEqual(full_tensor, torch.arange(ws * ws).reshape(ws, ws))
|
||||
|
||||
@with_comms
|
||||
def test_shard_tensor_2d(self):
|
||||
ws = self.world_size
|
||||
full_tensor = torch.arange(ws).reshape(2, ws // 2)
|
||||
device_mesh = DeviceMesh(self.device_type, full_tensor)
|
||||
|
||||
# Shard by row and column
|
||||
placements = [Shard(0), Shard(1)]
|
||||
sharded_tensor = _shard_tensor(full_tensor, placements, device_mesh)
|
||||
self.assertEqual(sharded_tensor.size(), torch.Size([2, ws // 2]))
|
||||
self.assertEqual(sharded_tensor.placements, placements)
|
||||
local_tensor = sharded_tensor.to_local()
|
||||
self.assertEqual(local_tensor.item(), self.rank)
|
||||
|
||||
|
||||
class DTensorMeshTest(DTensorTestBase):
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user