mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix einsum strategy shard dim > ndim (#157593)
Previously we didn't constrain Shard dim to be <= the tensor's ndim. This cause an invalid strategy like `(RR, RS(2)) -> RS(2),` for einsum `bmk,kn->bmn` on the 2d mesh. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157593 Approved by: https://github.com/wconstab, https://github.com/wanchaol
This commit is contained in:
@ -98,6 +98,16 @@ class TestEinsumStrategies(DTensorOpTestBase):
|
||||
all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
|
||||
self.assertEqual(len(all_strats.strategies), 5)
|
||||
|
||||
def test_bmm_diffinndim_2d_mesh(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
||||
all_strats = gen_einsum_strategies("bmk,kn->bmn", mesh)
|
||||
self.assertEqual(len(all_strats.strategies), 25)
|
||||
|
||||
def test_bmm_diffoutndim_2d_mesh(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
||||
all_strats = gen_einsum_strategies("bmk,k->bm", mesh)
|
||||
self.assertEqual(len(all_strats.strategies), 16)
|
||||
|
||||
def test_bmm_2d_mesh(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
|
||||
|
||||
|
@ -90,80 +90,93 @@ def gen_einsum_strategies(
|
||||
) -> OpStrategy:
|
||||
"""
|
||||
Generate a strategy list for the ops that follow einsum style notation.
|
||||
|
||||
In principle, each mesh dim is independent of other device mesh dim when we
|
||||
generate strategies. So we generate strategy over each device mesh dim and
|
||||
do product combination on all mesh dims. We basically follow the below rule
|
||||
for each device mesh dim:
|
||||
|
||||
1. Shard on contracting dim: When both inputs shard on contracting dim over
|
||||
the same device dim. The result will be Partial over that device dim.
|
||||
|
||||
2. Shard on noncontracting dim:
|
||||
2.1: Shard on batch dim: output, both inputs all should shard on batch
|
||||
dim.
|
||||
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
|
||||
input should shard on this free dim.
|
||||
|
||||
3. Linearity (Partial): If enabled, set Partial on output and inputs over
|
||||
the same device mesh dim.
|
||||
"""
|
||||
# parse einop equation and extract dims
|
||||
input_dims, output_dim = EinsumDims.parse_equation(equation)
|
||||
edims = EinsumDims.parse_dims(input_dims, output_dim)
|
||||
|
||||
all_mesh_dim_strategies = []
|
||||
|
||||
# generate strategies for each mesh dim
|
||||
for mesh_dim in range(mesh.ndim):
|
||||
mesh_dim_strategies = []
|
||||
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
|
||||
strategies_over_one_mesh_dim = []
|
||||
|
||||
# placement list stores placements of [output, input1, input2, ...]
|
||||
# first we always have replicate all for inputs and output
|
||||
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
mesh_dim_strategies.append(placement_list)
|
||||
# placement list stores placements of [output, input1, input2, ...]
|
||||
# first we always have replicate all for inputs and output
|
||||
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split batch dim
|
||||
for batch_dim in edims.batch_dims:
|
||||
output_batch_dim = output_dim.index(batch_dim)
|
||||
placement_list = [Shard(output_batch_dim)]
|
||||
for input_dim in input_dims:
|
||||
input_batch_dim = input_dim.index(batch_dim)
|
||||
placement_list.append(Shard(input_batch_dim))
|
||||
# split batch dim
|
||||
for batch_dim in edims.batch_dims:
|
||||
output_batch_dim = output_dim.index(batch_dim)
|
||||
placement_list = [Shard(output_batch_dim)]
|
||||
for input_dim in input_dims:
|
||||
input_batch_dim = input_dim.index(batch_dim)
|
||||
placement_list.append(Shard(input_batch_dim))
|
||||
|
||||
mesh_dim_strategies.append(placement_list)
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split contracting dim
|
||||
for contracting_dim in edims.contracting_dims:
|
||||
placement_list = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
input_contracting_dim = input_dim.index(contracting_dim)
|
||||
placement_list.append(Shard(input_contracting_dim))
|
||||
# split contracting dim
|
||||
for contracting_dim in edims.contracting_dims:
|
||||
# Contracting dim can shard on same device axis for both inputs. This
|
||||
# results in the output being Partial on that device axis. For example:
|
||||
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
|
||||
placement_list = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
input_contracting_dim = input_dim.index(contracting_dim)
|
||||
placement_list.append(Shard(input_contracting_dim))
|
||||
|
||||
mesh_dim_strategies.append(placement_list)
|
||||
strategies_over_one_mesh_dim.append(placement_list)
|
||||
|
||||
# split lhs free dim
|
||||
for lhs_dim in edims.lhs_out_only_dims:
|
||||
lhs_free_dim = output_dim.index(lhs_dim)
|
||||
# this means split the lhs input and output
|
||||
# i.e. S(0), R -> S(0)
|
||||
lhs_placement_list: list[Placement] = [
|
||||
Shard(lhs_free_dim),
|
||||
Shard(lhs_free_dim),
|
||||
Replicate(),
|
||||
]
|
||||
mesh_dim_strategies.append(lhs_placement_list)
|
||||
# split lhs free dim
|
||||
for lhs_dim in edims.lhs_out_only_dims:
|
||||
lhs_free_dim_output = output_dim.index(lhs_dim)
|
||||
lhs_free_dim_input = input_dims[0].index(lhs_dim)
|
||||
# this means split the lhs input and output
|
||||
# i.e. S(0), R -> S(0)
|
||||
lhs_placement_list: list[Placement] = [
|
||||
Shard(lhs_free_dim_output),
|
||||
Shard(lhs_free_dim_input),
|
||||
Replicate(),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(lhs_placement_list)
|
||||
|
||||
# split rhs free dim
|
||||
for rhs_dim in edims.rhs_out_only_dims:
|
||||
rhs_free_dim = output_dim.index(rhs_dim)
|
||||
rhs_placement_list: list[Placement] = [
|
||||
Shard(rhs_free_dim),
|
||||
Replicate(),
|
||||
Shard(rhs_free_dim),
|
||||
]
|
||||
mesh_dim_strategies.append(rhs_placement_list)
|
||||
# split rhs free dim
|
||||
for rhs_dim in edims.rhs_out_only_dims:
|
||||
rhs_free_dim_output = output_dim.index(rhs_dim)
|
||||
rhs_free_dim_input = input_dims[1].index(rhs_dim)
|
||||
rhs_placement_list: list[Placement] = [
|
||||
Shard(rhs_free_dim_output),
|
||||
Replicate(),
|
||||
Shard(rhs_free_dim_input),
|
||||
]
|
||||
strategies_over_one_mesh_dim.append(rhs_placement_list)
|
||||
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
mesh_dim_strategies.append(linearity_placement_list)
|
||||
|
||||
all_mesh_dim_strategies.append(mesh_dim_strategies)
|
||||
# linearity strategy
|
||||
if linearity:
|
||||
linearity_placement_list: list[Placement] = [Partial()]
|
||||
for input_dim in input_dims:
|
||||
linearity_placement_list.append(Partial())
|
||||
strategies_over_one_mesh_dim.append(linearity_placement_list)
|
||||
|
||||
# generate strategies for entire mesh
|
||||
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
|
||||
strategy_combs = itertools.product(*all_mesh_dim_strategies)
|
||||
|
||||
# TODO: filter out invalid strategies, at this point we generate
|
||||
# all possible strategies without considering the whether the tensor
|
||||
# dim could be sharded or not, we would need to filter out invalid
|
||||
# strategies base on the actual tensor shape
|
||||
# (i.e. for Shard, tensor dim size must > mesh size)
|
||||
all_strategies = []
|
||||
for strategy_comb in strategy_combs:
|
||||
spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
|
||||
|
Reference in New Issue
Block a user