[dtensor] Introduce full_tensor API to DTensor (#112224)

This PR introduces a `full_tensor` API to DTensor, there were so many
callsites that exercises the `redistribute(replicate)` path and I feel
it deserves a separate API, mostly just a syntactic sugar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112224
Approved by: https://github.com/wz337
This commit is contained in:
Wanchao Liang
2023-10-27 14:42:47 -07:00
committed by PyTorch MergeBot
parent e2cd69a770
commit 2f09da3a21
11 changed files with 67 additions and 69 deletions

View File

@ -265,9 +265,20 @@ class DTensorTest(DTensorTestBase):
)
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.redistribute(
placements=[Replicate()]
).to_local()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_full_tensor_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms