mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157640 Approved by: https://github.com/yewentao256, https://github.com/malfet
434 lines
16 KiB
Python
434 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
from functools import partial
|
|
|
|
import torch
|
|
import torch.distributed._functional_collectives as funcol
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
from torch.distributed.tensor import (
|
|
distribute_tensor,
|
|
DTensor,
|
|
Partial,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.distributed.tensor.debug import CommDebugMode
|
|
from torch.distributed.tensor.experimental import local_map
|
|
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
funcol_py = torch.ops.c10d_functional
|
|
|
|
|
|
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
|
|
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
|
|
replicate = [Replicate()] # replicate placements on 1-d mesh
|
|
|
|
|
|
def equal_allgather_forward(device_mesh, X, Y):
|
|
eq = torch.tensor([torch.equal(X, Y)], device=X.device)
|
|
eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
|
|
return torch.all(eq_gather).item()
|
|
|
|
|
|
def mm_all_gather_forward(device_mesh, A, B):
|
|
local_mm_result = torch.mm(A, B)
|
|
return funcol.all_gather_tensor(local_mm_result, 0, device_mesh).wait()
|
|
|
|
|
|
def mm_forward(A, B): # no device mesh needed since we don't do collective
|
|
return torch.mm(A, B)
|
|
|
|
|
|
def mm_allreduce_forward(device_mesh, A, B):
|
|
partial_sum_tensor = torch.mm(A, B)
|
|
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
|
|
|
|
|
@partial(
|
|
local_map,
|
|
out_placements=replicate,
|
|
in_placements=(None, col_wise, row_wise),
|
|
)
|
|
def mm_allreduce_forward_decorated(device_mesh, A, B):
|
|
partial_sum_tensor = torch.mm(A, B)
|
|
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
|
|
|
|
|
def mul_forward(X, scalar): # no device mesh needed since we don't do collective
|
|
return torch.mul(X, scalar)
|
|
|
|
|
|
class TestLocalMap(DTensorTestBase):
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
# simple correctness check
|
|
@with_comms
|
|
def test_local_map_correctness(self):
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
comm_mode = CommDebugMode()
|
|
|
|
# Y = X @ W
|
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
|
Y = torch.mm(X, W)
|
|
|
|
X_dt = distribute_tensor(
|
|
X, device_mesh, col_wise
|
|
) # col-wisely sharded X tensor
|
|
W_dt = distribute_tensor(
|
|
W, device_mesh, row_wise
|
|
) # row-wisely sharded W tensor
|
|
|
|
# Test 1: use the function returned from calling local_map
|
|
# get the function wrapped with DTensor/Tensor conversion
|
|
# mm_allreduce_forward is a function that applies to Tensors with manual collective
|
|
# local_mm_allreduce_forward is the function that does the same but applies to
|
|
# DTensors' `_local_tensor`.
|
|
local_mm_allreduce_forward = local_map(
|
|
mm_allreduce_forward,
|
|
out_placements=replicate,
|
|
in_placements=(None, col_wise, row_wise),
|
|
device_mesh=device_mesh,
|
|
)
|
|
with comm_mode:
|
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
|
|
|
# output redistribution to Replicate
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
# check output placements
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_replicate())
|
|
# check output value
|
|
self.assertEqual(Y_dt.to_local(), Y)
|
|
|
|
# Test 2: use the local_map decorator
|
|
with comm_mode:
|
|
Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt)
|
|
|
|
# output redistribution to Replicate
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
# check output placements
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_replicate())
|
|
# check output value
|
|
self.assertEqual(Y_dt.to_local(), Y)
|
|
|
|
# check for `out_placements`
|
|
@with_comms
|
|
def test_local_map_out_placements(self):
|
|
# Test 1: wrap out into DTensor w/ `out_placements`
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
comm_mode = CommDebugMode()
|
|
|
|
# X.equal(Y)
|
|
X = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
|
Y = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
|
X_dt = distribute_tensor(X, device_mesh, row_wise)
|
|
Y_dt = distribute_tensor(Y, device_mesh, row_wise)
|
|
local_equal_allgather_forward = local_map(
|
|
equal_allgather_forward,
|
|
out_placements=None,
|
|
)
|
|
with comm_mode:
|
|
equal_dt = local_equal_allgather_forward(device_mesh, X_dt, Y_dt) # a bool
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
self.assertTrue(not equal_dt)
|
|
self.assertTrue(not (X.equal(Y)))
|
|
|
|
# Test 2: directly return out if no argument is DTensor
|
|
# matmul in DDP
|
|
X = torch.randn(
|
|
4 // self.world_size, 4, device=self.device_type, requires_grad=False
|
|
)
|
|
W = torch.randn(4, 4, device=self.device_type, requires_grad=False)
|
|
local_mm_all_gather_forward = local_map(
|
|
mm_all_gather_forward,
|
|
out_placements=row_wise,
|
|
in_placements=(None, row_wise, replicate),
|
|
)
|
|
with comm_mode:
|
|
Y = local_mm_all_gather_forward(device_mesh, X, W)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
|
self.assertEqual(
|
|
comm_mode.get_comm_counts()[funcol_py.all_gather_into_tensor], 1
|
|
)
|
|
X_replicate = funcol.all_gather_tensor(X, 0, device_mesh).wait()
|
|
Y_replicate = torch.mm(X_replicate, W)
|
|
self.assertEqual(Y, Y_replicate) # Y is a torch.Tensor
|
|
|
|
# check for `in_placements` handling
|
|
@with_comms
|
|
def test_local_map_in_placements(self):
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
comm_mode = CommDebugMode()
|
|
|
|
# Y = X @ W
|
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
|
Y = torch.mm(X, W)
|
|
|
|
X_dt = distribute_tensor(
|
|
X, device_mesh, row_wise
|
|
) # row-wisely sharded X tensor
|
|
W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
|
|
|
|
# Test 1: explicitly pass `in_placements`
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=row_wise,
|
|
in_placements=(row_wise, replicate),
|
|
device_mesh=device_mesh,
|
|
)
|
|
with comm_mode:
|
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
|
|
|
# no communication should occur in this case
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_shard(dim=0))
|
|
self.assertEqual(Y_dt.full_tensor(), Y)
|
|
|
|
# Test 2: `in_placements=None`
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=row_wise,
|
|
device_mesh=device_mesh,
|
|
)
|
|
with comm_mode:
|
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_shard(dim=0))
|
|
self.assertEqual(Y_dt.full_tensor(), Y)
|
|
|
|
# Test 3: `None` placements for non-Tensor input argument
|
|
# Y = X * 2.0
|
|
local_mul_forward = local_map(
|
|
mul_forward,
|
|
in_placements=(row_wise, None),
|
|
out_placements=row_wise,
|
|
device_mesh=device_mesh,
|
|
)
|
|
Y = torch.mul(X, 2.0)
|
|
with comm_mode:
|
|
Y_dt = local_mul_forward(X_dt, 2.0)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_shard(dim=0))
|
|
self.assertEqual(Y_dt.full_tensor(), Y)
|
|
|
|
# Test 4: `None` placements for Tensor input argument
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=None,
|
|
in_placements=(None, None),
|
|
device_mesh=device_mesh,
|
|
)
|
|
with comm_mode:
|
|
Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(
|
|
DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
|
|
torch.mm(X, W),
|
|
)
|
|
|
|
# Test 5: Some placements for Tensor input argument
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=None,
|
|
in_placements=(replicate, row_wise),
|
|
device_mesh=device_mesh,
|
|
)
|
|
with comm_mode:
|
|
Y_dt_local = local_mm_forward(X_dt.to_local(), W_dt.to_local())
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
self.assertEqual(
|
|
DTensor.from_local(Y_dt_local, device_mesh, row_wise).full_tensor(),
|
|
torch.mm(X, W),
|
|
)
|
|
|
|
# Test 6: expect error - `None` placements for DTensor input argument
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=row_wise,
|
|
in_placements=(row_wise, None),
|
|
device_mesh=device_mesh,
|
|
)
|
|
with self.assertRaisesRegex(AssertionError, "expects placements"):
|
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
|
|
|
# check for `redistribute_inputs` handling
|
|
@with_comms
|
|
def test_local_map_redistribute(self):
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
comm_mode = CommDebugMode()
|
|
|
|
# Y = X @ W
|
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
|
Y = torch.mm(X, W)
|
|
|
|
X_dt = distribute_tensor(
|
|
X, device_mesh, row_wise
|
|
) # row-wisely sharded X tensor which will be redistributed
|
|
W_dt = distribute_tensor(
|
|
W, device_mesh, col_wise
|
|
) # col-wisely sharded W tensor which will be redistributed
|
|
|
|
# Test 1: allow input redistribution
|
|
local_mm_allreduce_forward = local_map(
|
|
mm_allreduce_forward,
|
|
out_placements=replicate,
|
|
in_placements=(None, col_wise, row_wise),
|
|
device_mesh=device_mesh,
|
|
redistribute_inputs=True,
|
|
)
|
|
with comm_mode:
|
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
|
|
|
# 2 for input redistribution and 1 for output
|
|
self.assertEqual(comm_mode.get_total_counts(), 3)
|
|
for placement in Y_dt.placements:
|
|
self.assertTrue(placement.is_replicate())
|
|
self.assertEqual(Y_dt.to_local(), Y)
|
|
|
|
# Test 2: no input redistribution is allowed
|
|
local_mm_allreduce_forward = local_map(
|
|
mm_allreduce_forward,
|
|
out_placements=replicate,
|
|
in_placements=(None, col_wise, row_wise),
|
|
device_mesh=device_mesh,
|
|
redistribute_inputs=False,
|
|
)
|
|
with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
|
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
|
|
|
# check for `in_grad_placements` handling
|
|
@with_comms()
|
|
def test_local_map_with_grad_placement(self):
|
|
"""
|
|
Test the gradient result is correct when we specify the right
|
|
`in_grad_placements`.
|
|
"""
|
|
device_mesh = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
torch.manual_seed(12)
|
|
|
|
# ground truth output, consider X as a batch of 2 on dim 0.
|
|
X = torch.randn(4, 2, device=self.device_type, requires_grad=True)
|
|
X1, X2 = torch.chunk(X, 2, dim=0)
|
|
X1 = X1.detach().requires_grad_()
|
|
X2 = X2.detach().requires_grad_()
|
|
W = torch.randn(2, 4, device=self.device_type, requires_grad=True)
|
|
Y1 = torch.mm(X1, W)
|
|
Y2 = torch.mm(X2, W)
|
|
loss = Y1.sum() + Y2.sum()
|
|
loss.backward()
|
|
|
|
in_placement_mismatch_choice = (False, True)
|
|
for is_in_placement_mismatch in in_placement_mismatch_choice:
|
|
if is_in_placement_mismatch:
|
|
# in_placements for local_map() will take effect
|
|
X_dt = distribute_tensor(X, device_mesh, replicate)
|
|
else:
|
|
# in_placements for local_map() will not take effect
|
|
X_dt = distribute_tensor(X, device_mesh, row_wise)
|
|
W_dt = distribute_tensor(W, device_mesh, replicate)
|
|
in_grad_placements = ([Shard(0)], [Partial()])
|
|
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=[Shard(0)],
|
|
in_placements=(row_wise, replicate),
|
|
in_grad_placements=in_grad_placements,
|
|
device_mesh=device_mesh,
|
|
redistribute_inputs=True,
|
|
)
|
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
|
self.assertEqual(Y_dt.full_tensor(), torch.cat([Y1, Y2], dim=0))
|
|
|
|
# Note: this is a way to simulate how DPP works. We don't need to
|
|
# all_gather the loss. Instead, we do all_reduce to each distributed
|
|
# weight.
|
|
loss = Y_dt.to_local().sum()
|
|
loss.backward()
|
|
|
|
if not is_in_placement_mismatch:
|
|
self.assertEqual(X_dt.grad.placements, in_grad_placements[0])
|
|
self.assertEqual(W_dt.grad.placements, in_grad_placements[1])
|
|
# regardless of is_in_placement_mismatch, grad output should always
|
|
# match
|
|
self.assertEqual(
|
|
X_dt.grad.full_tensor(), torch.cat([X1.grad, X2.grad], dim=0)
|
|
)
|
|
self.assertEqual(W_dt.grad.full_tensor(), W.grad)
|
|
|
|
@skip_if_lt_x_gpu(4)
|
|
@with_comms
|
|
def test_multi_mesh_inputs(self):
|
|
"""
|
|
Test the function can be applied to accept DTensors that lives
|
|
on different device meshes.
|
|
"""
|
|
mesh_full = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size,)
|
|
)
|
|
mesh_2d = init_device_mesh(
|
|
device_type=self.device_type, mesh_shape=(self.world_size // 2, 2)
|
|
)
|
|
comm_mode = CommDebugMode()
|
|
|
|
X = torch.randn(8, 32, device=self.device_type, requires_grad=False)
|
|
x_placements = [Shard(1)]
|
|
W = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
|
w_placements = [Shard(0), Shard(1)]
|
|
|
|
X_dt = distribute_tensor(X, mesh_full, x_placements)
|
|
W_dt = distribute_tensor(W, mesh_2d, w_placements)
|
|
|
|
# local output shape should be (8, 4)
|
|
output_placements = [Replicate(), Shard(1)]
|
|
|
|
local_mm_forward = local_map(
|
|
mm_forward,
|
|
out_placements=output_placements,
|
|
in_placements=(x_placements, w_placements),
|
|
device_mesh=mesh_2d,
|
|
)
|
|
|
|
with comm_mode:
|
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
|
|
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
|
# output local shape should be (8, 4)
|
|
self.assertEqual(Y_dt.to_local().shape, (8, 4))
|
|
# output lives in mesh_2d
|
|
self.assertEqual(Y_dt.device_mesh, mesh_2d)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|