Fix einsum strategy shard dim > ndim (#157593)

Previously we didn't constrain Shard dim to be <= the tensor's ndim. This cause an invalid strategy like `(RR, RS(2)) -> RS(2),` for einsum `bmk,kn->bmn` on the 2d mesh.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157593
Approved by: https://github.com/wconstab, https://github.com/wanchaol
This commit is contained in:
zpcore
2025-07-08 20:27:17 +00:00
committed by PyTorch MergeBot
parent 06b3265cb1
commit a73d9e0aec
2 changed files with 80 additions and 57 deletions

View File

@ -98,6 +98,16 @@ class TestEinsumStrategies(DTensorOpTestBase):
all_strats = gen_einsum_strategies("bmk,bkn->bmn", mesh)
self.assertEqual(len(all_strats.strategies), 5)
def test_bmm_diffinndim_2d_mesh(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
all_strats = gen_einsum_strategies("bmk,kn->bmn", mesh)
self.assertEqual(len(all_strats.strategies), 25)
def test_bmm_diffoutndim_2d_mesh(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))
all_strats = gen_einsum_strategies("bmk,k->bm", mesh)
self.assertEqual(len(all_strats.strategies), 16)
def test_bmm_2d_mesh(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size).reshape(2, 2))

View File

@ -90,80 +90,93 @@ def gen_einsum_strategies(
) -> OpStrategy:
"""
Generate a strategy list for the ops that follow einsum style notation.
In principle, each mesh dim is independent of other device mesh dim when we
generate strategies. So we generate strategy over each device mesh dim and
do product combination on all mesh dims. We basically follow the below rule
for each device mesh dim:
1. Shard on contracting dim: When both inputs shard on contracting dim over
the same device dim. The result will be Partial over that device dim.
2. Shard on noncontracting dim:
2.1: Shard on batch dim: output, both inputs all should shard on batch
dim.
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
input should shard on this free dim.
3. Linearity (Partial): If enabled, set Partial on output and inputs over
the same device mesh dim.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
all_mesh_dim_strategies = []
# generate strategies for each mesh dim
for mesh_dim in range(mesh.ndim):
mesh_dim_strategies = []
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
strategies_over_one_mesh_dim = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
mesh_dim_strategies.append(placement_list)
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
strategies_over_one_mesh_dim.append(placement_list)
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
mesh_dim_strategies.append(placement_list)
strategies_over_one_mesh_dim.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
# split contracting dim
for contracting_dim in edims.contracting_dims:
# Contracting dim can shard on same device axis for both inputs. This
# results in the output being Partial on that device axis. For example:
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
mesh_dim_strategies.append(placement_list)
strategies_over_one_mesh_dim.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim = output_dim.index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim),
Shard(lhs_free_dim),
Replicate(),
]
mesh_dim_strategies.append(lhs_placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim_output = output_dim.index(lhs_dim)
lhs_free_dim_input = input_dims[0].index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim_output),
Shard(lhs_free_dim_input),
Replicate(),
]
strategies_over_one_mesh_dim.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim = output_dim.index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim),
Replicate(),
Shard(rhs_free_dim),
]
mesh_dim_strategies.append(rhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim_output = output_dim.index(rhs_dim)
rhs_free_dim_input = input_dims[1].index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim_output),
Replicate(),
Shard(rhs_free_dim_input),
]
strategies_over_one_mesh_dim.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
mesh_dim_strategies.append(linearity_placement_list)
all_mesh_dim_strategies.append(mesh_dim_strategies)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for input_dim in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
# generate strategies for entire mesh
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
# TODO: filter out invalid strategies, at this point we generate
# all possible strategies without considering the whether the tensor
# dim could be sharded or not, we would need to filter out invalid
# strategies base on the actual tensor shape
# (i.e. for Shard, tensor dim size must > mesh size)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]