mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] handle negative dim and fix TP regression (#111750)
TP style still have some regression due to negative dim specifications, fix it by allow DTensor API to handle negative dims and normalize them. i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d291e1f19
commit
61461f39d1
@ -194,6 +194,14 @@ class DTensorTest(DTensorTestBase):
|
||||
expected_grad = torch.ones(3, 3) * 9
|
||||
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
|
||||
|
||||
@with_comms
|
||||
def test_from_local_negative_dim(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(-1)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
self.assertEqual(sharded_tensor.placements[0].dim, 1)
|
||||
|
||||
@with_comms
|
||||
def test_to_local(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
Reference in New Issue
Block a user