[dtensor] add grad placements kwarg to to_local API (#110629)

When we convert to local tensor, dtensor can't track autograd or
gradient layout of the local tensor anymore, if user do sth not expected, there
needs to be a way for user to hint about the gradient layout of the
local tensor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110629
Approved by: https://github.com/zdevito
This commit is contained in:
Wanchao Liang
2023-10-05 11:53:53 -07:00
committed by PyTorch MergeBot
parent ada65508d2
commit c95cf4b4c9
2 changed files with 61 additions and 13 deletions

View File

@ -245,6 +245,23 @@ class DTensorTest(DTensorTestBase):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
)
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.redistribute(
placements=[Replicate()]
).to_local()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))