[dtensor] add test for local_map decorator (#127752)

**Summary**
This PR is a follow-up of #126924 to address reviewer's comments:
1) add a test case to show the use of `local_map` as a function decorator.
2) simplify the logic of handling different data types of `out_placements`.
3) correct variable naming in test cases to match math formulas.

**Test**
see #126924

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127752
Approved by: https://github.com/wanchaol
This commit is contained in:
Xilun Wu
2024-08-26 16:19:38 -07:00
committed by PyTorch MergeBot
parent 8de0d7690c
commit 0159ebb654
2 changed files with 75 additions and 56 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates # Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"] # Owner(s): ["oncall: distributed"]
from functools import partial
import torch import torch
import torch.distributed._functional_collectives as funcol import torch.distributed._functional_collectives as funcol
@ -22,6 +23,11 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
funcol_py = torch.ops.c10d_functional funcol_py = torch.ops.c10d_functional
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
replicate = [Replicate()] # replicate placements on 1-d mesh
def equal_allgather_forward(device_mesh, X, Y): def equal_allgather_forward(device_mesh, X, Y):
eq = torch.tensor([torch.equal(X, Y)], device=X.device) eq = torch.tensor([torch.equal(X, Y)], device=X.device)
eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh) eq_gather = funcol.all_gather_tensor(eq, 0, device_mesh)
@ -42,6 +48,16 @@ def mm_allreduce_forward(device_mesh, A, B):
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait() return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
@partial(
local_map,
out_placements=replicate,
in_placements=(None, col_wise, row_wise),
)
def mm_allreduce_forward_decorated(device_mesh, A, B):
partial_sum_tensor = torch.mm(A, B)
return funcol.all_reduce(partial_sum_tensor, "sum", device_mesh).wait()
def mul_forward(X, scalar): # no device mesh needed since we don't do collective def mul_forward(X, scalar): # no device mesh needed since we don't do collective
return torch.mul(X, scalar) return torch.mul(X, scalar)
@ -59,20 +75,19 @@ class TestLocalMap(DTensorTestBase):
) )
comm_mode = CommDebugMode() comm_mode = CommDebugMode()
# Y = W @ X # Y = X @ W
W = torch.randn(12, 8, device=self.device_type, requires_grad=False) X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
X = torch.randn(8, 16, device=self.device_type, requires_grad=False) W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
Y = torch.mm(W, X) Y = torch.mm(X, W)
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
replicate = [Replicate()]
W_dt = distribute_tensor(
W, device_mesh, col_wise
) # col-wisely sharded W tensor
X_dt = distribute_tensor( X_dt = distribute_tensor(
X, device_mesh, row_wise X, device_mesh, col_wise
) # row-wisely sharded X tensor ) # col-wisely sharded X tensor
W_dt = distribute_tensor(
W, device_mesh, row_wise
) # row-wisely sharded W tensor
# Test 1: use the function returned from calling local_map
# get the function wrapped with DTensor/Tensor convertion # get the function wrapped with DTensor/Tensor convertion
# mm_allreduce_forward is a function that applies to Tensors with manual collective # mm_allreduce_forward is a function that applies to Tensors with manual collective
# local_mm_allreduce_forward is the function that does the same but applies to # local_mm_allreduce_forward is the function that does the same but applies to
@ -84,7 +99,19 @@ class TestLocalMap(DTensorTestBase):
device_mesh=device_mesh, device_mesh=device_mesh,
) )
with comm_mode: with comm_mode:
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
# output redistribution to Replicate
self.assertEqual(comm_mode.get_total_counts(), 1)
# check output placements
for placement in Y_dt.placements:
self.assertTrue(placement.is_replicate())
# check output value
self.assertEqual(Y_dt.to_local(), Y)
# Test 2: use the local_map decorator
with comm_mode:
Y_dt = mm_allreduce_forward_decorated(device_mesh, X_dt, W_dt)
# output redistribution to Replicate # output redistribution to Replicate
self.assertEqual(comm_mode.get_total_counts(), 1) self.assertEqual(comm_mode.get_total_counts(), 1)
@ -106,7 +133,6 @@ class TestLocalMap(DTensorTestBase):
# X.equal(Y) # X.equal(Y)
X = torch.randn(8, 8, device=self.device_type, requires_grad=False) X = torch.randn(8, 8, device=self.device_type, requires_grad=False)
Y = torch.randn(8, 8, device=self.device_type, requires_grad=False) Y = torch.randn(8, 8, device=self.device_type, requires_grad=False)
row_wise = [Shard(0)]
X_dt = distribute_tensor(X, device_mesh, row_wise) X_dt = distribute_tensor(X, device_mesh, row_wise)
Y_dt = distribute_tensor(Y, device_mesh, row_wise) Y_dt = distribute_tensor(Y, device_mesh, row_wise)
local_equal_allgather_forward = local_map( local_equal_allgather_forward = local_map(
@ -122,7 +148,6 @@ class TestLocalMap(DTensorTestBase):
# Test 2: directly return out if no argument is DTensor # Test 2: directly return out if no argument is DTensor
# matmul in DDP # matmul in DDP
replicate = [Replicate()]
X = torch.randn( X = torch.randn(
4 // self.world_size, 4, device=self.device_type, requires_grad=False 4 // self.world_size, 4, device=self.device_type, requires_grad=False
) )
@ -151,17 +176,15 @@ class TestLocalMap(DTensorTestBase):
) )
comm_mode = CommDebugMode() comm_mode = CommDebugMode()
# Y = W @ X # Y = X @ W
W = torch.randn(12, 8, device=self.device_type, requires_grad=False) X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
X = torch.randn(8, 16, device=self.device_type, requires_grad=False) W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
Y = torch.mm(W, X) Y = torch.mm(X, W)
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh X_dt = distribute_tensor(
replicate = [Replicate()] # replicate placements on 1-d mesh X, device_mesh, row_wise
W_dt = distribute_tensor( ) # row-wisely sharded X tensor
W, device_mesh, row_wise W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
) # row-wisely sharded W tensor
X_dt = distribute_tensor(X, device_mesh, replicate) # replicate X tensor
# Test 1: explicitly pass `in_placements` # Test 1: explicitly pass `in_placements`
local_mm_forward = local_map( local_mm_forward = local_map(
@ -171,7 +194,7 @@ class TestLocalMap(DTensorTestBase):
device_mesh=device_mesh, device_mesh=device_mesh,
) )
with comm_mode: with comm_mode:
Y_dt = local_mm_forward(W_dt, X_dt) Y_dt = local_mm_forward(X_dt, W_dt)
# no communication should occur in this case # no communication should occur in this case
self.assertEqual(comm_mode.get_total_counts(), 0) self.assertEqual(comm_mode.get_total_counts(), 0)
@ -186,7 +209,7 @@ class TestLocalMap(DTensorTestBase):
device_mesh=device_mesh, device_mesh=device_mesh,
) )
with comm_mode: with comm_mode:
Y_dt = local_mm_forward(W_dt, X_dt) Y_dt = local_mm_forward(X_dt, W_dt)
self.assertEqual(comm_mode.get_total_counts(), 0) self.assertEqual(comm_mode.get_total_counts(), 0)
for placement in Y_dt.placements: for placement in Y_dt.placements:
@ -194,15 +217,16 @@ class TestLocalMap(DTensorTestBase):
self.assertEqual(Y_dt.full_tensor(), Y) self.assertEqual(Y_dt.full_tensor(), Y)
# Test 3: `None` placements for non-Tensor input argument # Test 3: `None` placements for non-Tensor input argument
# Y = X * 2.0
local_mul_forward = local_map( local_mul_forward = local_map(
mul_forward, mul_forward,
in_placements=(row_wise, None), in_placements=(row_wise, None),
out_placements=row_wise, out_placements=row_wise,
device_mesh=device_mesh, device_mesh=device_mesh,
) )
Y = torch.mul(W, 2.0) Y = torch.mul(X, 2.0)
with comm_mode: with comm_mode:
Y_dt = local_mul_forward(W_dt, 2.0) Y_dt = local_mul_forward(X_dt, 2.0)
self.assertEqual(comm_mode.get_total_counts(), 0) self.assertEqual(comm_mode.get_total_counts(), 0)
for placement in Y_dt.placements: for placement in Y_dt.placements:
@ -210,12 +234,6 @@ class TestLocalMap(DTensorTestBase):
self.assertEqual(Y_dt.full_tensor(), Y) self.assertEqual(Y_dt.full_tensor(), Y)
# Test 4: `None` placements for Tensor input argument # Test 4: `None` placements for Tensor input argument
X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
X_dt = distribute_tensor(
X, device_mesh, row_wise
) # row-wisely sharded X tensor
W_dt = distribute_tensor(W, device_mesh, replicate) # replicate W tensor
local_mm_forward = local_map( local_mm_forward = local_map(
mm_forward, mm_forward,
out_placements=None, out_placements=None,
@ -265,20 +283,17 @@ class TestLocalMap(DTensorTestBase):
) )
comm_mode = CommDebugMode() comm_mode = CommDebugMode()
# Y = W @ X # Y = X @ W
W = torch.randn(12, 8, device=self.device_type, requires_grad=False) X = torch.randn(16, 8, device=self.device_type, requires_grad=False)
X = torch.randn(8, 16, device=self.device_type, requires_grad=False) W = torch.randn(8, 12, device=self.device_type, requires_grad=False)
Y = torch.mm(W, X) Y = torch.mm(X, W)
row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh
col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh
replicate = [Replicate()]
W_dt = distribute_tensor(
W, device_mesh, row_wise
) # row-wisely sharded W tensor which will be redistributed
X_dt = distribute_tensor( X_dt = distribute_tensor(
X, device_mesh, col_wise X, device_mesh, row_wise
) # col-wisely sharded X tensor which will be redistributed ) # row-wisely sharded X tensor which will be redistributed
W_dt = distribute_tensor(
W, device_mesh, col_wise
) # col-wisely sharded W tensor which will be redistributed
# Test 1: allow input redistribution # Test 1: allow input redistribution
local_mm_allreduce_forward = local_map( local_mm_allreduce_forward = local_map(
@ -289,7 +304,7 @@ class TestLocalMap(DTensorTestBase):
redistribute_inputs=True, redistribute_inputs=True,
) )
with comm_mode: with comm_mode:
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
# 2 for input redistribution and 1 for output # 2 for input redistribution and 1 for output
self.assertEqual(comm_mode.get_total_counts(), 3) self.assertEqual(comm_mode.get_total_counts(), 3)
@ -306,7 +321,7 @@ class TestLocalMap(DTensorTestBase):
redistribute_inputs=False, redistribute_inputs=False,
) )
with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"): with self.assertRaisesRegex(ValueError, "set redistribute_inputs=True"):
Y_dt = local_mm_allreduce_forward(device_mesh, W_dt, X_dt) Y_dt = local_mm_allreduce_forward(device_mesh, X_dt, W_dt)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -194,13 +194,17 @@ def local_map(
flat_out, out_spec = pytree.tree_flatten(out) flat_out, out_spec = pytree.tree_flatten(out)
flat_dist_out = [] flat_dist_out = []
for idx, out in enumerate(flat_out): out_placements_tuple = (
spec = ( out_placements
out_placements[idx] if isinstance(out_placements, tuple)
if isinstance(out_placements, tuple) else (out_placements,)
else out_placements )
) assert len(flat_out) == len(out_placements_tuple), (
"local_map requires one PlacementType be provided for each output value,"
f" received {len(out_placements_tuple)} out_placements but"
f" {len(flat_out)} is expected!"
)
for out, spec in zip(flat_out, out_placements_tuple):
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
assert not isinstance( assert not isinstance(
out, DTensor out, DTensor