mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] fix new_empty_strided op (#107835)
This PR fixes the new_empty_strided op to become replicate from sharding when necessary, this is a quick fix to resolve https://github.com/pytorch/pytorch/issues/107661 We'll need to think more about the behavior of this op when it comes to sharding, one possibility is to follow the input sharding, but given the output shape of this op might not be the same as the input, it's hard to say we should follow the input sharding, further improvement needed once we figure out the op syntax Pull Request resolved: https://github.com/pytorch/pytorch/pull/107835 Approved by: https://github.com/fduwjj
This commit is contained in:
committed by
PyTorch MergeBot
parent
46cd2fef3f
commit
74ff028839
@ -245,6 +245,29 @@ class DTensorTest(DTensorTestBase):
|
||||
except RuntimeError:
|
||||
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_new_empty_strided(self):
|
||||
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
|
||||
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
|
||||
new_strided_dtensor = my_dtensor.new_empty_strided(
|
||||
(8, 8), (8, 1), requires_grad=True
|
||||
)
|
||||
# test the op produces new dtensor and autograd works
|
||||
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
|
||||
new_strided_dtensor.sum().backward()
|
||||
self.assertIsNotNone(new_strided_dtensor.grad)
|
||||
self.assertIsInstance(new_strided_dtensor.grad, DTensor)
|
||||
|
||||
# test backward new_empty_strided with sharding works correctly
|
||||
my_dtensor.to_local().sum().backward()
|
||||
local_tensor.sum().backward()
|
||||
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
|
||||
self.assertEqual(
|
||||
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
|
||||
local_tensor.grad,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_dtensor_async_output(self):
|
||||
# Tests that if the output of some dtensor operations isn't used in any compute,
|
||||
|
Reference in New Issue
Block a user