[dtensor] handle negative dim and fix TP regression (#111750)

TP style still have some regression due to negative dim specifications,
fix it by allow DTensor API to handle negative dims and normalize them.

i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750
Approved by: https://github.com/rohan-varma
This commit is contained in:
Wanchao Liang
2023-10-21 18:47:32 -07:00
committed by PyTorch MergeBot
parent 1d291e1f19
commit 61461f39d1
6 changed files with 44 additions and 5 deletions

View File

@ -194,6 +194,14 @@ class DTensorTest(DTensorTestBase):
expected_grad = torch.ones(3, 3) * 9
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))