[DTensor] allow numel 1 tensor operand to be implicitly replicate DTensor (#125073)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125073
Approved by: https://github.com/wanchaol
This commit is contained in:
wz337
2024-05-08 19:47:41 +00:00
committed by PyTorch MergeBot
parent 445a0c01da
commit 603d1e6049
2 changed files with 41 additions and 1 deletions

View File

@ -763,6 +763,32 @@ class DTensorMeshTest(DTensorTestBase):
self.assertEqual(local_shard.shape, (4, 3))
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
@with_comms
def test_auto_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor = torch.ones(self.world_size, 3, device=self.device_type)
sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
# automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1
ndim_0_tensor = torch.tensor(1, device=self.device_type)
def add_scalar_tensor_with_dtensor():
return sharded_dtensor + ndim_0_tensor
result = add_scalar_tensor_with_dtensor().to_local()
self.assertEqual(result, local_tensor + ndim_0_tensor)
self.assertNotWarn(
add_scalar_tensor_with_dtensor,
"Found a non-scalar tensor with numel=1 and ndim!=0",
)
# automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
numel_1_tensor = torch.tensor([1], device=self.device_type)
self.assertEqual(
(sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor
)
class TestDTensorPlacementTypes(DTensorTestBase):
@property