mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] allow numel 1 tensor operand to be implicitly replicate DTensor (#125073)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/125073 Approved by: https://github.com/wanchaol
This commit is contained in:
@ -763,6 +763,32 @@ class DTensorMeshTest(DTensorTestBase):
|
||||
self.assertEqual(local_shard.shape, (4, 3))
|
||||
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
|
||||
|
||||
@with_comms
|
||||
def test_auto_implicit_replication(self):
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
|
||||
local_tensor = torch.ones(self.world_size, 3, device=self.device_type)
|
||||
sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
|
||||
|
||||
# automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1
|
||||
ndim_0_tensor = torch.tensor(1, device=self.device_type)
|
||||
|
||||
def add_scalar_tensor_with_dtensor():
|
||||
return sharded_dtensor + ndim_0_tensor
|
||||
|
||||
result = add_scalar_tensor_with_dtensor().to_local()
|
||||
self.assertEqual(result, local_tensor + ndim_0_tensor)
|
||||
self.assertNotWarn(
|
||||
add_scalar_tensor_with_dtensor,
|
||||
"Found a non-scalar tensor with numel=1 and ndim!=0",
|
||||
)
|
||||
|
||||
# automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
|
||||
numel_1_tensor = torch.tensor([1], device=self.device_type)
|
||||
self.assertEqual(
|
||||
(sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
@property
|
||||
|
Reference in New Issue
Block a user