mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] handle negative dim and fix TP regression (#111750)
TP style still have some regression due to negative dim specifications, fix it by allow DTensor API to handle negative dims and normalize them. i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d291e1f19
commit
61461f39d1
@ -57,6 +57,12 @@ class DTensorAPITest(DTensorTestBase):
|
||||
self.assertTrue(dist_tensor.requires_grad)
|
||||
self.assertTrue(dist_tensor.is_leaf)
|
||||
|
||||
# test negative dim
|
||||
shard_minus_spec = [Shard(-1)]
|
||||
tensor_to_shard = torch.randn(3, 3 * self.world_size)
|
||||
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
|
||||
self.assertEqual(dist_tensor.placements[0].dim, 1)
|
||||
|
||||
@with_comms
|
||||
def test_distribute_tensor_errors(self):
|
||||
device_mesh = DeviceMesh(
|
||||
|
@ -194,6 +194,14 @@ class DTensorTest(DTensorTestBase):
|
||||
expected_grad = torch.ones(3, 3) * 9
|
||||
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
|
||||
|
||||
@with_comms
|
||||
def test_from_local_negative_dim(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
shard_spec = [Shard(-1)]
|
||||
local_tensor = torch.randn(3, 3)
|
||||
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
|
||||
self.assertEqual(sharded_tensor.placements[0].dim, 1)
|
||||
|
||||
@with_comms
|
||||
def test_to_local(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
|
@ -67,6 +67,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
dist_y = torch.nn.functional.softmax(
|
||||
dist_x, dim=softmax_dim, dtype=torch.float32
|
||||
)
|
||||
shard_dim = shard_dim + dist_y.ndim if shard_dim < 0 else shard_dim
|
||||
self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim))
|
||||
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
|
||||
self.assertEqual(dist_y.to_local(), local_y)
|
||||
@ -102,6 +103,7 @@ class DistMathOpsTest(DTensorTestBase):
|
||||
dist_softmax = dist_x.softmax(dim=softmax_dim)
|
||||
else:
|
||||
dist_softmax = dist_x.softmax(dim=softmax_dim)
|
||||
shard_dim = shard_dim + dist_x.ndim if shard_dim < 0 else shard_dim
|
||||
self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim))
|
||||
dist_y = dist_softmax.sum()
|
||||
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
|
||||
|
@ -222,6 +222,18 @@ class RedistributeTest(DTensorTestBase):
|
||||
torch.ones(local_shape) * self.world_size,
|
||||
)
|
||||
|
||||
@with_comms
|
||||
def test_redistribute_negative_shard_dim(self):
|
||||
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
|
||||
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
|
||||
shard_spec = [Shard(1)]
|
||||
shard_minus_spec = [Shard(-1)]
|
||||
|
||||
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
|
||||
self.assertEqual(shard_tensor.placements[0].dim, 1)
|
||||
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
|
||||
self.assertEqual(shard_tensor.placements[0].dim, 1)
|
||||
|
||||
|
||||
class MultiDimRedistributeTest(DTensorTestBase):
|
||||
@property
|
||||
|
@ -143,7 +143,12 @@ def compute_global_tensor_info(
|
||||
for idx, placement in enumerate(placements):
|
||||
mesh_dim_size = mesh.size(idx)
|
||||
if placement.is_shard():
|
||||
shard_dim = cast(Shard, placement).dim
|
||||
shard_placement = cast(Shard, placement)
|
||||
if shard_placement.dim < 0:
|
||||
# normalize shard dim to be positive
|
||||
shard_placement.dim += len(tensor_shape)
|
||||
shard_dim = shard_placement.dim
|
||||
|
||||
local_dim_size = tensor_shape[shard_dim]
|
||||
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size
|
||||
|
||||
|
@ -393,15 +393,18 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
|
||||
if placements is None:
|
||||
raise RuntimeError("placements is needed for redistribute!")
|
||||
|
||||
# Early return the original DTensor if the placements are the same.
|
||||
if self._spec.placements == placements:
|
||||
return self
|
||||
|
||||
for placement in placements:
|
||||
if placement.is_partial():
|
||||
raise RuntimeError(
|
||||
"Can not redistribute to _Partial, _Partial is for internal use only!"
|
||||
)
|
||||
elif isinstance(placement, Shard) and placement.dim < 0:
|
||||
# normalize shard dim to be positive
|
||||
placement.dim += self.ndim
|
||||
|
||||
# Early return the original DTensor if the placements are the same.
|
||||
if self._spec.placements == placements:
|
||||
return self
|
||||
|
||||
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
|
||||
return Redistribute.apply(self, device_mesh, placements)
|
||||
@ -519,6 +522,9 @@ def distribute_tensor(
|
||||
for idx, placement in enumerate(placements):
|
||||
if placement.is_shard():
|
||||
placement = cast(Shard, placement)
|
||||
if placement.dim < 0:
|
||||
# normalize shard placement dim
|
||||
placement.dim += tensor.ndim
|
||||
local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
|
||||
elif placement.is_replicate():
|
||||
placement = cast(Replicate, placement)
|
||||
|
Reference in New Issue
Block a user