mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] Introduce full_tensor API to DTensor (#112224)
This PR introduces a `full_tensor` API to DTensor, there were so many callsites that exercises the `redistribute(replicate)` path and I feel it deserves a separate API, mostly just a syntactic sugar Pull Request resolved: https://github.com/pytorch/pytorch/pull/112224 Approved by: https://github.com/wz337
This commit is contained in:
committed by
PyTorch MergeBot
parent
e2cd69a770
commit
2f09da3a21
@ -265,9 +265,20 @@ class DTensorTest(DTensorTestBase):
|
||||
)
|
||||
local_out.sum().backward()
|
||||
|
||||
replica_grad = sharded_dtensor.grad.redistribute(
|
||||
placements=[Replicate()]
|
||||
).to_local()
|
||||
replica_grad = sharded_dtensor.grad.full_tensor()
|
||||
self.assertEqual(replica_grad, global_tensor * self.world_size)
|
||||
|
||||
@with_comms
|
||||
def test_full_tensor_grad_hint(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
|
||||
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
|
||||
local_out.sum().backward()
|
||||
|
||||
replica_grad = sharded_dtensor.grad.full_tensor()
|
||||
self.assertEqual(replica_grad, global_tensor * self.world_size)
|
||||
|
||||
@with_comms
|
||||
|
||||
Reference in New Issue
Block a user