Files
pytorch/test/distributed/_tensor/test_dtensor.py
Wanchao Liang 242e03ba86 [dtensor] add async_op option to redistribute and some refactor (#121477)
async output option was only available in `full_tensor()` call, but I think it's
generally good to make this option available in the `redistribute` call directly
so that user can control it

This PR adds async_op option to redistribute call, to allow user control
whether to perform tensor redistribution asynchronously or not.

By default we set this to False, this is to follow the semantics of the c10d
collectives.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121477
Approved by: https://github.com/wz337
2024-03-09 06:17:23 +00:00

866 lines
34 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
import torch.nn.functional as F
from numpy.testing import assert_array_equal
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_distributed import run_with_both_funcol_impls
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
class DummyMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.net1 = torch.nn.Linear(5, 1024, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(1024, 4, device=device)
def forward(self, x):
return self.net2(F.relu(self.net1(x)))
def reset_parameters(self, *args, **kwargs):
with torch.no_grad():
self.net1.weight.fill_(0.5)
self.net2.weight.fill_(1)
self.net1.bias.fill_(1.5)
self.net2.bias.fill_(1.2)
@instantiate_parametrized_tests
class DTensorTest(DTensorTestBase):
@with_comms
@run_with_both_funcol_impls
def test_dtensor_constructor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3, requires_grad=True)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
dist_tensor = DTensor(
local_tensor,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=True,
stride=local_tensor.stride(),
)
self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3)))
with self.assertWarnsRegex(UserWarning, "To construct"):
DTensor(
local_tensor,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=False,
stride=local_tensor.stride(),
)
@with_comms
@run_with_both_funcol_impls
def test_meta_dtensor(self):
device_mesh = self.build_device_mesh()
dist_specs = [[Shard(0)], [Replicate()]]
meta_tensor = torch.randn(1024, 2048, device="meta")
for dist_spec in dist_specs:
# Test distribute_tensor on meta tensor
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
self.assertTrue(meta_dtensor.is_meta)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.2)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
self.assertFalse(meta_dtensor.is_meta)
self.assertEqual(meta_dtensor.device.type, self.device_type)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
# Test from_local on meta tensor
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.5)
self.assertEqual(meta_dtensor.device.type, self.device_type)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
@with_comms
@run_with_both_funcol_impls
def test_modules_w_meta_dtensor(self):
model = DummyMLP("meta")
device_mesh = self.build_device_mesh()
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=self.device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
model_regular = DummyMLP(self.device_type)
model_regular_tp = parallelize_module(
model_regular, device_mesh, parallelize_plan
)
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
model_regular_tp.reset_parameters()
torch.manual_seed(0)
inp = torch.randn(20, 5, device=self.device_type)
output = model_tp(inp)
output_regular = model_regular_tp(inp)
self.assertEqual(output, output_regular)
output.sum().backward()
output_regular.sum().backward()
optim.step()
optim_regular.step()
torch.manual_seed(1)
inp = torch.randn(20, 5, device=self.device_type)
self.assertEqual(model_tp(inp), model_regular_tp(inp))
@with_comms
@run_with_both_funcol_impls
def test_dtensor_stride(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = [Shard(0)]
local_tensor = torch.randn(4, 8)
global_shape = torch.Size([self.world_size * 4, 8])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
# won't affect stride
self.assertEqual(dist_tensor.stride(), (8, 1))
shard1_spec = [Shard(1)]
local_tensor = torch.randn(8, 4)
global_shape = torch.Size([8, self.world_size * 4])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
# will affect stride after DT initialized
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
# if initialized from a transposed mat
local_tensor = torch.randn(8, 4, 8)
local_tensor_t = local_tensor.permute(1, 2, 0)
global_shape = torch.Size([4, self.world_size * 8, 8])
self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
self.assertEqual(dist_tensor.stride(), global_stride)
@with_comms
@run_with_both_funcol_impls
def test_from_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
replica_spec = [Replicate()]
ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
self.assertEqual(ddp_tensor.size(), local_tensor.size())
partial_spec = [_Partial()]
partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec)
self.assertEqual(partial_tensor.size(), local_tensor.size())
# test dist tensor works with torch.Tensor during backwards
local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad * 3
# create the dist tensor with non leaf local tensor, dist tensor created
# should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 3
self.assertIsInstance(output, DTensor)
# trigger .backward() on dist tensor directly
local_grad = torch.ones(3, 3)
grad_output = DTensor.from_local(local_grad, device_mesh, placements)
# run backward directly on dist tensor
output.backward(grad_output)
# check it gradients flow back to original torch.Tensor
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 9
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
@run_with_both_funcol_impls
def test_from_local_uneven_sharding(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
stride=global_tensor.stride(),
)
self.assertEqual(dtensor.size(), global_tensor.size())
self.assertEqual(dtensor.stride(), global_tensor.stride())
@with_comms
@run_with_both_funcol_impls
def test_from_local_uneven_sharding_raise_error(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
stride=global_tensor.stride(),
)
@with_comms
@run_with_both_funcol_impls
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
sharded_tensor = DTensor(
local_tensor_with_grad,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor_with_grad.dtype,
requires_grad=True,
stride=local_tensor_with_grad.stride(),
)
self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
# test dist tensor works with torch.Tensor during backwards
# dist tensor created is a leaf node, do some operation on dist tensor
temp_st = sharded_tensor * 3
# do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = temp_st.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating through dist tensor
res.sum().backward()
self.assertIsNotNone(sharded_tensor.grad)
self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
# test the case when grad stride is different from fwd input.
res = sharded_tensor.to_local()
model = torch.nn.ReLU()
res.register_hook(lambda grad: grad.t())
target = torch.randn(3, 3, device=self.device_type)
mae_loss = torch.nn.L1Loss()
output = mae_loss(model(res), target)
# The manual change to grad stride leads to the failure of the copy op afterwards.
# so that we need a try-catch here.
try:
output.backward()
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
@with_comms
@run_with_both_funcol_impls
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
comm_mode = CommDebugMode()
with comm_mode:
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
)
local_out.backward(torch.ones_like(local_out))
self.assertEqual(
comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
)
self.assertEqual(
comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
)
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
@run_with_both_funcol_impls
def test_full_tensor_sync(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
full_out = sharded_dtensor.full_tensor()
self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
self.assertEqual(full_out, global_tensor)
@with_comms
@run_with_both_funcol_impls
def test_full_tensor_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
new_strided_dtensor = my_dtensor.new_empty_strided(
(8, 8), (8, 1), requires_grad=True
)
# test the op produces new dtensor and autograd works
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
new_strided_dtensor.sum().backward()
self.assertIsNotNone(new_strided_dtensor.grad)
self.assertIsInstance(new_strided_dtensor.grad, DTensor)
# test backward new_empty_strided with sharding works correctly
my_dtensor.to_local().sum().backward()
local_tensor.sum().backward()
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
self.assertEqual(
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
local_tensor.grad,
)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,
# the output should be an AsyncCollectiveTensor (representing the fact that
# we haven't synced the collective yet).
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
def fn(dt):
dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True)
# Make sure we haven't synced yet
# TODO: figure out why this is returning None
# self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
dt_out_redistribute_view = dt_out_redistribute.view(
dt_out_redistribute.shape
)
local_tensor = dt_out_redistribute_view.to_local()
return local_tensor
x = torch.ones((4, 2), device=self.device_type)
dt = distribute_tensor(x, mesh, [Shard(0)])
out = fn(dt)
# Make sure we haven't synced yet
self.assertEqual(type(out), AsyncCollectiveTensor)
self.assertFalse(out.completed)
out_view = out.view(-1)
# Assert that output is a `AsyncCollectiveTensor`
self.assertEqual(type(out_view), AsyncCollectiveTensor)
self.assertFalse(out.completed)
# Use the daa, requiring a sync
ref = torch.ones((4, 2), device=self.device_type) + 1
ref = ref.view(-1)
out_data = out_view + 1
self.assertEqual(type(out_data), torch.Tensor)
self.assertEqual(out_data, ref)
# test async_op = False default
sync_out = dt.redistribute(mesh, [Replicate()])
self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor))
self.assertEqual(sync_out.to_local(), x)
@with_comms
@run_with_both_funcol_impls
def test_from_local_then_to_local(self):
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
# step 1. construct from construct local tensor
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad + 8
# step 2. create the dist tensor with non leaf local tensor, dist tensor
# created should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 6
self.assertIsInstance(output, DTensor)
# step 3. do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = output.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating all the way back to the original torch.Tensor
res.sum().backward()
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 6
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_spec_read_only_after_set(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
# modify placements, and dist_tensor's spec should not be changed
placements[0] = Replicate()
self.assertTrue(sharded_tensor.placements is not placements)
self.assertNotEqual(sharded_tensor.placements, placements)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_spec_hash(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
local_tensor2 = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
# note that DTensorSpec without real tensor data, so the hash would be the same
# as long as the mesh, placements and tensor properties are the same
self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
# change the placements would change the hash
local_tensor3 = torch.ones(3, 3)
replica_spec = [Replicate()]
replica_tensor = DTensor.from_local(
local_tensor3, device_mesh, replica_spec, run_check=False
)
self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec))
@with_comms
@run_with_both_funcol_impls
def test_dtensor_properties(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.device.type, self.device_type)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_save_load(self):
import io
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer)
self.assertEqual(sharded_tensor, reloaded_st)
@instantiate_parametrized_tests
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):
return 8
def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
if self.rank in mesh:
self.assertEqual(tensor, exp_in_mesh)
else:
self.assertEqual(tensor, exp_out_of_mesh)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_api_device_mesh_context_manager(self):
with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=placements
)
with DeviceMesh(self.device_type, list(range(self.world_size))):
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
replica_spec = [Replicate()]
replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
self.assertEqual(
replica_tensor.size(), torch.Size([3 * self.world_size, 3])
)
with DeviceMesh(self.device_type, torch.arange(self.world_size)):
placements = [Shard(0)]
global_shape = torch.Size([3 * self.world_size, 3])
global_tensor = torch.randn(global_shape)
sharded_tensor = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
with mesh_2d:
shard_2d_spec = [Shard(0), Replicate()]
tensor_2d = distribute_tensor(global_tensor, placements=shard_2d_spec)
self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
@with_comms
@run_with_both_funcol_impls
def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
placements = [Shard(0), Shard(1)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(
dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# if shard on the same tensor dimension
# we should correctly construct the global tensor size
shard_same_dim_spec = [Shard(0), Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec)
self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
@with_comms
@run_with_both_funcol_impls
def test_device_mesh_nd(self):
# construct a cuda device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
placements = [Shard(0), Shard(1), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# construct a dist tensor on 3d device mesh with some shards on same dim
placements = [Shard(0), Shard(0), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
@run_with_both_funcol_impls
def test_dtensor_spec_local_shard_offset(self):
device_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
tensor_shape = (3 * self.world_size, 3 * self.world_size)
# sharding specs and its corresponding local shard offsets
shard_spec_and_offsets = [
(
[Shard(0), Replicate()],
(3 * (self.world_size // 2) * (self.rank // 4), 0),
),
(
[Shard(1), Replicate()],
(0, 3 * (self.world_size // 2) * (self.rank // 4)),
),
(
[Replicate(), Shard(0)],
(3 * (self.world_size // 4) * (self.rank % 4), 0),
),
(
[Replicate(), Shard(1)],
(0, 3 * (self.world_size // 4) * (self.rank % 4)),
),
]
from torch.distributed._tensor._utils import (
compute_local_shape_and_global_offset,
)
# loop through all sharding specs and check local shard offsets
logical_tensor = torch.randn(tensor_shape)
for placements, expected_shard_offsets in shard_spec_and_offsets:
dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
_, offset = compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, dtensor.placements
)
self.assertEqual(expected_shard_offsets, offset)
@with_comms
@run_with_both_funcol_impls
def test_from_local_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
self.assertEqual(dtensor.size(), torch.Size([6, 4]))
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4),
torch.tensor([]),
dtensor.to_local(),
)
# test dtensor created in submesh, the operation should only
# be applied to the local shard inside the mesh, not the whole
# world, so only 0/2 really run the computation
dtensor = dtensor + 2
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4) + 2,
torch.tensor([]),
dtensor.to_local(),
)
@with_comms
@run_with_both_funcol_impls
def test_default_value_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test scalar return value
local_tensor1 = torch.ones(4, 3)
local_tensor2 = torch.ones(4, 3)
dtensor1 = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
dtensor2 = DTensor.from_local(local_tensor2, mesh, [Shard(0)])
local_res = dtensor1.equal(dtensor2) # equal returns local result
self.sub_mesh_assert_equal(
mesh.mesh,
True,
True,
local_res,
)
# test 0-d tensor return value
local_tensor = torch.ones(4, 3)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]).sum()
self.sub_mesh_assert_equal(
mesh.mesh,
torch.tensor(12.0),
torch.tensor(0.0),
dtensor.to_local(),
)
# test List[torch.Tensor] return value
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
dtensor_list = dtensor.split([2, 2], dim=1)
self.sub_mesh_assert_equal(
mesh.mesh,
[torch.ones(3, 2)] * 2,
[torch.tensor([])] * 2,
[dt.to_local() for dt in dtensor_list],
)
@with_comms
@run_with_both_funcol_impls
def test_redistribute_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test redistribute on a submesh
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
replicated_dtensor = sharded_dtensor.redistribute(placements=[Replicate()])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(8, 3), torch.tensor([]), replicated_dtensor.to_local()
)
sharded_again = replicated_dtensor.redistribute(placements=[Shard(0)])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local()
)
@with_comms
@run_with_both_funcol_impls
def test_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
out_dt = sharded_dtensor + torch.ones(3, device=self.device_type)
self.assertEqual(out_dt.placements, [Shard(0)])
self.assertEqual(out_dt.shape, (4 * self.world_size, 3))
local_shard = out_dt.to_local()
self.assertEqual(local_shard.shape, (4, 3))
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
@instantiate_parametrized_tests
class TestDTensorPlacementTypes(DTensorTestBase):
@property
def world_size(self):
return 8
def _create_tensor(self, size):
# Keep everything deterministic.
torch.manual_seed(0)
tensor = torch.rand(size)
if self.device_type == "cuda":
return tensor.cuda()
else:
return tensor
@with_comms
@run_with_both_funcol_impls
def test_split_tensor_1D(self) -> None:
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
shard_placement = Shard(0)
for size in range(8):
tensor = self._create_tensor(size)
splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
tensor,
mesh.size(),
with_padding=True,
contiguous=True,
)
if size == 0:
# when tensor size is 0, there is no padding needed for all the ranks.
expected_pad_sizes = [0] * self.world_size
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [
False if splitted_tensor.numel() > 0 else True
for splitted_tensor in splitted_tensor_list
]
expected_is_tensor_empty = [True] * self.world_size
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
else:
expected_pad_sizes = [
0 if idx < size else 1
for idx, _ in enumerate(range(self.world_size))
]
assert_array_equal(expected_pad_sizes, pad_sizes)
unpadded_list = [
shard_placement._unpad_tensor(tensor, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
for i, tensor in enumerate(splitted_tensor_list)
]
expected_is_tensor_empty = [
False if idx < size else True
for idx, _ in enumerate(range(self.world_size))
]
is_tensor_empty = [
False if unpadded_tensor.numel() > 0 else True
for unpadded_tensor in unpadded_list
]
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
if __name__ == "__main__":
run_tests()