Files
pytorch/test/distributed/_tensor/test_dtensor.py
Wanchao Liang 2ae65b72ff [dtensor] early return for _split_tensor (#125810)
as titled, if _split_tensor does not require padding or even is evenly
sharded on the dim, no need to calculate padding and could simply return

This is to avoid some unnecessary CPU operations

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125810
Approved by: https://github.com/wz337
2024-05-14 04:59:27 +00:00

859 lines
34 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import torch
import torch.nn.functional as F
from numpy.testing import assert_array_equal
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
c10d_functional = torch.ops.c10d_functional
class DummyMLP(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.net1 = torch.nn.Linear(5, 1024, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(1024, 4, device=device)
def forward(self, x):
return self.net2(F.relu(self.net1(x)))
def reset_parameters(self, *args, **kwargs):
with torch.no_grad():
self.net1.weight.fill_(0.5)
self.net2.weight.fill_(1)
self.net1.bias.fill_(1.5)
self.net2.bias.fill_(1.2)
class DTensorTest(DTensorTestBase):
@with_comms
def test_dtensor_constructor(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3, requires_grad=True)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
dist_tensor = DTensor(
local_tensor,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=True,
stride=local_tensor.stride(),
)
self.assertEqual(dist_tensor.size(), torch.Size((self.world_size * 3, 3)))
with self.assertWarnsRegex(UserWarning, "To construct"):
DTensor(
local_tensor,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor.dtype,
requires_grad=False,
stride=local_tensor.stride(),
)
@with_comms
def test_meta_dtensor(self):
device_mesh = self.build_device_mesh()
dist_specs = [[Shard(0)], [Replicate()]]
meta_tensor = torch.randn(1024, 2048, device="meta")
for dist_spec in dist_specs:
# Test distribute_tensor on meta tensor
meta_dtensor = distribute_tensor(meta_tensor, device_mesh, dist_spec)
self.assertTrue(meta_dtensor.is_meta)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.2)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.2)
self.assertFalse(meta_dtensor.is_meta)
self.assertEqual(meta_dtensor.device.type, self.device_type)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
# Test from_local on meta tensor
meta_dtensor = DTensor.from_local(meta_tensor, device_mesh, dist_spec)
meta_dtensor = torch.empty_like(meta_dtensor, device=self.device_type)
torch.nn.init.constant_(meta_dtensor, 1.5)
self.assertEqual(meta_dtensor.device.type, self.device_type)
value_tensor = torch.empty_like(meta_dtensor.to_local()).fill_(1.5)
self.assertEqual(meta_dtensor.to_local(), value_tensor)
@with_comms
def test_modules_w_meta_dtensor(self):
model = DummyMLP("meta")
device_mesh = self.build_device_mesh()
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=self.device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
model_regular = DummyMLP(self.device_type)
model_regular_tp = parallelize_module(
model_regular, device_mesh, parallelize_plan
)
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
model_regular_tp.reset_parameters()
torch.manual_seed(0)
inp = torch.randn(20, 5, device=self.device_type)
output = model_tp(inp)
output_regular = model_regular_tp(inp)
self.assertEqual(output, output_regular)
output.sum().backward()
output_regular.sum().backward()
optim.step()
optim_regular.step()
torch.manual_seed(1)
inp = torch.randn(20, 5, device=self.device_type)
self.assertEqual(model_tp(inp), model_regular_tp(inp))
@with_comms
def test_dtensor_stride(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard0_spec = [Shard(0)]
local_tensor = torch.randn(4, 8)
global_shape = torch.Size([self.world_size * 4, 8])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard0_spec)
# won't affect stride
self.assertEqual(dist_tensor.stride(), (8, 1))
shard1_spec = [Shard(1)]
local_tensor = torch.randn(8, 4)
global_shape = torch.Size([8, self.world_size * 4])
dist_tensor = DTensor.from_local(local_tensor, device_mesh, shard1_spec)
# will affect stride after DT initialized
self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
# if initialized from a transposed mat
local_tensor = torch.randn(8, 4, 8)
local_tensor_t = local_tensor.permute(1, 2, 0)
global_shape = torch.Size([4, self.world_size * 8, 8])
self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
dist_tensor = DTensor.from_local(local_tensor_t, device_mesh, shard1_spec)
global_stride = (8 * self.world_size, 1, 32 * self.world_size)
self.assertEqual(dist_tensor.stride(), global_stride)
@with_comms
def test_from_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.size(), torch.Size([self.world_size * 3, 3]))
replica_spec = [Replicate()]
ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
self.assertEqual(ddp_tensor.size(), local_tensor.size())
partial_spec = [_Partial()]
partial_tensor = DTensor.from_local(local_tensor, device_mesh, partial_spec)
self.assertEqual(partial_tensor.size(), local_tensor.size())
# test dist tensor works with torch.Tensor during backwards
local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad * 3
# create the dist tensor with non leaf local tensor, dist tensor created
# should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 3
self.assertIsInstance(output, DTensor)
# trigger .backward() on dist tensor directly
local_grad = torch.ones(3, 3)
grad_output = DTensor.from_local(local_grad, device_mesh, placements)
# run backward directly on dist tensor
output.backward(grad_output)
# check it gradients flow back to original torch.Tensor
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 9
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_from_local_uneven_sharding(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
stride=global_tensor.stride(),
)
self.assertEqual(dtensor.size(), global_tensor.size())
self.assertEqual(dtensor.stride(), global_tensor.stride())
@with_comms
def test_from_local_uneven_sharding_raise_error(self):
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
uneven_dim0_size = self.world_size + 1
global_tensor = torch.randn(uneven_dim0_size, 2)
shard_placement = Shard(0)
tensor_list, _ = shard_placement._split_tensor(
global_tensor,
device_mesh.size(mesh_dim=0),
with_padding=False,
contiguous=True,
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
shape=global_tensor.size(),
)
with self.assertRaisesRegex(
RuntimeError, "Please pass both shape and stride at the same time."
):
dtensor = DTensor.from_local(
tensor_list[self.rank],
device_mesh,
(Shard(0),),
stride=global_tensor.stride(),
)
@with_comms
def test_from_local_negative_dim(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(-1)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.placements[0].dim, 1)
@with_comms
def test_to_local(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
dist_tensor_shape = torch.Size([self.world_size * 3, 3])
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
sharded_tensor = DTensor(
local_tensor_with_grad,
device_mesh,
placements,
shape=dist_tensor_shape,
dtype=local_tensor_with_grad.dtype,
requires_grad=True,
stride=local_tensor_with_grad.stride(),
)
self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
# test dist tensor works with torch.Tensor during backwards
# dist tensor created is a leaf node, do some operation on dist tensor
temp_st = sharded_tensor * 3
# do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = temp_st.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating through dist tensor
res.sum().backward()
self.assertIsNotNone(sharded_tensor.grad)
self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
# test the case when grad stride is different from fwd input.
res = sharded_tensor.to_local()
model = torch.nn.ReLU()
res.register_hook(lambda grad: grad.t())
target = torch.randn(3, 3, device=self.device_type)
mae_loss = torch.nn.L1Loss()
output = mae_loss(model(res), target)
# The manual change to grad stride leads to the failure of the copy op afterwards.
# so that we need a try-catch here.
try:
output.backward()
except RuntimeError:
self.assertEqual(sharded_tensor.grad.stride(), [1, 3 * self.world_size])
@with_comms
def test_to_local_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
comm_mode = CommDebugMode()
with comm_mode:
local_out = sharded_dtensor.redistribute(placements=[Replicate()]).to_local(
grad_placements=[_Partial()]
)
local_out.backward(torch.ones_like(local_out))
self.assertEqual(
comm_mode.comm_counts[c10d_functional.all_gather_into_tensor], 1
)
self.assertEqual(
comm_mode.comm_counts[c10d_functional.reduce_scatter_tensor], 1
)
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_full_tensor_sync(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
full_out = sharded_dtensor.full_tensor()
self.assertFalse(isinstance(full_out, AsyncCollectiveTensor))
self.assertEqual(full_out, global_tensor)
@with_comms
def test_full_tensor_grad_hint(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = (Shard(0),)
global_tensor = torch.ones(8, 3, requires_grad=True)
sharded_dtensor = distribute_tensor(global_tensor, device_mesh, placements)
local_out = sharded_dtensor.full_tensor(grad_placements=[_Partial()])
local_out.sum().backward()
replica_grad = sharded_dtensor.grad.full_tensor()
self.assertEqual(replica_grad, global_tensor * self.world_size)
@with_comms
def test_dtensor_new_empty_strided(self):
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
local_tensor = torch.randn(8, 8, requires_grad=True, device=self.device_type)
my_dtensor = distribute_tensor(local_tensor, device_mesh, [Shard(0)])
new_strided_dtensor = my_dtensor.new_empty_strided(
(8, 8), (8, 1), requires_grad=True
)
# test the op produces new dtensor and autograd works
self.assertEqual(new_strided_dtensor.shape, my_dtensor.shape)
new_strided_dtensor.sum().backward()
self.assertIsNotNone(new_strided_dtensor.grad)
self.assertIsInstance(new_strided_dtensor.grad, DTensor)
# test backward new_empty_strided with sharding works correctly
my_dtensor.to_local().sum().backward()
local_tensor.sum().backward()
self.assertEqual(my_dtensor.grad, new_strided_dtensor.grad)
self.assertEqual(
my_dtensor.grad.redistribute(placements=[Replicate()]).to_local(),
local_tensor.grad,
)
@with_comms
def test_dtensor_async_output(self):
# Tests that if the output of some dtensor operations isn't used in any compute,
# the output should be an AsyncCollectiveTensor (representing the fact that
# we haven't synced the collective yet).
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
def fn(dt):
dt_out_redistribute = dt.redistribute(mesh, [Replicate()], async_op=True)
# Make sure we haven't synced yet
# TODO: figure out why this is returning None
# self.assertTrue(_tensor_needs_wait(dt_out_redistribute))
dt_out_redistribute_view = dt_out_redistribute.view(
dt_out_redistribute.shape
)
local_tensor = dt_out_redistribute_view.to_local()
return local_tensor
x = torch.ones((4, 2), device=self.device_type)
dt = distribute_tensor(x, mesh, [Shard(0)])
out = fn(dt)
# Make sure we haven't synced yet
self.assertEqual(type(out), AsyncCollectiveTensor)
self.assertFalse(out.completed)
out_view = out.view(-1)
# Assert that output is a `AsyncCollectiveTensor`
self.assertEqual(type(out_view), AsyncCollectiveTensor)
self.assertFalse(out.completed)
# Use the daa, requiring a sync
ref = torch.ones((4, 2), device=self.device_type) + 1
ref = ref.view(-1)
out_data = out_view + 1
self.assertEqual(type(out_data), torch.Tensor)
self.assertEqual(out_data, ref)
# test async_op = False default
sync_out = dt.redistribute(mesh, [Replicate()])
self.assertFalse(isinstance(sync_out, AsyncCollectiveTensor))
self.assertEqual(sync_out.to_local(), x)
@with_comms
def test_from_local_then_to_local(self):
# this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
# step 1. construct from construct local tensor
local_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
# do some operations on local tensor
local_tensor_temp = local_tensor_with_grad + 8
# step 2. create the dist tensor with non leaf local tensor, dist tensor
# created should also be non leaf node
dist_tensor = DTensor.from_local(local_tensor_temp, device_mesh, placements)
self.assertFalse(dist_tensor.is_leaf)
# do some random operations on dist tensor
output = dist_tensor * 6
self.assertIsInstance(output, DTensor)
# step 3. do some operation on local tensor of the dist tensor
new_tensor_with_grad = torch.randn(
3, 3, device=self.device_type, requires_grad=True
)
res = output.to_local() + new_tensor_with_grad
# call backward directly on torch.Tensor, and see if it works by
# propagating all the way back to the original torch.Tensor
res.sum().backward()
self.assertIsNotNone(local_tensor_with_grad.grad)
expected_grad = torch.ones(3, 3) * 6
self.assertEqual(local_tensor_with_grad.grad, expected_grad)
@with_comms
def test_dtensor_spec_read_only_after_set(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
# modify placements, and dist_tensor's spec should not be changed
placements[0] = Replicate()
self.assertTrue(sharded_tensor.placements is not placements)
self.assertNotEqual(sharded_tensor.placements, placements)
@with_comms
def test_dtensor_spec_hash(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
local_tensor2 = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
sharded_tensor2 = DTensor.from_local(local_tensor2, device_mesh, placements)
# note that DTensorSpec without real tensor data, so the hash would be the same
# as long as the mesh, placements and tensor properties are the same
self.assertEqual(hash(sharded_tensor._spec), hash(sharded_tensor2._spec))
# change the placements would change the hash
local_tensor3 = torch.ones(3, 3)
replica_spec = [Replicate()]
replica_tensor = DTensor.from_local(
local_tensor3, device_mesh, replica_spec, run_check=False
)
self.assertNotEqual(hash(sharded_tensor._spec), hash(replica_tensor._spec))
@with_comms
def test_dtensor_properties(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
self.assertEqual(sharded_tensor.device.type, self.device_type)
@with_comms
def test_dtensor_save_load(self):
import io
device_mesh = self.build_device_mesh()
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, device_mesh, placements)
buffer = io.BytesIO()
torch.save(sharded_tensor, buffer)
buffer.seek(0)
reloaded_st = torch.load(buffer)
self.assertEqual(sharded_tensor, reloaded_st)
class DTensorMeshTest(DTensorTestBase):
@property
def world_size(self):
return 8
def sub_mesh_assert_equal(self, mesh, exp_in_mesh, exp_out_of_mesh, tensor):
if self.rank in mesh:
self.assertEqual(tensor, exp_in_mesh)
else:
self.assertEqual(tensor, exp_out_of_mesh)
@with_comms
def test_dtensor_device_mesh_device_conversion(self):
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# construct from a cpu local tensor with cuda device mesh
# should automatically convert the dist tensor to cuda
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
def test_dtensor_api_device_mesh_context_manager(self):
with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(
local_tensor, device_mesh=mesh, placements=placements
)
with DeviceMesh(self.device_type, list(range(self.world_size))):
placements = [Shard(0)]
local_tensor = torch.randn(3, 3)
sharded_tensor = DTensor.from_local(local_tensor, placements=placements)
replica_spec = [Replicate()]
replica_tensor = sharded_tensor.redistribute(placements=replica_spec)
self.assertEqual(
replica_tensor.size(), torch.Size([3 * self.world_size, 3])
)
with DeviceMesh(self.device_type, torch.arange(self.world_size)):
placements = [Shard(0)]
global_shape = torch.Size([3 * self.world_size, 3])
global_tensor = torch.randn(global_shape)
sharded_tensor = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_tensor.to_local().shape, torch.Size([3, 3]))
mesh_2d = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
with mesh_2d:
shard_2d_spec = [Shard(0), Replicate()]
tensor_2d = distribute_tensor(global_tensor, placements=shard_2d_spec)
self.assertEqual(tensor_2d.to_local().shape, torch.Size([3 * 4, 3]))
sharded_after_2d = distribute_tensor(global_tensor, placements=placements)
self.assertEqual(sharded_after_2d.to_local().shape, torch.Size([3, 3]))
@with_comms
def test_dtensor_2d_mesh(self):
mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
# construct a cuda device mesh
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 2d device mesh and test if works
placements = [Shard(0), Shard(1)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(
dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
)
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# if shard on the same tensor dimension
# we should correctly construct the global tensor size
shard_same_dim_spec = [Shard(0), Shard(0)]
local_tensor = torch.randn(3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, shard_same_dim_spec)
self.assertEqual(dist_tensor.size(), torch.Size([3 * self.world_size, 3]))
@with_comms
def test_device_mesh_nd(self):
# construct a cuda device mesh
mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
mesh = DeviceMesh(self.device_type, mesh_tensor)
# construct a dist tensor on 3d device mesh and test if works
placements = [Shard(0), Shard(1), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
# construct a dist tensor on 3d device mesh with some shards on same dim
placements = [Shard(0), Shard(0), Shard(2)]
local_tensor = torch.randn(3, 3, 3)
dist_tensor = DTensor.from_local(local_tensor, mesh, placements)
self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
self.assertEqual(dist_tensor.device.type, self.device_type)
self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
@with_comms
def test_dtensor_spec_local_shard_offset(self):
device_mesh = DeviceMesh(
self.device_type, torch.arange(self.world_size).reshape(2, 4)
)
tensor_shape = (3 * self.world_size, 3 * self.world_size)
# sharding specs and its corresponding local shard offsets
shard_spec_and_offsets = [
(
[Shard(0), Replicate()],
(3 * (self.world_size // 2) * (self.rank // 4), 0),
),
(
[Shard(1), Replicate()],
(0, 3 * (self.world_size // 2) * (self.rank // 4)),
),
(
[Replicate(), Shard(0)],
(3 * (self.world_size // 4) * (self.rank % 4), 0),
),
(
[Replicate(), Shard(1)],
(0, 3 * (self.world_size // 4) * (self.rank % 4)),
),
]
from torch.distributed._tensor._utils import (
compute_local_shape_and_global_offset,
)
# loop through all sharding specs and check local shard offsets
logical_tensor = torch.randn(tensor_shape)
for placements, expected_shard_offsets in shard_spec_and_offsets:
dtensor = distribute_tensor(logical_tensor, device_mesh, placements)
_, offset = compute_local_shape_and_global_offset(
dtensor.shape, device_mesh, dtensor.placements
)
self.assertEqual(expected_shard_offsets, offset)
@with_comms
def test_from_local_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
self.assertEqual(dtensor.size(), torch.Size([6, 4]))
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4),
torch.tensor([]),
dtensor.to_local(),
)
# test dtensor created in submesh, the operation should only
# be applied to the local shard inside the mesh, not the whole
# world, so only 0/2 really run the computation
dtensor = dtensor + 2
self.sub_mesh_assert_equal(
mesh.mesh,
torch.ones(3, 4) + 2,
torch.tensor([]),
dtensor.to_local(),
)
@with_comms
def test_default_value_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test scalar return value
local_tensor1 = torch.ones(4, 3)
local_tensor2 = torch.ones(4, 3)
dtensor1 = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
dtensor2 = DTensor.from_local(local_tensor2, mesh, [Shard(0)])
local_res = dtensor1.equal(dtensor2) # equal returns local result
self.sub_mesh_assert_equal(
mesh.mesh,
True,
True,
local_res,
)
# test 0-d tensor return value
local_tensor = torch.ones(4, 3)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)]).sum()
self.sub_mesh_assert_equal(
mesh.mesh,
torch.tensor(12.0),
torch.tensor(0.0),
dtensor.to_local(),
)
# test List[torch.Tensor] return value
local_tensor = torch.ones(3, 4)
dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
dtensor_list = dtensor.split([2, 2], dim=1)
self.sub_mesh_assert_equal(
mesh.mesh,
[torch.ones(3, 2)] * 2,
[torch.tensor([])] * 2,
[dt.to_local() for dt in dtensor_list],
)
@with_comms
def test_redistribute_sub_mesh(self):
mesh = DeviceMesh(self.device_type, [0, 2])
# test redistribute on a submesh
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
replicated_dtensor = sharded_dtensor.redistribute(placements=[Replicate()])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(8, 3), torch.tensor([]), replicated_dtensor.to_local()
)
sharded_again = replicated_dtensor.redistribute(placements=[Shard(0)])
self.sub_mesh_assert_equal(
mesh.mesh, torch.ones(4, 3), torch.tensor([]), sharded_again.to_local()
)
@with_comms
def test_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor1 = torch.ones(4, 3)
sharded_dtensor = DTensor.from_local(local_tensor1, mesh, [Shard(0)])
from torch.distributed._tensor.experimental import implicit_replication
with implicit_replication():
out_dt = sharded_dtensor + torch.ones(3, device=self.device_type)
self.assertEqual(out_dt.placements, [Shard(0)])
self.assertEqual(out_dt.shape, (4 * self.world_size, 3))
local_shard = out_dt.to_local()
self.assertEqual(local_shard.shape, (4, 3))
self.assertEqual(local_shard, torch.ones(4, 3) + torch.ones(3))
@with_comms
def test_auto_implicit_replication(self):
mesh = init_device_mesh(self.device_type, (self.world_size,))
local_tensor = torch.ones(self.world_size, 3, device=self.device_type)
sharded_dtensor = DTensor.from_local(local_tensor, mesh, [Shard(0)])
# automatically turn tensor to DTensor replicate when ndim = 0 and numel = 1
ndim_0_tensor = torch.tensor(1, device=self.device_type)
def add_scalar_tensor_with_dtensor():
return sharded_dtensor + ndim_0_tensor
result = add_scalar_tensor_with_dtensor().to_local()
self.assertEqual(result, local_tensor + ndim_0_tensor)
self.assertNotWarn(
add_scalar_tensor_with_dtensor,
"Found a non-scalar tensor with numel=1 and ndim!=0",
)
# automatically turn tensor to DTensor replicate when ndim = 1 and numel = 1
numel_1_tensor = torch.tensor([1], device=self.device_type)
self.assertEqual(
(sharded_dtensor + numel_1_tensor).to_local(), local_tensor + numel_1_tensor
)
class TestDTensorPlacementTypes(DTensorTestBase):
@property
def world_size(self):
return 8
def _create_tensor(self, size):
# Keep everything deterministic.
torch.manual_seed(0)
tensor = torch.rand(size)
if self.device_type == "cuda":
return tensor.cuda()
else:
return tensor
@with_comms
def test_split_tensor_1D(self) -> None:
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
shard_placement = Shard(0)
for size in range(8):
tensor = self._create_tensor(size)
splitted_tensor_list, pad_sizes = shard_placement._split_tensor(
tensor,
mesh.size(),
with_padding=True,
contiguous=True,
)
if size == 0:
# when tensor size is 0, there is no padding needed for all the ranks.
expected_pad_sizes = []
assert_array_equal(expected_pad_sizes, pad_sizes)
is_tensor_empty = [
False if splitted_tensor.numel() > 0 else True
for splitted_tensor in splitted_tensor_list
]
expected_is_tensor_empty = [True] * self.world_size
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
else:
expected_pad_sizes = [
0 if idx < size else 1
for idx, _ in enumerate(range(self.world_size))
]
assert_array_equal(expected_pad_sizes, pad_sizes)
from torch.distributed._tensor._collective_utils import unpad_tensor
unpadded_list = [
unpad_tensor(tensor, shard_placement.dim, pad_sizes[i])
if pad_sizes[i] > 0
else tensor
for i, tensor in enumerate(splitted_tensor_list)
]
expected_is_tensor_empty = [
False if idx < size else True
for idx, _ in enumerate(range(self.world_size))
]
is_tensor_empty = [
False if unpadded_tensor.numel() > 0 else True
for unpadded_tensor in unpadded_list
]
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
if __name__ == "__main__":
run_tests()