mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dtensor] add test for local_map decorator (#127752)
**Summary** This PR is a follow-up of #126924 to address reviewer's comments: 1) add a test case to show the use of `local_map` as a function decorator. 2) simplify the logic of handling different data types of `out_placements`. 3) correct variable naming in test cases to match math formulas. **Test** see #126924 Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
8de0d7690c
commit
0159ebb654
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates
|
# Copyright (c) Meta Platforms, Inc. and affiliates
|
||||||
# Owner(s): ["oncall: distributed"]
|
# Owner(s): ["oncall: distributed"]
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed._functional_collectives as funcol
|
import torch.distributed._functional_collectives as funcol
|
||||||
@ -22,6 +23,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|||||||
funcol_py = torch.ops.c10d_functional
|
funcol_py = torch.ops.c10d_functional
|
||||||
|
|
||||||
|
|
||||||
|
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
|
||||||
|
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
|
||||||
|
replicate = [Replicate()] # replicate placements on 1-d mesh
|
||||||
|
|
||||||
|
|
||||||
def equal_allgather_forward(device_mesh, X, Y):
|
def equal_allgather_forward(device_mesh, X, Y):
|
||||||
eq = torch.tensor([torch.equal(X, Y)], device=X.device)
|
eq = torch.tensor([torch.equal(X, Y)], device=X.device)
|
||||||
eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
|
eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
|
||||||
@ -42,6 +48,16 @@ def mm_allreduce_forward(device_mesh, A, B):
|
|||||||
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
||||||
|
|
||||||
|
|
||||||
|
@partial(
|
||||||
|
local_map,
|
||||||
|
out_placements=replicate,
|
||||||
|
in_placements=(None, col_wise, row_wise),
|
||||||
|
)
|
||||||
|
def mm_allreduce_forward_decorated(device_mesh, A, B):
|
||||||
|
partial_sum_tensor = torch.mm(A, B)
|
||||||
|
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
|
||||||
|
|
||||||
|
|
||||||
def mul_forward(X, scalar): # no device mesh needed since we don't do collective
|
def mul_forward(X, scalar): # no device mesh needed since we don't do collective
|
||||||
return torch.mul(X, scalar)
|
return torch.mul(X, scalar)
|
||||||
|
|
||||||
@ -59,20 +75,19 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
)
|
)
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
|
|
||||||
# Y = W @ X
|
# Y = X @ W
|
||||||
W = torch.randn(12, 8, device=self.device_type, requires_grad=False)
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
||||||
X = torch.randn(8, 16, device=self.device_type, requires_grad=False)
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
||||||
Y = torch.mm(W, X)
|
Y = torch.mm(X, W)
|
||||||
|
|
||||||
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
|
|
||||||
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
|
|
||||||
replicate = [Replicate()]
|
|
||||||
W_dt = distribute_tensor(
|
|
||||||
W, device_mesh, col_wise
|
|
||||||
) # col-wisely sharded W tensor
|
|
||||||
X_dt = distribute_tensor(
|
X_dt = distribute_tensor(
|
||||||
X, device_mesh, row_wise
|
X, device_mesh, col_wise
|
||||||
) # row-wisely sharded X tensor
|
) # col-wisely sharded X tensor
|
||||||
|
W_dt = distribute_tensor(
|
||||||
|
W, device_mesh, row_wise
|
||||||
|
) # row-wisely sharded W tensor
|
||||||
|
|
||||||
|
# Test 1: use the function returned from calling local_map
|
||||||
# get the function wrapped with DTensor/Tensor convertion
|
# get the function wrapped with DTensor/Tensor convertion
|
||||||
# mm_allreduce_forward is a function that applies to Tensors with manual collective
|
# mm_allreduce_forward is a function that applies to Tensors with manual collective
|
||||||
# local_mm_allreduce_forward is the function that does the same but applies to
|
# local_mm_allreduce_forward is the function that does the same but applies to
|
||||||
@ -84,7 +99,19 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
||||||
|
|
||||||
|
# output redistribution to Replicate
|
||||||
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||||
|
# check output placements
|
||||||
|
for placement in Y_dt.placements:
|
||||||
|
self.assertTrue(placement.is_replicate())
|
||||||
|
# check output value
|
||||||
|
self.assertEqual(Y_dt.to_local(), Y)
|
||||||
|
|
||||||
|
# Test 2: use the local_map decorator
|
||||||
|
with comm_mode:
|
||||||
|
Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt)
|
||||||
|
|
||||||
# output redistribution to Replicate
|
# output redistribution to Replicate
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 1)
|
self.assertEqual(comm_mode.get_total_counts(), 1)
|
||||||
@ -106,7 +133,6 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
# X.equal(Y)
|
# X.equal(Y)
|
||||||
X = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
X = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
||||||
Y = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
Y = torch.randn(8, 8, device=self.device_type, requires_grad=False)
|
||||||
row_wise = [Shard(0)]
|
|
||||||
X_dt = distribute_tensor(X, device_mesh, row_wise)
|
X_dt = distribute_tensor(X, device_mesh, row_wise)
|
||||||
Y_dt = distribute_tensor(Y, device_mesh, row_wise)
|
Y_dt = distribute_tensor(Y, device_mesh, row_wise)
|
||||||
local_equal_allgather_forward = local_map(
|
local_equal_allgather_forward = local_map(
|
||||||
@ -122,7 +148,6 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
|
|
||||||
# Test 2: directly return out if no argument is DTensor
|
# Test 2: directly return out if no argument is DTensor
|
||||||
# matmul in DDP
|
# matmul in DDP
|
||||||
replicate = [Replicate()]
|
|
||||||
X = torch.randn(
|
X = torch.randn(
|
||||||
4 // self.world_size, 4, device=self.device_type, requires_grad=False
|
4 // self.world_size, 4, device=self.device_type, requires_grad=False
|
||||||
)
|
)
|
||||||
@ -151,17 +176,15 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
)
|
)
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
|
|
||||||
# Y = W @ X
|
# Y = X @ W
|
||||||
W = torch.randn(12, 8, device=self.device_type, requires_grad=False)
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
||||||
X = torch.randn(8, 16, device=self.device_type, requires_grad=False)
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
||||||
Y = torch.mm(W, X)
|
Y = torch.mm(X, W)
|
||||||
|
|
||||||
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
|
X_dt = distribute_tensor(
|
||||||
replicate = [Replicate()] # replicate placements on 1-d mesh
|
X, device_mesh, row_wise
|
||||||
W_dt = distribute_tensor(
|
) # row-wisely sharded X tensor
|
||||||
W, device_mesh, row_wise
|
W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
|
||||||
) # row-wisely sharded W tensor
|
|
||||||
X_dt = distribute_tensor(X, device_mesh, replicate) # replicate X tensor
|
|
||||||
|
|
||||||
# Test 1: explicitly pass `in_placements`
|
# Test 1: explicitly pass `in_placements`
|
||||||
local_mm_forward = local_map(
|
local_mm_forward = local_map(
|
||||||
@ -171,7 +194,7 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
Y_dt = local_mm_forward(W_dt, X_dt)
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
||||||
|
|
||||||
# no communication should occur in this case
|
# no communication should occur in this case
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
@ -186,7 +209,7 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
Y_dt = local_mm_forward(W_dt, X_dt)
|
Y_dt = local_mm_forward(X_dt, W_dt)
|
||||||
|
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
for placement in Y_dt.placements:
|
for placement in Y_dt.placements:
|
||||||
@ -194,15 +217,16 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
self.assertEqual(Y_dt.full_tensor(), Y)
|
self.assertEqual(Y_dt.full_tensor(), Y)
|
||||||
|
|
||||||
# Test 3: `None` placements for non-Tensor input argument
|
# Test 3: `None` placements for non-Tensor input argument
|
||||||
|
# Y = X * 2.0
|
||||||
local_mul_forward = local_map(
|
local_mul_forward = local_map(
|
||||||
mul_forward,
|
mul_forward,
|
||||||
in_placements=(row_wise, None),
|
in_placements=(row_wise, None),
|
||||||
out_placements=row_wise,
|
out_placements=row_wise,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
)
|
)
|
||||||
Y = torch.mul(W, 2.0)
|
Y = torch.mul(X, 2.0)
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
Y_dt = local_mul_forward(W_dt, 2.0)
|
Y_dt = local_mul_forward(X_dt, 2.0)
|
||||||
|
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 0)
|
self.assertEqual(comm_mode.get_total_counts(), 0)
|
||||||
for placement in Y_dt.placements:
|
for placement in Y_dt.placements:
|
||||||
@ -210,12 +234,6 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
self.assertEqual(Y_dt.full_tensor(), Y)
|
self.assertEqual(Y_dt.full_tensor(), Y)
|
||||||
|
|
||||||
# Test 4: `None` placements for Tensor input argument
|
# Test 4: `None` placements for Tensor input argument
|
||||||
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
|
||||||
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
|
||||||
X_dt = distribute_tensor(
|
|
||||||
X, device_mesh, row_wise
|
|
||||||
) # row-wisely sharded X tensor
|
|
||||||
W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
|
|
||||||
local_mm_forward = local_map(
|
local_mm_forward = local_map(
|
||||||
mm_forward,
|
mm_forward,
|
||||||
out_placements=None,
|
out_placements=None,
|
||||||
@ -265,20 +283,17 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
)
|
)
|
||||||
comm_mode = CommDebugMode()
|
comm_mode = CommDebugMode()
|
||||||
|
|
||||||
# Y = W @ X
|
# Y = X @ W
|
||||||
W = torch.randn(12, 8, device=self.device_type, requires_grad=False)
|
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
|
||||||
X = torch.randn(8, 16, device=self.device_type, requires_grad=False)
|
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
|
||||||
Y = torch.mm(W, X)
|
Y = torch.mm(X, W)
|
||||||
|
|
||||||
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
|
|
||||||
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
|
|
||||||
replicate = [Replicate()]
|
|
||||||
W_dt = distribute_tensor(
|
|
||||||
W, device_mesh, row_wise
|
|
||||||
) # row-wisely sharded W tensor which will be redistributed
|
|
||||||
X_dt = distribute_tensor(
|
X_dt = distribute_tensor(
|
||||||
X, device_mesh, col_wise
|
X, device_mesh, row_wise
|
||||||
) # col-wisely sharded X tensor which will be redistributed
|
) # row-wisely sharded X tensor which will be redistributed
|
||||||
|
W_dt = distribute_tensor(
|
||||||
|
W, device_mesh, col_wise
|
||||||
|
) # col-wisely sharded W tensor which will be redistributed
|
||||||
|
|
||||||
# Test 1: allow input redistribution
|
# Test 1: allow input redistribution
|
||||||
local_mm_allreduce_forward = local_map(
|
local_mm_allreduce_forward = local_map(
|
||||||
@ -289,7 +304,7 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
redistribute_inputs=True,
|
redistribute_inputs=True,
|
||||||
)
|
)
|
||||||
with comm_mode:
|
with comm_mode:
|
||||||
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
||||||
|
|
||||||
# 2 for input redistribution and 1 for output
|
# 2 for input redistribution and 1 for output
|
||||||
self.assertEqual(comm_mode.get_total_counts(), 3)
|
self.assertEqual(comm_mode.get_total_counts(), 3)
|
||||||
@ -306,7 +321,7 @@ class TestLocalMap(DTensorTestBase):
|
|||||||
redistribute_inputs=False,
|
redistribute_inputs=False,
|
||||||
)
|
)
|
||||||
with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
|
with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
|
||||||
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt)
|
Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -194,13 +194,17 @@ def local_map(
|
|||||||
flat_out, out_spec = pytree.tree_flatten(out)
|
flat_out, out_spec = pytree.tree_flatten(out)
|
||||||
|
|
||||||
flat_dist_out = []
|
flat_dist_out = []
|
||||||
for idx, out in enumerate(flat_out):
|
out_placements_tuple = (
|
||||||
spec = (
|
out_placements
|
||||||
out_placements[idx]
|
if isinstance(out_placements, tuple)
|
||||||
if isinstance(out_placements, tuple)
|
else (out_placements,)
|
||||||
else out_placements
|
)
|
||||||
)
|
assert len(flat_out) == len(out_placements_tuple), (
|
||||||
|
"local_map requires one PlacementType be provided for each output value,"
|
||||||
|
f" received {len(out_placements_tuple)} out_placements but"
|
||||||
|
f" {len(flat_out)} is expected!"
|
||||||
|
)
|
||||||
|
for out, spec in zip(flat_out, out_placements_tuple):
|
||||||
if isinstance(out, torch.Tensor):
|
if isinstance(out, torch.Tensor):
|
||||||
assert not isinstance(
|
assert not isinstance(
|
||||||
out, DTensor
|
out, DTensor
|
||||||
|
Reference in New Issue
Block a user