Compare commits

...

6 Commits

Author SHA1 Message Date
2a94e1833c claude generate tests 2025-11-14 13:10:07 -08:00
0e6ba6334d Notes on tensor_ops 2025-11-14 06:34:05 -08:00
7c7d62b701 Support mm via single-dim strategy 2025-11-14 06:33:51 -08:00
b5c4ccc570 document things 2025-11-14 06:33:04 -08:00
fb0a066b74 [DTensor] add register_single_dim_strategy
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977ea86d849b84d453408109cc4f602019c4d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167677
2025-11-14 06:32:05 -08:00
475b8fd76c [DTensor] Fix mypy on register_op_strategy
ghstack-source-id: 59ef401df5d190a4c7611a327779ba00ba2a8d7c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167673
2025-11-12 12:39:32 -08:00
4 changed files with 434 additions and 15 deletions

View File

@ -32,6 +32,7 @@ from torch.distributed.tensor._ops._einsum_strategy import (
)
from torch.distributed.tensor._ops.utils import (
register_op_strategy,
register_single_dim_strategy,
replicate_op_strategy,
)
from torch.distributed.tensor.debug import CommDebugMode
@ -655,5 +656,202 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
TestStrategyHashing,
)
class TestSingleDimStrategy(DTensorTestBase):
@with_comms
def test_register_single_dim_strategy_replaces_existing_rule(self):
"""
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_mm_like_strategy,
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test a specific input sharding combination
lhs_placement = (Shard(1),)
rhs_placement = (Shard(0),)
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
# Create the OpSchema for mm operation
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get the strategies from the old mm_like_strategy (what was used before)
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
# Get the strategies from the new register_single_dim_strategy approach
# First, we need to get the single dim strategy function
def mm_single_dim_strategy_func(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Now expand it to full strategy using the same logic as register_single_dim_strategy
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
new_style_strategy = OpStrategy(all_strategies)
# Verify that both strategies produce the same set of shardings
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
self.assertEqual(
old_strategy_set,
new_strategy_set,
"Old and new strategies should produce the same shardings",
)
# Verify that the registration actually works by checking the propagator
propagator = DTensor._op_dispatcher.sharding_propagator
# Save the original strategy if it exists
original_strategy = None
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
try:
# Register a custom single-dim strategy
@register_single_dim_strategy(torch.ops.aten.mm.default)
def custom_mm_single_dim_strategy(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Verify the strategy was registered
self.assertIn(
torch.ops.aten.mm.default,
propagator.op_strategy_funcs,
"Strategy should be registered after calling register_single_dim_strategy",
)
# Verify it replaced any existing rule
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
self.assertIsNotNone(
registered_func, "Registered strategy function should not be None"
)
# Test that the registered strategy produces valid output
result_strategy = registered_func(op_schema)
self.assertIsInstance(
result_strategy, OpStrategy, "Result should be an OpStrategy"
)
self.assertGreater(
len(result_strategy.strategies),
0,
"Strategy should contain at least one OpSpec",
)
finally:
# Restore original strategy if it existed
if original_strategy is not None:
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
original_strategy
)
else:
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
# Clear the cache
propagator.propagate_op_sharding.cache.cache_clear()
@with_comms
def test_single_dim_strategy_shardings_match_full_strategy(self):
"""
Verify that the shardings produced by a single-dim strategy match those produced
by the full strategy implementation.
"""
from torch.distributed.tensor._ops._matrix_ops import (
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test multiple input sharding combinations
mm_combs = (
(Shard(0), Replicate()),
(Replicate(), Shard(1)),
(Shard(1), Shard(0)),
(Replicate(), Replicate()),
)
for lhs_placement, rhs_placement in mm_combs:
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get single-dim strategies
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Expand to full strategy (mimicking what register_single_dim_strategy does)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
expanded_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
expanded_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
# Verify that for the given input shardings, we can find a matching strategy
# with zero redistribute cost
found_zero_cost_strategy = False
for strategy in expanded_strategies:
if strategy.input_specs == (lhs_spec, rhs_spec):
# This strategy should have zero redistribute cost since inputs match
found_zero_cost_strategy = True
# In a real strategy, redistribute costs would be computed
# Here we just verify the structure is correct
self.assertEqual(
len(strategy.input_specs),
2,
"MM should have exactly 2 input specs",
)
self.assertIsNotNone(
strategy.output_specs, "Output spec should not be None"
)
break
self.assertTrue(
found_zero_cost_strategy,
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
)
if __name__ == "__main__":
run_tests()

View File

@ -23,6 +23,7 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
prod,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
@ -237,10 +238,130 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("i,i->", mesh, op_schema)
@register_op_strategy(aten.mm.default)
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# @register_op_strategy(aten.mm.default)
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# mesh = op_schema.get_mesh_from_args()
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
from ._einsum_strategy import EinsumDims
def gen_single_dim_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> list[Placement]:
"""
Generate a strategy list for the ops that follow einsum style notation.
In principle, each mesh dim is independent of other device mesh dim when we
generate strategies. So we generate strategy over each device mesh dim and
do product combination on all mesh dims. We basically follow the below rule
for each device mesh dim:
1. Shard on contracting dim: When both inputs shard on contracting dim over
the same device dim. The result will be Partial over that device dim.
2. Shard on noncontracting dim:
2.1: Shard on batch dim: output, both inputs all should shard on batch
dim.
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
input should shard on this free dim.
3. Linearity (Partial): If enabled, set Partial on output and inputs over
the same device mesh dim.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
all_mesh_dim_strategies = []
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
strategies_over_one_mesh_dim = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
strategies_over_one_mesh_dim.append(placement_list)
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
# Contracting dim can shard on same device axis for both inputs. This
# results in the output being Partial on that device axis. For example:
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim_output = output_dim.index(lhs_dim)
lhs_free_dim_input = input_dims[0].index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim_output),
Shard(lhs_free_dim_input),
Replicate(),
]
strategies_over_one_mesh_dim.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim_output = output_dim.index(rhs_dim)
rhs_free_dim_input = input_dims[1].index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim_output),
Replicate(),
Shard(rhs_free_dim_input),
]
strategies_over_one_mesh_dim.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for _ in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
# generate strategies for entire mesh
# all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
# strategy_combs = itertools.product(*all_mesh_dim_strategies)
# all_strategies = []
# for strategy_comb in strategy_combs:
# spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
# strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
# all_strategies.append(strat)
# return OpStrategy(all_strategies)
return strategies_over_one_mesh_dim
@register_single_dim_strategy(aten.mm.default)
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
self_strategy, mat2_strategy = op_schema.args_schema
if not isinstance(self_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
if not isinstance(mat2_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
# generate all possible strategies for mm
mesh = op_schema.get_mesh_from_args()
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
@register_op_strategy(aten.addmm.default)

View File

@ -41,6 +41,8 @@ from torch.distributed.tensor.placement_types import (
aten = torch.ops.aten
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
# For ops with a single tensor input, we perform a 1:1 mapping such that
# for each strategy that the input supports, we create a corresponding strategy.
@ -97,6 +99,28 @@ register_op_strategy(
)(propagate_single_input_strategy)
"""
WHC- equal_strategy is an example baking an optimization into the sharding rule.
The unoptimized equal strategy (for one mesh dim) should look like this
S, S -> S
R, R -> R
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
And this should be expanded to the full mesh.
But what this rule actually does is
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
Why? becuase converting the other arg from R->S is cheaper than converting S->R
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
- then adjust the placements from partial to replicate since we don't support partial in equal
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
"""
@register_op_strategy(
[
aten.equal.default,
@ -140,6 +164,19 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
return equal_strategy
"""
WHC
seems like we could replace this with single-mesh strategy
S->S
R->R
P->R
The P->R thing is odd, but makes sense:
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
"""
@register_op_strategy(
[
aten.empty_like.default,
@ -481,6 +518,19 @@ def replicate_tensor_dim(
)
"""
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
seems very simple to write this way
assert input, src same ndim
for i in range(input.ndim):
if i != slice_dim:
Shard(i), Shard(i) -> Shard(i)
"""
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
# 1. number of dimensions in input and src need to match.

View File

@ -4,8 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from typing import cast, Optional, Union
import torch
from torch._prims_common import DimsSequenceType, DimsType
@ -30,10 +29,6 @@ from torch.distributed.tensor.placement_types import (
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
# convenient wrapper to register sharding propagation rules
def register_prop_rule(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
@ -54,11 +49,61 @@ def register_prop_rule(
return wrapper
def register_op_strategy(
op, schema_info=None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# pyre-fixme[2]: Parameter must be annotated.
def register_single_dim_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[
[Callable[[OpSchema], list[Placement]]], Callable[[OpSchema], StrategyType]
]:
"""
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
to cover all the mesh dims present in the runtime inputs.
"""
def expanded_registration_wrapper(
single_dim_strategy: Callable[[OpSchema], list[Placement]],
) -> Callable[[OpSchema], StrategyType]:
def _expanded_strategy(op_schema: OpSchema) -> StrategyType:
"""
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
sharding types used by inputs.
"""
inputs_strategy = op_schema.args_strategy
mesh = inputs_strategy[0].mesh
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
# TODO: filter out 'invalid' placements
# - ShardVar needs to say whether 'even sharding' is required or not
# copied from einsum strategy..
# TODO: identify differences between this and 'expand_' util
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
return OpStrategy(all_strategies)
# register_op_strategy returns another wrapper that actually does the strategy registration,
# we just add another layer of wrapping that expands the single_dim_strategy into a strategy that's
# compatible with register_op_strategy
register_op_strategy(op, schema_info)(_expanded_strategy)
return _expanded_strategy
return expanded_registration_wrapper
def register_op_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
@ -68,7 +113,9 @@ def register_op_strategy(
"memory_format",
]
def wrapper(impl):
def wrapper(
impl: Callable[[OpSchema], StrategyType],
) -> Callable[[OpSchema], StrategyType]:
if isinstance(op, list):
overloads = op
else:
@ -159,7 +206,10 @@ def prod(xs: Iterable[int]) -> int:
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
"""Check if the spec matches these criteria:
* any Shard placements in spec refer to valid tensor dims
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
"""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):