mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157640 Approved by: https://github.com/yewentao256, https://github.com/malfet
348 lines
13 KiB
Python
348 lines
13 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
from itertools import chain
|
|
|
|
import torch
|
|
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
|
from torch.distributed.tensor._collective_utils import redistribute_cost
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
|
from torch.distributed.tensor._op_schema import OpSchema, OpSpec, OpStrategy
|
|
from torch.distributed.tensor._ops._einsum_strategy import (
|
|
EinsumDims,
|
|
gen_einsum_strategies,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
|
|
|
|
|
|
class TestEinsumDims(TestCase):
|
|
def test_batch_dims(self):
|
|
equation = "abc,abc->abc"
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, ["a", "b", "c"])
|
|
self.assertEqual(edims.contracting_dims, [])
|
|
self.assertEqual(edims.lhs_out_only_dims, [])
|
|
self.assertEqual(edims.rhs_out_only_dims, [])
|
|
|
|
def test_mm_dims(self):
|
|
equation = "mk,kn->mn"
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, [])
|
|
self.assertEqual(edims.contracting_dims, ["k"])
|
|
self.assertEqual(edims.lhs_out_only_dims, ["m"])
|
|
self.assertEqual(edims.rhs_out_only_dims, ["n"])
|
|
|
|
def test_bmm_dims(self):
|
|
equation = "bmk,bkn->bmn"
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, ["b"])
|
|
self.assertEqual(edims.contracting_dims, ["k"])
|
|
self.assertEqual(edims.lhs_out_only_dims, ["m"])
|
|
self.assertEqual(edims.rhs_out_only_dims, ["n"])
|
|
|
|
equation = "bcmk,bckn->bcmn"
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, ["b", "c"])
|
|
self.assertEqual(edims.contracting_dims, ["k"])
|
|
self.assertEqual(edims.lhs_out_only_dims, ["m"])
|
|
self.assertEqual(edims.rhs_out_only_dims, ["n"])
|
|
|
|
def test_free_dims(self):
|
|
equation = "abc,ab->abc"
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, ["a", "b"])
|
|
self.assertEqual(edims.contracting_dims, [])
|
|
self.assertEqual(edims.lhs_out_only_dims, ["c"])
|
|
self.assertEqual(edims.rhs_out_only_dims, [])
|
|
|
|
equation = "abd,bf->abfd" # codespell:ignore
|
|
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
|
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
|
|
|
self.assertEqual(edims.batch_dims, ["b"])
|
|
self.assertEqual(edims.contracting_dims, [])
|
|
self.assertEqual(edims.lhs_out_only_dims, ["a", "d"])
|
|
self.assertEqual(edims.rhs_out_only_dims, ["f"])
|
|
|
|
|
|
class TestEinsumStrategies(DTensorOpTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 4
|
|
|
|
def test_mm_1d_mesh(self):
|
|
mesh = self.build_device_mesh()
|
|
|
|
all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 4)
|
|
|
|
def test_mm_2d_mesh(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
|
|
|
all_strats = gen_einsum_strategies("mk,kn->mn", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 16)
|
|
|
|
def test_bmm_1d_mesh(self):
|
|
mesh = self.build_device_mesh()
|
|
|
|
all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 5)
|
|
|
|
def test_bmm_diffinndim_2d_mesh(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
|
all_strats = gen_einsum_strategies("bmk,kn->bmn", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 25)
|
|
|
|
def test_bmm_diffoutndim_2d_mesh(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
|
all_strats = gen_einsum_strategies("bmk,k->bm", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 16)
|
|
|
|
def test_bmm_2d_mesh(self):
|
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
|
|
|
all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
|
|
self.assertEqual(len(all_strats.strategies), 25)
|
|
|
|
def test_pointwise_1d_mesh(self):
|
|
mesh = self.build_device_mesh()
|
|
|
|
simple_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh)
|
|
self.assertEqual(len(simple_strats.strategies), 5)
|
|
|
|
broadcast_strats = gen_einsum_strategies("bcd,abcd->abcd", mesh)
|
|
self.assertEqual(len(broadcast_strats.strategies), 5)
|
|
|
|
def test_linearity_1d_mesh(self):
|
|
mesh = self.build_device_mesh()
|
|
|
|
all_strats = gen_einsum_strategies("abcd,abcd->abcd", mesh, linearity=True)
|
|
self.assertEqual(len(all_strats.strategies), 6)
|
|
|
|
|
|
class TestCostModel(DTensorOpTestBase):
|
|
def _extract_tensor_meta(self, t) -> TensorMeta:
|
|
return TensorMeta(t.shape, t.stride(), t.dtype)
|
|
|
|
@property
|
|
def world_size(self) -> int:
|
|
return 4
|
|
|
|
def test_redistribute_cost_mesh_1d(self):
|
|
mesh_1d = self.build_device_mesh()
|
|
shard_placement = (Shard(0),)
|
|
replica_placement = (Replicate(),)
|
|
partial_placement = (Partial(),)
|
|
|
|
global_tensor = torch.randn(10, 10)
|
|
global_tensor_meta = self._extract_tensor_meta(global_tensor)
|
|
|
|
# shard spec
|
|
shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta)
|
|
# replica spec
|
|
replica_spec = DTensorSpec(mesh_1d, replica_placement, global_tensor_meta)
|
|
# partial spec
|
|
partial_spec = DTensorSpec(mesh_1d, partial_placement, global_tensor_meta)
|
|
|
|
# make sure reshard cost is 0 for the same spec redistribute
|
|
for spec in [shard_spec, replica_spec, partial_spec]:
|
|
cost = redistribute_cost(spec, spec)
|
|
self.assertEqual(cost, 0)
|
|
|
|
# shard -> replicate
|
|
allgather_cost = redistribute_cost(shard_spec, replica_spec)
|
|
# partial -> shard
|
|
reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
|
|
# partial -> replicate
|
|
allreduce_cost = redistribute_cost(partial_spec, replica_spec)
|
|
self.assertEqual(allgather_cost, reduce_scatter_cost)
|
|
self.assertTrue(allreduce_cost + 1 < allgather_cost + reduce_scatter_cost)
|
|
# shard to partial
|
|
cost = redistribute_cost(shard_spec, partial_spec)
|
|
self.assertEqual(cost, float("inf"))
|
|
|
|
def test_redistribute_cost_latency(self):
|
|
# test cost model on addmm op
|
|
from torch.distributed.tensor._ops._matrix_ops import addmm_strategy
|
|
|
|
mesh = self.build_device_mesh()
|
|
shard0_placement = (Shard(0),)
|
|
partial_placement = (Partial(),)
|
|
shard1_placement = (Shard(1),)
|
|
|
|
shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8))
|
|
partial_tensor_meta = self._extract_tensor_meta(torch.randn(50, 6))
|
|
shard1_tensor_meta = self._extract_tensor_meta(torch.randn(6, 8))
|
|
|
|
# shard spec
|
|
shard0_spec = DTensorSpec(mesh, shard0_placement, shard0_tensor_meta)
|
|
# replica spec
|
|
partial_spec = DTensorSpec(mesh, partial_placement, partial_tensor_meta)
|
|
# partial spec
|
|
shard1_spec = DTensorSpec(mesh, shard1_placement, shard1_tensor_meta)
|
|
|
|
op_schema = OpSchema(
|
|
torch.ops.aten.addmm.default,
|
|
(
|
|
OpStrategy([OpSpec(shard0_spec)]),
|
|
OpStrategy([OpSpec(partial_spec)]),
|
|
OpStrategy([OpSpec(shard1_spec)]),
|
|
),
|
|
{},
|
|
)
|
|
|
|
output_strategy = addmm_strategy(op_schema)
|
|
strategy_costs = {}
|
|
for strategy in output_strategy.strategies:
|
|
redistribute_cost = sum(chain.from_iterable(strategy.redistribute_cost))
|
|
strategy_costs[str(strategy)] = redistribute_cost
|
|
|
|
# assert that cost model counts for collective latency (i.e. multiple comm is penalized)
|
|
self.assertTrue(
|
|
strategy_costs["(S(0), R, S(1)) -> S(1)"]
|
|
< strategy_costs["(R, S(0), R) -> S(0)"]
|
|
)
|
|
# assert a single allreduce is the best one
|
|
self.assertEqual(
|
|
strategy_costs["(S(0), R, S(1)) -> S(1)"], min(strategy_costs.values())
|
|
)
|
|
|
|
def test_redistribute_cost_mesh_2d(self):
|
|
mesh_2d = DeviceMesh(
|
|
self.device_type, torch.arange(self.world_size).reshape(2, 2)
|
|
)
|
|
shard_placement = (Shard(0), Shard(0))
|
|
replica_placement = (Replicate(), Replicate())
|
|
partial_placement = (Partial(), Partial())
|
|
|
|
global_tensor = torch.randn(8, 8)
|
|
global_tensor_meta = self._extract_tensor_meta(global_tensor)
|
|
|
|
# shard spec
|
|
shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta)
|
|
# replica spec
|
|
replica_spec = DTensorSpec(mesh_2d, replica_placement, global_tensor_meta)
|
|
# partial spec
|
|
partial_spec = DTensorSpec(mesh_2d, partial_placement, global_tensor_meta)
|
|
|
|
# make sure reshard cost is 0 for the same spec redistribute
|
|
for spec in [shard_spec, replica_spec, partial_spec]:
|
|
cost = redistribute_cost(spec, spec)
|
|
self.assertEqual(cost, 0)
|
|
|
|
# shard -> replicate
|
|
allgather_cost = redistribute_cost(shard_spec, replica_spec)
|
|
# partial -> replicate
|
|
allreduce_cost = redistribute_cost(partial_spec, replica_spec)
|
|
# partial -> shard
|
|
reduce_scatter_cost = redistribute_cost(partial_spec, shard_spec)
|
|
self.assertTrue(allreduce_cost > allgather_cost)
|
|
self.assertTrue(allreduce_cost > reduce_scatter_cost)
|
|
|
|
def test_mm_strategies(self):
|
|
from torch.distributed.tensor._ops._matrix_ops import mm_strategy
|
|
|
|
mesh = self.build_device_mesh()
|
|
lhs_tensor = torch.randn(6, 8)
|
|
rhs_tensor = torch.randn(8, 12)
|
|
lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
|
|
rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
|
|
|
|
mm_combs = (
|
|
(Shard(0), Replicate()),
|
|
(Replicate(), Shard(1)),
|
|
(Shard(1), Shard(0)),
|
|
(Replicate(), Replicate()),
|
|
)
|
|
for lhs, rhs in mm_combs:
|
|
lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta)
|
|
rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta)
|
|
|
|
op_schema = OpSchema(
|
|
torch.ops.aten.mm.default,
|
|
(
|
|
OpStrategy([OpSpec(lhs_spec)]),
|
|
OpStrategy([OpSpec(rhs_spec)]),
|
|
),
|
|
{},
|
|
)
|
|
# test the strategy
|
|
res_strategies = mm_strategy(op_schema)
|
|
|
|
for strtgy in res_strategies.strategies:
|
|
if strtgy.input_specs == (lhs_spec, rhs_spec):
|
|
self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]])
|
|
break
|
|
|
|
op_schema = OpSchema(
|
|
torch.ops.aten.mm.default,
|
|
(lhs_spec, rhs_spec),
|
|
{},
|
|
)
|
|
# test sharding prop
|
|
output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached(
|
|
op_schema
|
|
)
|
|
self.assertFalse(output_sharding.needs_redistribute)
|
|
|
|
def test_bmm_strategies(self):
|
|
from torch.distributed.tensor._ops._matrix_ops import bmm_strategy
|
|
|
|
mesh = self.build_device_mesh()
|
|
lhs_tensor = torch.randn(8, 6, 8)
|
|
rhs_tensor = torch.randn(8, 8, 12)
|
|
lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
|
|
rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
|
|
|
|
bmm_combs = (
|
|
(Shard(0), Shard(0)),
|
|
(Shard(1), Replicate()),
|
|
(Replicate(), Shard(2)),
|
|
(Shard(2), Shard(1)),
|
|
(Replicate(), Replicate()),
|
|
)
|
|
for lhs, rhs in bmm_combs:
|
|
lhs_spec = DTensorSpec(mesh, (lhs,), lhs_tensor_meta)
|
|
rhs_spec = DTensorSpec(mesh, (rhs,), rhs_tensor_meta)
|
|
|
|
op_schema = OpSchema(
|
|
torch.ops.aten.bmm.default,
|
|
(
|
|
OpStrategy([OpSpec(lhs_spec)]),
|
|
OpStrategy([OpSpec(rhs_spec)]),
|
|
),
|
|
{},
|
|
)
|
|
# test the strategy
|
|
res_strategies = bmm_strategy(op_schema)
|
|
|
|
for strtgy in res_strategies.strategies:
|
|
if strtgy.input_specs == (lhs_spec, rhs_spec):
|
|
self.assertEqual(strtgy.redistribute_cost, [[0.0], [0.0]])
|
|
break
|
|
|
|
op_schema = OpSchema(
|
|
torch.ops.aten.bmm.default,
|
|
(lhs_spec, rhs_spec),
|
|
{},
|
|
)
|
|
# test sharding prop
|
|
output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding_non_cached(
|
|
op_schema
|
|
)
|
|
self.assertFalse(output_sharding.needs_redistribute)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|