Compare commits

...

3 Commits

Author SHA1 Message Date
869e28a943 [DTensor] add register_single_dim_strategy
WIP for now, not ready to land

Experimenting with the idea of decomposing op strategies into
 - a core function that proposes a single-mesh-dim strategy
 - automatic expansion to the mesh inside the registration mechanism

Also, plan to add a 'ShardingPlacholder' to use for writing
single-mesh-dim strategies in a way that can be expanded at runtime to
any type of sharding discovered in the inputs.

For now, this relies on full enumeration of the single-dim strategy onto
the full mesh, and full enumeration of the combinations of different
sharding placements discovered in the input, but we should be able to
replace this with an algorithm to expand iteratively following the path
of lowest redistribution cost.

[ghstack-poisoned]
2025-11-12 13:14:49 -08:00
190e6387a3 Update on "[DTensor] Fix mypy on register_op_strategy"
cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
2025-11-12 12:39:32 -08:00
374a2c16e2 [DTensor] Fix mypy on register_op_strategy
[ghstack-poisoned]
2025-11-12 12:37:56 -08:00

View File

@ -4,8 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Callable, Iterable, Sequence
from typing import cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from typing import cast, Optional, Union
import torch
from torch._prims_common import DimsSequenceType, DimsType
@ -30,10 +29,6 @@ from torch.distributed.tensor.placement_types import (
)
_T = TypeVar("_T")
_P = ParamSpec("_P")
# convenient wrapper to register sharding propagation rules
def register_prop_rule(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
@ -54,11 +49,56 @@ def register_prop_rule(
return wrapper
def register_op_strategy(
op, schema_info=None
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# pyre-fixme[2]: Parameter must be annotated.
def register_single_dim_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], list[OpSpec]]], Callable[[OpSchema], StrategyType]]:
"""
Registers a simplified op strategy that only considers a single mesh dim, taking care to expand it
to cover all the mesh dims present in the runtime inputs.
"""
def expanded_registration_wrapper(
single_dim_strategy: Callable[[OpSchema], list[OpSpec]],
) -> Callable[[OpSchema], StrategyType]:
def _expanded_strategy(op_schema: OpSchema) -> StrategyType:
"""
Expands the single_mesh_dim impl across all mesh dims, and expands ShardingPlacholder into all
sharding types used by inputs.
"""
inputs_strategy = op_schema.args_strategy
mesh = inputs_strategy[0].mesh
strategies_over_one_mesh_dim = single_dim_strategy(op_schema)
# copied from einsum strategy..
# TODO: identify differences between this and 'expand_' util
# TODO: handle 'ShardingPlaceholder' expansion (doesn't exist yet)
all_mesh_dim_strategies = [strategies_over_one_mesh_dim] * mesh.ndim
strategy_combs = itertools.product(*all_mesh_dim_strategies)
all_strategies = []
for strategy_comb in strategy_combs:
spec_list = [
DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)
]
all_strategies.append(
OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:])
)
return OpStrategy(all_strategies)
# register_op_strategy returns another wrapper that actually does the strategy registration,
# we just add another layer of wrapping that expands the single_dim_strategy into a strategy that's
# compatible with register_op_strategy
register_op_strategy(op, schema_info)(_expanded_strategy)
return _expanded_strategy
return expanded_registration_wrapper
def register_op_strategy(
op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> Callable[[Callable[[OpSchema], StrategyType]], Callable[[OpSchema], StrategyType]]:
# For every ATen op that accepts any args in this list,
# the arg itself can impact the strides (and potentially the sharding strategy)
# of the output tensor.
@ -68,7 +108,9 @@ def register_op_strategy(
"memory_format",
]
def wrapper(impl):
def wrapper(
impl: Callable[[OpSchema], StrategyType],
) -> Callable[[OpSchema], StrategyType]:
if isinstance(op, list):
overloads = op
else: