Compare commits

...

7 Commits

Author SHA1 Message Date
653c0ecf35 WIP add pointwise strategy 2025-11-17 21:18:29 -08:00
057434a442 claude generate tests 2025-11-17 21:03:07 -08:00
9fd0af1c3b Notes on tensor_ops 2025-11-17 21:03:07 -08:00
53305e5379 Support mm via single-dim strategy 2025-11-17 21:03:07 -08:00
ea5f2aceda document things 2025-11-17 21:01:28 -08:00
83557a528f [DTensor] add register_single_dim_strategy
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

ghstack-source-id: ed6977ea86d849b84d453408109cc4f602019c4d
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167677
2025-11-14 15:27:09 -08:00
54d05a0874 [DTensor] Fix mypy on register_op_strategy
ghstack-source-id: 59ef401df5d190a4c7611a327779ba00ba2a8d7c
Pull Request resolved: https://github.com/pytorch/pytorch/pull/167673
2025-11-14 15:27:09 -08:00
5 changed files with 496 additions and 19 deletions

View File

@ -32,6 +32,7 @@ from torch.distributed.tensor._ops._einsum_strategy import (
)
from torch.distributed.tensor._ops.utils import (
register_op_strategy,
register_single_dim_strategy,
replicate_op_strategy,
)
from torch.distributed.tensor.debug import CommDebugMode
@ -655,5 +656,202 @@ TestStrategyHashingWithLocalTensor = create_local_tensor_test_class(
TestStrategyHashing,
)
class TestSingleDimStrategy(DTensorTestBase):
@with_comms
def test_register_single_dim_strategy_replaces_existing_rule(self):
"""
Test that calling register_single_dim_strategy works and replaces an existing registered rule.
"""
from torch.distributed.tensor._ops._matrix_ops import (
_mm_like_strategy,
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test a specific input sharding combination
lhs_placement = (Shard(1),)
rhs_placement = (Shard(0),)
lhs_spec = DTensorSpec(mesh, lhs_placement, lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, rhs_placement, rhs_tensor_meta)
# Create the OpSchema for mm operation
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get the strategies from the old mm_like_strategy (what was used before)
old_style_strategy = _mm_like_strategy("mk,kn->mn", mesh, op_schema)
# Get the strategies from the new register_single_dim_strategy approach
# First, we need to get the single dim strategy function
def mm_single_dim_strategy_func(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Now expand it to full strategy using the same logic as register_single_dim_strategy
single_dim_strategies = mm_single_dim_strategy_func(op_schema)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
new_style_strategy = OpStrategy(all_strategies)
# Verify that both strategies produce the same set of shardings
old_strategy_set = {str(strategy) for strategy in old_style_strategy.strategies}
new_strategy_set = {str(strategy) for strategy in new_style_strategy.strategies}
self.assertEqual(
old_strategy_set,
new_strategy_set,
"Old and new strategies should produce the same shardings",
)
# Verify that the registration actually works by checking the propagator
propagator = DTensor._op_dispatcher.sharding_propagator
# Save the original strategy if it exists
original_strategy = None
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
original_strategy = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
try:
# Register a custom single-dim strategy
@register_single_dim_strategy(torch.ops.aten.mm.default)
def custom_mm_single_dim_strategy(op_schema: OpSchema):
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Verify the strategy was registered
self.assertIn(
torch.ops.aten.mm.default,
propagator.op_strategy_funcs,
"Strategy should be registered after calling register_single_dim_strategy",
)
# Verify it replaced any existing rule
registered_func = propagator.op_strategy_funcs[torch.ops.aten.mm.default]
self.assertIsNotNone(
registered_func, "Registered strategy function should not be None"
)
# Test that the registered strategy produces valid output
result_strategy = registered_func(op_schema)
self.assertIsInstance(
result_strategy, OpStrategy, "Result should be an OpStrategy"
)
self.assertGreater(
len(result_strategy.strategies),
0,
"Strategy should contain at least one OpSpec",
)
finally:
# Restore original strategy if it existed
if original_strategy is not None:
propagator.op_strategy_funcs[torch.ops.aten.mm.default] = (
original_strategy
)
else:
if torch.ops.aten.mm.default in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[torch.ops.aten.mm.default]
# Clear the cache
propagator.propagate_op_sharding.cache.cache_clear()
@with_comms
def test_single_dim_strategy_shardings_match_full_strategy(self):
"""
Verify that the shardings produced by a single-dim strategy match those produced
by the full strategy implementation.
"""
from torch.distributed.tensor._ops._matrix_ops import (
gen_single_dim_einsum_strategies,
)
mesh = self.build_device_mesh()
# Create test inputs
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
# Test multiple input sharding combinations
mm_combs = (
(Shard(0), Replicate()),
(Replicate(), Shard(1)),
(Shard(1), Shard(0)),
(Replicate(), Replicate()),
)
for lhs_placement, rhs_placement in mm_combs:
lhs_spec = DTensorSpec(mesh, (lhs_placement,), lhs_tensor_meta)
rhs_spec = DTensorSpec(mesh, (rhs_placement,), rhs_tensor_meta)
op_schema = OpSchema(
torch.ops.aten.mm.default,
(
OpStrategy([OpSpec(lhs_spec)]),
OpStrategy([OpSpec(rhs_spec)]),
),
{},
)
# Get single-dim strategies
single_dim_strategies = gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
# Expand to full strategy (mimicking what register_single_dim_strategy does)
all_mesh_dim_strategies = [single_dim_strategies] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
expanded_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
expanded_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
# Verify that for the given input shardings, we can find a matching strategy
# with zero redistribute cost
found_zero_cost_strategy = False
for strategy in expanded_strategies:
if strategy.input_specs == (lhs_spec, rhs_spec):
# This strategy should have zero redistribute cost since inputs match
found_zero_cost_strategy = True
# In a real strategy, redistribute costs would be computed
# Here we just verify the structure is correct
self.assertEqual(
len(strategy.input_specs),
2,
"MM should have exactly 2 input specs",
)
self.assertIsNotNone(
strategy.output_specs, "Output spec should not be None"
)
break
self.assertTrue(
found_zero_cost_strategy,
f"Should find a strategy matching input shardings {lhs_placement}, {rhs_placement}",
)
if __name__ == "__main__":
run_tests()

View File

@ -23,6 +23,7 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
prod,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
@ -237,10 +238,119 @@ def dot_strategy(op_schema: OpSchema) -> OpStrategy:
return _mm_like_strategy("i,i->", mesh, op_schema)
@register_op_strategy(aten.mm.default)
def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# @register_op_strategy(aten.mm.default)
# def mm_strategy(op_schema: OpSchema) -> OpStrategy:
# mesh = op_schema.get_mesh_from_args()
# return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
from ._einsum_strategy import EinsumDims
def gen_single_dim_einsum_strategies(
equation: str,
mesh: DeviceMesh,
*,
linearity: bool = False,
) -> list[list[Placement]]:
"""
Generate a strategy list for the ops that follow einsum style notation.
In principle, each mesh dim is independent of other device mesh dim when we
generate strategies. So we generate strategy over each device mesh dim and
do product combination on all mesh dims. We basically follow the below rule
for each device mesh dim:
1. Shard on contracting dim: When both inputs shard on contracting dim over
the same device dim. The result will be Partial over that device dim.
2. Shard on noncontracting dim:
2.1: Shard on batch dim: output, both inputs all should shard on batch
dim.
2.2: Shard on lhs only dim or rhs only dim: both output and lhs or rhs
input should shard on this free dim.
3. Linearity (Partial): If enabled, set Partial on output and inputs over
the same device mesh dim.
"""
# parse einop equation and extract dims
input_dims, output_dim = EinsumDims.parse_equation(equation)
edims = EinsumDims.parse_dims(input_dims, output_dim)
# generate strategies for each mesh dim and do cartesian product for final strategy. E.g., for a 2D mesh, we can have [P(),R,R]
strategies_over_one_mesh_dim = []
# placement list stores placements of [output, input1, input2, ...]
# first we always have replicate all for inputs and output
placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1)
strategies_over_one_mesh_dim.append(placement_list)
# split batch dim
for batch_dim in edims.batch_dims:
output_batch_dim = output_dim.index(batch_dim)
placement_list = [Shard(output_batch_dim)]
for input_dim in input_dims:
input_batch_dim = input_dim.index(batch_dim)
placement_list.append(Shard(input_batch_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split contracting dim
for contracting_dim in edims.contracting_dims:
# Contracting dim can shard on same device axis for both inputs. This
# results in the output being Partial on that device axis. For example:
# bmk_{x},k_{x}n -> bmn{Ux} (becomes partial over device axis x)
placement_list = [Partial()]
for input_dim in input_dims:
input_contracting_dim = input_dim.index(contracting_dim)
placement_list.append(Shard(input_contracting_dim))
strategies_over_one_mesh_dim.append(placement_list)
# split lhs free dim
for lhs_dim in edims.lhs_out_only_dims:
lhs_free_dim_output = output_dim.index(lhs_dim)
lhs_free_dim_input = input_dims[0].index(lhs_dim)
# this means split the lhs input and output
# i.e. S(0), R -> S(0)
lhs_placement_list: list[Placement] = [
Shard(lhs_free_dim_output),
Shard(lhs_free_dim_input),
Replicate(),
]
strategies_over_one_mesh_dim.append(lhs_placement_list)
# split rhs free dim
for rhs_dim in edims.rhs_out_only_dims:
rhs_free_dim_output = output_dim.index(rhs_dim)
rhs_free_dim_input = input_dims[1].index(rhs_dim)
rhs_placement_list: list[Placement] = [
Shard(rhs_free_dim_output),
Replicate(),
Shard(rhs_free_dim_input),
]
strategies_over_one_mesh_dim.append(rhs_placement_list)
# linearity strategy
if linearity:
linearity_placement_list: list[Placement] = [Partial()]
for _ in input_dims:
linearity_placement_list.append(Partial())
strategies_over_one_mesh_dim.append(linearity_placement_list)
return strategies_over_one_mesh_dim
@register_single_dim_strategy(aten.mm.default)
def mm_single_dim_strategy(op_schema: OpSchema) -> list[Placement]:
self_strategy, mat2_strategy = op_schema.args_schema
if not isinstance(self_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(self_strategy)}")
if not isinstance(mat2_strategy, OpStrategy):
raise AssertionError(f"Expected OpStrategy, got {type(mat2_strategy)}")
# generate all possible strategies for mm
mesh = op_schema.get_mesh_from_args()
return _mm_like_strategy("mk,kn->mn", mesh, op_schema)
return gen_single_dim_einsum_strategies("mk,kn->mn", mesh)
@register_op_strategy(aten.addmm.default)

View File

@ -18,6 +18,7 @@ from torch.distributed.tensor._ops.utils import (
map_placements_after_broadcast,
normalize_dim,
register_op_strategy,
register_single_dim_strategy,
)
from torch.distributed.tensor.placement_types import (
Partial,
@ -488,6 +489,58 @@ def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType:
return pointwise_strategy(op_schema, linearity=linearity_type)
def single_mesh_dim_pointwise_strategy(op_schema: OpSchema, linearity: int = -1) -> list[list[Placement]]:
return single_mesh_dim_common_pointwise_strategy(op_schema.args_schema, linearity)
def single_mesh_dim_common_pointwise_strategy(
args_schema: Sequence[object],
linearity: int = -1,
scalar_tensor_idx: Optional[int] = None
) -> list[list[Placement]]:
"""
Common strategy for pointwise operations.
Args:
args_schema: Input arguments schema
linearity: depending on the operator, we support different types of linearity
-1: the operation does not support linearity
0: the unary operation that supports linearity, output propagates partial.
1: the binary operation supports add linearity, where it requires every operand
to be partial, output propagates partial.
2: the binary operation supports multiplicative linearity, where it requires
the primary operand to be partial, and the other operands to be replicate,
output propagates partial.
scalar_tensor_idx: Index of the Replicate scalar tensor for which we allow the mesh
to be different from the mesh of followed_strategy
"""
# handle broadcasting
common_shape = torch.broadcast_shapes(
*[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)]
)
placements_list = []
for i in range(len(common_shape)):
# Shard output dim i, and then shard the corresponding arguments if they have a corresponding (non broadcast) dim
shard_placements = [Shard(i)]
for arg in args_schema:
if isinstance(arg, OpStrategy):
common_dim_to_arg_dim = infer_broadcast_dims_map(common_shape, arg.shape)
if common_dim_to_arg_dim[i] >= 0:
shard_placements.append(Shard(common_dim_to_arg_dim[i]))
else:
shard_placements.append(Replicate())
placements_list.append(shard_placements)
if linearity > 0:
# TODO implement partial
# TODO: can the same op support both add and multiplicative linearity?
pass
# TODO: handle scalar_tensor_idx
return placements_list
def common_pointwise_strategy(
args_schema: Sequence[object],
followed_strategy: OpStrategy,
@ -623,11 +676,15 @@ for op in linear_pointwise_ops:
linear_pointwise_strategy
)
for op in pointwise_ops:
register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
pointwise_strategy
)
# for op in pointwise_ops:
# register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
# pointwise_strategy
# )
for op in pointwise_ops:
register_single_dim_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
single_mesh_dim_pointwise_strategy
)
# TODO: add all for_each ops
for_each_ops = [

View File

@ -42,6 +42,8 @@ from torch.fx.experimental.symbolic_shapes import statically_known_true
aten = torch.ops.aten
# WHC- i think anywhere this is used, we can replace it with a corresponding single-dim passthrough strategy
# (anyshard, replicate, partial can all pass through- and then expand that to the mesh dims later)
def propagate_single_input_strategy(op_schema: OpSchema) -> StrategyType:
# For ops with a single tensor input, we perform a 1:1 mapping such that
# for each strategy that the input supports, we create a corresponding strategy.
@ -98,6 +100,28 @@ register_op_strategy(
)(propagate_single_input_strategy)
"""
WHC- equal_strategy is an example baking an optimization into the sharding rule.
The unoptimized equal strategy (for one mesh dim) should look like this
S, S -> S
R, R -> R
P, P -> P * - this could work, i think, if we supported a Partial of boolean and reduction?
And this should be expanded to the full mesh.
But what this rule actually does is
- compare the two tensor args to equal- look at the strategies for each, which represent the I-O sharding relationship for the
op that produced those tensor args. Pick the one that has the strategy (OpSpec) with the most Shard() placements in it.
Why? becuase converting the other arg from R->S is cheaper than converting S->R
- start with the assumption that the 'equal' op has the same strategy as the op that produced its max-shard input
- then adjust the placements from partial to replicate since we don't support partial in equal
- finally, produce an OpSpec that only populates the 'output_specs' of OpSpec
TODO: why is it ok to populate only the output_specs of an OpSpec? Is it defined to mean that all input specs are the same as the output spec?
"""
@register_op_strategy(
[
aten.equal.default,
@ -141,6 +165,19 @@ def equal_strategy(op_schema: OpSchema) -> StrategyType:
return equal_strategy
"""
WHC
seems like we could replace this with single-mesh strategy
S->S
R->R
P->R
The P->R thing is odd, but makes sense:
* can't support P->P since it would be incorrect to create a new 'partial' tensor from ones, which would no longer be ones if we replicated them
* don't want to omit the support for input Partial becuase we'd force a replication on the input which would be wasteful
"""
@register_op_strategy(
[
aten.empty_like.default,
@ -489,6 +526,19 @@ def replicate_tensor_dim(
)
"""
WHC- example of a complicated 'follow your inputs' strategy that would be useful to try out as a simple rule
seems very simple to write this way
assert input, src same ndim
for i in range(input.ndim):
if i != slice_dim:
Shard(i), Shard(i) -> Shard(i)
"""
@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2))
def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType:
# 1. number of dimensions in input and src need to match.

View File

@ -4,8 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from typing import cast, Optional, Union
import torch
from torch._prims_common import DimsSequenceType, DimsType
@ -28,10 +27,7 @@ from torch.distributed.tensor.placement_types import (
Replicate,
Shard,
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
# from torch.testing._internal.distributed._tensor.common_dtensor import redistribute
# convenient wrapper to register sharding propagation rules
@ -54,11 +50,69 @@ def register_prop_rule(
return wrapper
def register_op_strategy(
op, schema_info=None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# pyre-fixme[2]: Parameter must be annotated.
def _expand_single_dim_strategy_to_mesh(single_dim_strategy: Callable[[OpSchema], list[list[Placement]]]) -> Callable[[OpSchema], StrategyType]:
"""
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
sharding types used by inputs.
"""
def expanded_strategy(op_schema: OpSchema) -> StrategyType:
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
inputs_strategy = op_schema.args_strategy
mesh = inputs_strategy[0].mesh
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
# TODO: add Replicate since its implicit in single_dim strategies
# TODO: filter out 'invalid' placements
# - ShardVar needs to say whether 'even sharding' is required or not
# copied from einsum strategy..
# TODO: identify differences between this and 'expand_' util
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
arg_specs = spec_list[1:]
src_strategies = [s for s in op_schema.args_schema if isinstance(s, OpStrategy)]
assert len(arg_specs) == len(src_strategies), "expected one src strategy per arg spec"
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:], redistribute_cost=[
generate_redistribute_costs(src_strategy, arg_spec) for (src_strategy, arg_spec) in zip(src_strategies, arg_specs)
])
)
return OpStrategy(all_strategies)
return expanded_strategy
def register_single_dim_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[
[Callable[[OpSchema], list[list[Placement]]]], Callable[[OpSchema], list[list[Placement]]]
]:
"""
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
to cover all the mesh dims present in the runtime inputs.
"""
def expanded_registration_wrapper(
single_dim_strategy: Callable[[OpSchema], list[list[Placement]]],
) -> Callable[[OpSchema], list[list[Placement]]]:
_expanded_strategy = _expand_single_dim_strategy_to_mesh(single_dim_strategy)
register_op_strategy(op, schema_info)(_expanded_strategy)
return single_dim_strategy
return expanded_registration_wrapper
def register_op_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
@ -68,7 +122,9 @@ def register_op_strategy(
"memory_format",
]
def wrapper(impl):
def wrapper(
impl: Callable[[OpSchema], StrategyType],
) -> Callable[[OpSchema], StrategyType]:
if isinstance(op, list):
overloads = op
else:
@ -159,7 +215,10 @@ def prod(xs: Iterable[int]) -> int:
def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool:
"""Check if the shape is shardable according to the spec."""
"""Check if the spec matches these criteria:
* any Shard placements in spec refer to valid tensor dims
* no empty local tensors (uneven sharding OK, as long as last rank has >0 size)
"""
# number of shards in each tensor dimension
shards_map = [1] * len(shape)
for i, placement in enumerate(spec.placements):
@ -225,6 +284,9 @@ def infer_broadcast_dims_map(
) -> list[int]:
# infer the broadcast dims map, where it maps from the common shape dim to the input shape dim
# this is aligned with the broadcast semantics
# e.g. if common_shape = [1, 2, 3, 4] and input_shape = [2, 3, 4],
# broadcast_dims_map will be [-1, 0, 1, 2]
# meaning that dim 0 in the output has no mapping to the input, and dim 1 in the output maps to dim 0 in the input
common_ndim = len(common_shape)
input_ndim = len(input_shape)
broadcast_dims_map = [-1] * common_ndim