[dtensor] fix new_empty_strided op (#107835)

This PR fixes the new_empty_strided op to become replicate from sharding
when necessary, this is a quick fix to resolve https://github.com/pytorch/pytorch/issues/107661

We'll need to think more about the behavior of this op when it comes to
sharding, one possibility is to follow the input sharding, but given the
output shape of this op might not be the same as the input, it's hard to
say we should follow the input sharding, further improvement needed once
we figure out the op syntax
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107835
Approved by: https://github.com/fduwjj
This commit is contained in:
Wanchao Liang
2023-08-31 08:36:08 -07:00
committed by PyTorch MergeBot
parent 46cd2fef3f
commit 74ff028839
2 changed files with 24 additions and 1 deletions

View File

@ -245,6 +245,29 @@ class DTensorTest(DTensorTestBase):
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
@with_comms
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
new_strided_dtensor = my_dtensor.new_empty_strided(
(8, 8), (8, 1), requires_grad=True
)
# test the op produces new dtensor and autograd works
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
new_strided_dtensor.sum().backward()
self.assertIsNotNone(new_strided_dtensor.grad)
self.assertIsInstance(new_strided_dtensor.grad, DTensor)
# test backward new_empty_strided with sharding works correctly
my_dtensor.to_local().sum().backward()
local_tensor.sum().backward()
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
self.assertEqual(
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
local_tensor.grad,
)
@with_comms
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,