mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] add grad placements kwarg to to_local API (#110629)
When we convert to local tensor, dtensor can't track autograd or gradient layout of the local tensor anymore, if user do sth not expected, there needs to be a way for user to hint about the gradient layout of the local tensor Pull Request resolved: https://github.com/pytorch/pytorch/pull/110629 Approved by: https://github.com/zdevito
This commit is contained in:
committed by
PyTorch MergeBot
parent
ada65508d2
commit
c95cf4b4c9
@ -245,6 +245,23 @@ class DTensorTest(DTensorTestBase):
|
||||
except RuntimeError:
|
||||
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
|
||||
|
||||
@with_comms
|
||||
def test_to_local_grad_hint(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = (Shard(0),)
|
||||
global_tensor = torch.ones(8, 3, requires_grad=True)
|
||||
|
||||
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, shard_spec)
|
||||
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
|
||||
grad_placements=[_Partial()]
|
||||
)
|
||||
local_out.sum().backward()
|
||||
|
||||
replica_grad = sharded_dtensor.grad.redistribute(
|
||||
placements=[Replicate()]
|
||||
).to_local()
|
||||
self.assertEqual(replica_grad, global_tensor * self.world_size)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_new_empty_strided(self):
|
||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
Reference in New Issue
Block a user