Files
pytorch/test/distributed/_tensor/test_math_ops.py
Wanchao Liang 61461f39d1 [dtensor] handle negative dim and fix TP regression (#111750)
TP style still have some regression due to negative dim specifications,
fix it by allow DTensor API to handle negative dims and normalize them.

i.e. TP uses `Shard(-1)`, and then try to redistribute `Shard(1) -> Shard(-1)`, this should ideally be no-op but current it runs a decompose sharding phrase and it would turn this transformation to `Shard(1) -> Replicate -> Shard(-1)`, which is wrong and triggers unnecessary allgathers
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111750
Approved by: https://github.com/rohan-varma
2023-10-22 04:25:45 +00:00

120 lines
4.8 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import itertools
import torch
from torch.distributed._tensor import distribute_tensor
from torch.distributed._tensor.placement_types import Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_unless_torch_gpu,
with_comms,
)
class DistMathOpsTest(DTensorTestBase):
@with_comms
def test_sum(self):
device_mesh = self.build_device_mesh()
shard_spec = [Shard(0)]
tensor_to_sum = torch.randn(12, 8, 8)
mat1 = distribute_tensor(tensor_to_sum, device_mesh, shard_spec)
keep_dim_or_not = [True, False, None]
for dim in range(tensor_to_sum.ndim):
for keep_dim in keep_dim_or_not:
sum_args = (dim, keep_dim) if keep_dim is not None else (dim,)
dim_sumed_tensor = tensor_to_sum.sum(*sum_args)
dt_dim_sumed_tensor = mat1.sum(*sum_args).redistribute(
device_mesh, [Replicate()] * device_mesh.ndim
)
self.assertEqual(dt_dim_sumed_tensor.to_local(), dim_sumed_tensor)
full_sumed_tensor = tensor_to_sum.sum()
dt_sum = mat1.sum().redistribute(device_mesh, [Replicate()] * device_mesh.ndim)
self.assertEqual(dt_sum.to_local(), full_sumed_tensor)
# TODO: forward test can be removed once test_softmax_with_bwd passes on CPU
@with_comms
def test_softmax_fwd(self):
device_mesh = self.build_device_mesh()
x = torch.rand(8, 12, 16, device=self.device_type)
dims = range(3) # used to convert -1 to the actual dim
softmax_dims = [-1, 0, 1, 2]
shard_dims = [-1, 0, 1, 2]
test_list = list(itertools.product(softmax_dims, shard_dims))
for softmax_dim, shard_dim in test_list:
local_y = torch.nn.functional.softmax(
x, dim=softmax_dim, dtype=torch.float32
)
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
if dims[shard_dim] == dims[softmax_dim]:
with self.assertRaisesRegex(
Exception, "Cannot run .* on sharding dimension!$"
):
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
else:
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
)
shard_dim = shard_dim + dist_y.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_y.placements[0].is_shard(dim=shard_dim))
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_y.to_local(), local_y)
# TODO: get test_softmax_with_bwd pass on CPU
# DTensor's _softmax_backward_data produces wrong result on CPU on certain dimension.
# fail_on_cpu_list = [(0, -1), (1, -1)]
@with_comms
@skip_unless_torch_gpu
def test_softmax_with_bwd(self):
device_mesh = self.build_device_mesh()
dims = range(3) # used to convert -1 to the actual dim
softmax_dims = [-1, 0, 1, 2]
shard_dims = [-1, 0, 1, 2]
test_list = list(itertools.product(softmax_dims, shard_dims))
for params in test_list:
softmax_dim, shard_dim = params
x = torch.rand(8, 12, 16, device=self.device_type, requires_grad=True)
self.assertTrue(x.requires_grad)
local_y = torch.nn.functional.softmax(
x, dim=softmax_dim, dtype=torch.float32
).sum()
local_y.backward()
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
self.assertTrue(dist_x.requires_grad)
if dims[softmax_dim] == dims[shard_dim]:
with self.assertRaisesRegex(
Exception, "Cannot run .* on sharding dimension!$"
):
dist_softmax = dist_x.softmax(dim=softmax_dim)
else:
dist_softmax = dist_x.softmax(dim=softmax_dim)
shard_dim = shard_dim + dist_x.ndim if shard_dim < 0 else shard_dim
self.assertTrue(dist_softmax.placements[0].is_shard(dim=shard_dim))
dist_y = dist_softmax.sum()
dist_y = dist_y.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_y.to_local(), local_y)
self.assertIsNone(dist_x.grad)
dist_y.backward()
self.assertIsNotNone(dist_x.grad)
dist_x_grad = dist_x.grad.redistribute(device_mesh, [Replicate()])
self.assertEqual(dist_x_grad.to_local(), x.grad)
if __name__ == "__main__":
run_tests()