[dtensor] handle negative dim and fix TP regression (#111750)

TP style still have some regression due to negative dim specifications,
fix it by allow DTensor API to handle negative dims and normalize them.

i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750
Approved by: https://github.com/rohan-varma
This commit is contained in:
Wanchao Liang
2023-10-21 18:47:32 -07:00
committed by PyTorch MergeBot
parent 1d291e1f19
commit 61461f39d1
6 changed files with 44 additions and 5 deletions

View File

@ -57,6 +57,12 @@ class DTensorAPITest(DTensorTestBase):
self.assertTrue(dist_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
# test negative dim
shard_minus_spec = [Shard(-1)]
tensor_to_shard = torch.randn(3, 3 * self.world_size)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_minus_spec)
self.assertEqual(dist_tensor.placements[0].dim, 1)
@with_comms
def test_distribute_tensor_errors(self):
device_mesh = DeviceMesh(

View File

@ -194,6 +194,14 @@ class DTensorTest(DTensorTestBase):
expected_grad = torch.ones(3, 3) * 9
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, shard_spec)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))

View File

@ -67,6 +67,7 @@ class DistMathOpsTest(DTensorTestBase):
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
shard_dim = shard_dim + dist_y.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim))
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_y.to_local(), local_y)
@ -102,6 +103,7 @@ class DistMathOpsTest(DTensorTestBase):
dist_softmax = dist_x.softmax(dim=softmax_dim)
else:
dist_softmax = dist_x.softmax(dim=softmax_dim)
shard_dim = shard_dim + dist_x.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim))
dist_y = dist_softmax.sum()
dist_y = dist_y.redistribute(device_mesh, [Replicate()])

View File

@ -222,6 +222,18 @@ class RedistributeTest(DTensorTestBase):
torch.ones(local_shape) * self.world_size,
)
@with_comms
def test_redistribute_negative_shard_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
local_tensor = torch.randn(12, 3, device=self.device_type, requires_grad=True)
shard_spec = [Shard(1)]
shard_minus_spec = [Shard(-1)]
shard_tensor = distribute_tensor(local_tensor, device_mesh, shard_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1)
reshard_tensor = shard_tensor.redistribute(device_mesh, shard_minus_spec)
self.assertEqual(shard_tensor.placements[0].dim, 1)
class MultiDimRedistributeTest(DTensorTestBase):
@property

View File

@ -143,7 +143,12 @@ def compute_global_tensor_info(
for idx, placement in enumerate(placements):
mesh_dim_size = mesh.size(idx)
if placement.is_shard():
shard_dim = cast(Shard, placement).dim
shard_placement = cast(Shard, placement)
if shard_placement.dim < 0:
# normalize shard dim to be positive
shard_placement.dim += len(tensor_shape)
shard_dim = shard_placement.dim
local_dim_size = tensor_shape[shard_dim]
tensor_shape[shard_dim] = local_dim_size * mesh_dim_size

View File

@ -393,15 +393,18 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
if placements is None:
raise RuntimeError("placements is needed for redistribute!")
# Early return the original DTensor if the placements are the same.
if self._spec.placements == placements:
return self
for placement in placements:
if placement.is_partial():
raise RuntimeError(
"Can not redistribute to _Partial, _Partial is for internal use only!"
)
elif isinstance(placement, Shard) and placement.dim < 0:
# normalize shard dim to be positive
placement.dim += self.ndim
# Early return the original DTensor if the placements are the same.
if self._spec.placements == placements:
return self
# pyre-fixme[16]: `Redistribute` has no attribute `apply`.
return Redistribute.apply(self, device_mesh, placements)
@ -519,6 +522,9 @@ def distribute_tensor(
for idx, placement in enumerate(placements):
if placement.is_shard():
placement = cast(Shard, placement)
if placement.dim < 0:
# normalize shard placement dim
placement.dim += tensor.ndim
local_tensor = placement._shard_tensor(local_tensor, device_mesh, idx)
elif placement.is_replicate():
placement = cast(Replicate, placement)