Compare commits

...

3 Commits

Author SHA1 Message Date
1919e33c25 Support implicit strategy registration 2025-07-16 14:56:51 -07:00
5528bf3117 fix pytree for TupleStrategy 2025-07-16 14:49:22 -07:00
c605b9dc3a (1/N) support of replication fallback strategy
fix per fb

create a test function to check if strategy output identical strategy

nit

use expand_to_full_mesh_op_strategy to build replicate strategy

lint

fix flatten function
2025-07-16 12:36:04 -07:00
4 changed files with 545 additions and 64 deletions

View File

@ -1,18 +1,62 @@
# Owner(s): ["oncall: distributed"]
import functools
import itertools
import random
from contextlib import contextmanager
from copy import deepcopy
from itertools import chain
from unittest.mock import patch
import numpy as np
import torch
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
from torch.distributed.tensor import (
DeviceMesh,
distribute_tensor,
DTensor,
init_device_mesh,
Partial,
Placement,
Replicate,
Shard,
)
from torch.distributed.tensor._collective_utils import redistribute_cost
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import OpSchema, OpSpec, OpStrategy
from torch.distributed.tensor._op_schema import (
OpSchema,
OpSpec,
OpStrategy,
RuntimeSchemaInfo,
StrategyType,
)
from torch.distributed.tensor._ops._einsum_strategy import (
EinsumDims,
gen_einsum_strategies,
)
from torch.distributed.tensor._ops.utils import (
generate_redistribute_costs,
implicit_strategy_context,
OpStrategyPack,
register_op_strategy,
replicate_op_strategy,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorOpTestBase
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorOpTestBase,
DTensorTestBase,
with_comms,
)
try:
from torch.utils._cxx_pytree import tree_leaves
except ImportError:
from torch.utils._pytree import tree_leaves # type: ignore[no-redef]
def extract_tensor_meta(t) -> TensorMeta:
return TensorMeta(t.shape, t.stride(), t.dtype)
class TestEinsumDims(TestCase):
@ -131,9 +175,6 @@ class TestEinsumStrategies(DTensorOpTestBase):
class TestCostModel(DTensorOpTestBase):
def _extract_tensor_meta(self, t) -> TensorMeta:
return TensorMeta(t.shape, t.stride(), t.dtype)
@property
def world_size(self) -> int:
return 4
@ -145,7 +186,7 @@ class TestCostModel(DTensorOpTestBase):
partial_placement = (Partial(),)
global_tensor = torch.randn(10, 10)
global_tensor_meta = self._extract_tensor_meta(global_tensor)
global_tensor_meta = extract_tensor_meta(global_tensor)
# shard spec
shard_spec = DTensorSpec(mesh_1d, shard_placement, global_tensor_meta)
@ -180,9 +221,9 @@ class TestCostModel(DTensorOpTestBase):
partial_placement = (Partial(),)
shard1_placement = (Shard(1),)
shard0_tensor_meta = self._extract_tensor_meta(torch.randn(8))
partial_tensor_meta = self._extract_tensor_meta(torch.randn(50, 6))
shard1_tensor_meta = self._extract_tensor_meta(torch.randn(6, 8))
shard0_tensor_meta = extract_tensor_meta(torch.randn(8))
partial_tensor_meta = extract_tensor_meta(torch.randn(50, 6))
shard1_tensor_meta = extract_tensor_meta(torch.randn(6, 8))
# shard spec
shard0_spec = DTensorSpec(mesh, shard0_placement, shard0_tensor_meta)
@ -226,7 +267,7 @@ class TestCostModel(DTensorOpTestBase):
partial_placement = (Partial(), Partial())
global_tensor = torch.randn(8, 8)
global_tensor_meta = self._extract_tensor_meta(global_tensor)
global_tensor_meta = extract_tensor_meta(global_tensor)
# shard spec
shard_spec = DTensorSpec(mesh_2d, shard_placement, global_tensor_meta)
@ -255,8 +296,8 @@ class TestCostModel(DTensorOpTestBase):
mesh = self.build_device_mesh()
lhs_tensor = torch.randn(6, 8)
rhs_tensor = torch.randn(8, 12)
lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
mm_combs = (
(Shard(0), Replicate()),
@ -301,8 +342,8 @@ class TestCostModel(DTensorOpTestBase):
mesh = self.build_device_mesh()
lhs_tensor = torch.randn(8, 6, 8)
rhs_tensor = torch.randn(8, 8, 12)
lhs_tensor_meta = self._extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = self._extract_tensor_meta(rhs_tensor)
lhs_tensor_meta = extract_tensor_meta(lhs_tensor)
rhs_tensor_meta = extract_tensor_meta(rhs_tensor)
bmm_combs = (
(Shard(0), Shard(0)),
@ -343,5 +384,317 @@ class TestCostModel(DTensorOpTestBase):
self.assertFalse(output_sharding.needs_redistribute)
# -------------Test op strategy registration-------------
# custom op without List[Tensor] as input
# reference: https://docs.pytorch.org/docs/stable/library.html#torch.library.register_autograd
@torch.library.custom_op("mylib::numpy_sin", mutates_args=())
def numpy_sin(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_np = x.cpu().numpy()
y_np = y.cpu().numpy()
out_np = np.sin(x_np) + np.sin(y_np)
return torch.from_numpy(out_np).to(device=x.device)
def setup_context(ctx, inputs, output):
(x, y) = inputs
ctx.save_for_backward(x, y)
def backward(ctx, grad):
(x, y) = ctx.saved_tensors
return grad * x.cos(), grad * y.cos()
@numpy_sin.register_fake
def _fw(x, y):
return torch.empty_like(x)
torch.library.register_autograd(
"mylib::numpy_sin", backward, setup_context=setup_context
)
# custom op with List[Tensor] as input
@torch.library.custom_op("mylib::numpy_tuple_sin", mutates_args=())
def numpy_tuple_sin(
x: torch.Tensor, y: list[torch.Tensor], z: torch.Tensor
) -> torch.Tensor:
x_np = x.cpu().numpy()
y_np = [i.cpu().numpy() for i in y]
z_np = z.cpu().numpy()
out_np = np.sin(x_np) + np.sin(z_np) + sum(np.sin(i) for i in y_np)
return torch.from_numpy(out_np).to(device=x.device)
def setup_tuple_context(ctx, inputs, output):
(x, y, z) = inputs
ctx.save_for_backward(x, y, z)
def tuple_backward(ctx, grad):
(x, y, z) = ctx.saved_tensors
return grad * x.cos(), [grad * i.cos() for i in y], grad * z.cos()
@numpy_tuple_sin.register_fake
def _fw_tuple(x, y, z):
return torch.empty_like(x)
torch.library.register_autograd(
"mylib::numpy_tuple_sin", tuple_backward, setup_context=setup_tuple_context
)
def manual_strategy(
op_schema: OpSchema, output_placement: list[Placement]
) -> StrategyType:
select_strategy_x = op_schema.args_schema[0]
select_strategy_y = op_schema.args_schema[1]
assert isinstance(select_strategy_x, OpStrategy)
assert isinstance(select_strategy_y, OpStrategy)
new_placement = output_placement
new_input_specs = DTensorSpec(
mesh=select_strategy_x.mesh,
placements=tuple(new_placement),
tensor_meta=select_strategy_x.strategies[0].output_spec.tensor_meta,
)
output_spec = DTensorSpec(
mesh=select_strategy_x.mesh,
placements=tuple(new_placement),
)
output_strategy = OpStrategy([])
output_strategy.strategies.append(
OpSpec(
output_specs=output_spec,
input_specs=[new_input_specs, new_input_specs],
redistribute_cost=[
generate_redistribute_costs(select_strategy_x, new_input_specs),
generate_redistribute_costs(select_strategy_y, new_input_specs),
],
)
)
return output_strategy
@contextmanager
def op_strategy_context(op_overload, strategy_func, schema_info=None):
"""
Context manager for setting and clearing op strategies.
Args:
op_overload: The operator overload to set or clear the strategy for.
strategy_func: The strategy function to set for the operator overload.
schema_info: Optional schema information for the operator overload.
Yields:
None
"""
propagator = DTensor._op_dispatcher.sharding_propagator
try:
# register the op strategy
register_op_strategy(op_overload, schema_info=schema_info)(strategy_func)
yield
finally:
# clear this op strategy cache
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
propagator.propagate_op_sharding.cache.cache_clear()
def detect_exists_identical_opspec(*args, op, mesh, strategy_function) -> bool:
"""
Given sample input args, detect if identical OpSpecs exists under the same
OpStrategy.
"""
tree_args = tree_leaves(args)
# metadata for each argument
arg_tensor_metadata = [extract_tensor_meta(i) for i in args]
# possible combination of placements for each arg
arg_placement_comb = []
for i in tree_args:
if isinstance(i, torch.Tensor):
# possible placement choice for argument i
placement_choices = (Replicate(), *[Shard(i) for i in range(i.ndim)])
# expand placement choice into full Placements for argument i
arg_placement_comb.append(
list(itertools.product(placement_choices, repeat=mesh.ndim))
)
random.shuffle(arg_placement_comb[-1])
arg_opspec_list = []
for idx, arg_placement in enumerate(arg_placement_comb):
arg_opspec_list.append([])
for placement in arg_placement:
arg_opspec_list[idx].append(
OpSpec(
output_specs=DTensorSpec(
mesh, placement, tensor_meta=arg_tensor_metadata[idx]
)
)
)
op_schema = OpSchema(
op,
args_schema=(tuple(OpStrategy(i) for i in arg_opspec_list)),
kwargs_schema={},
)
with op_strategy_context(op, strategy_function):
output_strategy = strategy_function(op_schema)
# OpSpec doesn't have hashing, convert to str to compare
output_strategy_str_list = [
str(j) for i in tree_leaves(output_strategy) for j in i.strategies
]
return len(output_strategy_str_list) == len(set(output_strategy_str_list))
class DistTensorReplicateStrategyRegistrationTest(DTensorTestBase):
@with_comms
def test_replicate_strategy_placement(self):
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin
with op_strategy_context(test_op.default, replicate_op_strategy):
input_x = torch.randn([8, 16, 32], device=self.device_type)
input_y = torch.randn([8, 16, 32], device=self.device_type)
output = test_op(input_x, input_y)
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
input_y_dt = distribute_tensor(input_y, mesh, [Shard(0), Shard(1)])
output_dt = test_op(input_x_dt, input_y_dt)
self.assertEqual(output_dt.full_tensor(), output)
self.assertEqual(output_dt.placements, [Replicate(), Replicate()])
@with_comms
def test_no_identical_strategies(self):
# Test if there are any identical OpSpecs in the output strategy.
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin
x = torch.randn([8, 16, 8], device=self.device_type)
y = torch.randn([8, 16, 8], device=self.device_type)
self.assertTrue(
detect_exists_identical_opspec(
x,
y,
op=test_op.default,
mesh=mesh,
strategy_function=replicate_op_strategy,
)
)
@with_comms
@patch(
"torch.distributed.tensor._sharding_prop.ShardingPropagator._select_strategy"
)
def test_replicate_strategy_cost(self, mock_select_strategy):
costs_from__select_strategy: list[float] = []
def mock_select_func(strategy):
"""function copied from _select_strategy but with cost capturing"""
nonlocal costs_from__select_strategy
if len(strategy.strategies) == 1:
costs_from__select_strategy = strategy.strategies[0].redistribute_cost
return strategy.strategies[0]
op_spec_costs: list[float] = []
for op_spec in strategy.strategies:
assert op_spec.redistribute_cost is not None, (
"must set redistribute cost each OpSpec!"
)
costs_from__select_strategy.append(op_spec.redistribute_cost)
redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost))
op_spec_costs.append(redistribute_cost)
return strategy.strategies[op_spec_costs.index(min(op_spec_costs))]
mock_select_strategy.side_effect = mock_select_func
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin
input_x = torch.randn([8, 16, 8], device=self.device_type)
input_y = torch.randn([8, 16, 8], device=self.device_type)
# manual write the strategy to force replicate placement
manual_replicated_strategy = functools.partial(
manual_strategy, output_placement=[Replicate(), Replicate()]
)
placement_choice_pool = [Replicate(), Shard(0), Shard(1)]
for shard_a in placement_choice_pool:
for shard_b in placement_choice_pool:
input_x_dt = distribute_tensor(input_x, mesh, [shard_a, shard_b])
input_y_dt = distribute_tensor(input_y, mesh, [shard_b, shard_a])
# generate expected cost from manual strategy:
costs_from__select_strategy.clear()
with op_strategy_context(test_op.default, manual_replicated_strategy):
expect_output = test_op(input_x_dt, input_y_dt)
expected_cost = deepcopy(costs_from__select_strategy)
# generate cost from default fallback strategy:
costs_from__select_strategy.clear()
with op_strategy_context(test_op.default, replicate_op_strategy):
fallback_output = test_op(input_x_dt, input_y_dt)
fallback_cost = deepcopy(costs_from__select_strategy)
self.assertEqual(
expect_output.full_tensor(), fallback_output.full_tensor()
)
self.assertEqual(expected_cost, fallback_cost)
@with_comms
def test_tuple_replicate_strategy_placement(self):
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_tuple_sin
with op_strategy_context(
test_op.default,
replicate_op_strategy,
schema_info=RuntimeSchemaInfo(needs_pytree=True),
):
input_x = torch.randn([8, 16, 8], device=self.device_type)
input_y = [
torch.randn([8, 16, 8], device=self.device_type) for _ in range(3)
]
input_z = torch.randn([8, 16, 8], device=self.device_type)
output = test_op(input_x, input_y, input_z)
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
input_y_dt = [
distribute_tensor(i, mesh, [Shard(1), Shard(1)]) for i in input_y
]
input_z_dt = distribute_tensor(input_z, mesh, [Shard(1), Shard(0)])
output_dt = test_op(input_x_dt, input_y_dt, input_z_dt)
self.assertEqual(output_dt.full_tensor(), output)
self.assertEqual(output_dt.placements, [Replicate(), Replicate()])
class ImplicitRegistrationTest(DTensorTestBase):
@with_comms
def test_implicit_registration(self):
mesh = init_device_mesh(self.device_type, (2, self.world_size // 2))
test_op = torch.ops.mylib.numpy_sin
input_x = torch.randn([8, 16, 8], device=self.device_type)
input_y = torch.randn([8, 16, 8], device=self.device_type)
input_x_dt = distribute_tensor(input_x, mesh, [Shard(0), Shard(1)])
input_y_dt = distribute_tensor(input_y, mesh, [Shard(1), Shard(0)])
# 1. test_op strategy not registered test
with self.assertRaisesRegex(
RuntimeError,
"Operator mylib.numpy_sin.default does not have a sharding strategy registered",
):
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
# 2. test_op strategy implicitly registered under context manager
with implicit_strategy_context():
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
# 3. remove registration after exiting the context manager
with self.assertRaisesRegex(
RuntimeError,
"Operator mylib.numpy_sin.default does not have a sharding strategy registered",
):
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
# 4. runtime specify test_op strategy
# TODO(zpcore): try with a different universal strategy once we have
with implicit_strategy_context(
[OpStrategyPack(test_op.default, replicate_op_strategy, schema_info=None)]
):
self._test_op_on_dtensor(test_op, input_x_dt, input_y_dt)
if __name__ == "__main__":
run_tests()

View File

@ -13,9 +13,15 @@ from torch.distributed.tensor.placement_types import Placement
try:
from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec
from torch.utils._cxx_pytree import (
register_pytree_node,
tree_leaves,
tree_map_only,
TreeSpec,
)
except ImportError:
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
register_pytree_node,
tree_leaves,
tree_map_only,
TreeSpec,
@ -217,6 +223,13 @@ class TupleStrategy(StrategyType):
return f"TupleStrategy({child_strategies_str})"
register_pytree_node(
TupleStrategy,
lambda node: (node.children, None),
lambda children, _: TupleStrategy(tuple(children)),
)
@dataclass
class RuntimeSchemaInfo:
"""

View File

@ -4,6 +4,7 @@ import functools
import itertools
import operator
from collections.abc import Iterable, Sequence
from contextlib import contextmanager
from typing import Callable, cast, Optional, TypeVar, Union
from typing_extensions import ParamSpec
@ -19,6 +20,7 @@ from torch.distributed.tensor._op_schema import (
OutputSharding,
PlacementList,
RuntimeSchemaInfo,
StrategyType,
)
from torch.distributed.tensor.device_mesh import DeviceMesh
from torch.distributed.tensor.placement_types import (
@ -95,6 +97,33 @@ def register_op_strategy(
return wrapper
def replicate_op_strategy(op_schema: OpSchema) -> StrategyType:
"""
Fallback strategy all use Replication()
"""
inputs_strategy = op_schema.args_strategy
# TODO(zpcore): handle kwarg_inputs_strategy
# kwarg_inputs_strategy = op_schema.kwargs_schema
output_type = [str(ret.type) for ret in op_schema.op._schema.returns]
# TODO(zpcore): Confirm if view op can be handle properly or not. Prevent
# handling view ops until confirmed.
if op_schema.op.is_view:
raise RuntimeError(
"fallback strategy is unable to handle view ops until confirmed"
)
if "List[Tensor]" in output_type:
raise RuntimeError(
"fallback strategy is unable to handle ops with List[Tensor] output "
"because size of the list may depend on the op's input value"
)
mesh = inputs_strategy[0].mesh
dim_sharding: PlacementList = [Replicate()] * (len(inputs_strategy) + 1)
single_dim_placement = [dim_sharding]
return expand_to_full_mesh_op_strategy(mesh, op_schema, single_dim_placement)
def as_list(
x: Union[list[object], object],
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
@ -332,3 +361,58 @@ def expand_to_full_mesh_op_strategy(
)
all_strategies.append(strategy)
return OpStrategy(all_strategies)
import dataclasses
@dataclasses.dataclass
class OpStrategyPack:
"""
A dataclass to pack the op strategy and the schema info.
"""
op_overload: torch._ops.OpOverload
strategy_function: Callable[[OpSchema], StrategyType]
schema_info: Optional[RuntimeSchemaInfo] = None
@contextmanager
def implicit_strategy_context(
op_strategy_pack_list: Optional[list[OpStrategyPack]] = None,
):
"""
Context manager for setting and clearing implicit strategy.
Args:
op_strategy_pack_list: A list of OpStrategyPack objects. If we specify
the OpStrategyPack for the operator, the specified strategy will be
registered. Otherwise, the default replication strategy will be
registered for all operators.
"""
propagator = DTensor._op_dispatcher.sharding_propagator
propagator.enable_implicit_strategy = True
try:
if op_strategy_pack_list:
for op_strategy_pack in op_strategy_pack_list:
op_overload = op_strategy_pack.op_overload
schema_info = op_strategy_pack.schema_info
strategy_func = op_strategy_pack.strategy_function
# register the op strategy
register_op_strategy(op_overload, schema_info=schema_info)(
strategy_func
)
yield
finally:
propagator.enable_implicit_strategy = False
# clear this op strategy cache
op_to_remove = propagator.implicit_strategy_op_tracker
if op_strategy_pack_list:
op_to_remove.extend([i.op_overload for i in op_strategy_pack_list])
for op_overload in op_to_remove:
if op_overload in propagator.op_strategy_funcs:
del propagator.op_strategy_funcs[op_overload]
if op_overload in propagator.op_to_schema_info:
del propagator.op_to_schema_info[op_overload]
propagator.propagate_op_sharding.cache.cache_clear()

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import threading
import warnings
from collections.abc import Sequence
from functools import lru_cache
from itertools import chain
@ -81,6 +82,10 @@ class ShardingPropagator:
aten.slice_backward.default: 1,
}
self.enable_implicit_strategy = False
# track which ops have been implicit registered
self.implicit_strategy_op_tracker: list[OpOverload] = []
def register_sharding_prop_rule(
self,
op_overload: OpOverload,
@ -254,6 +259,32 @@ class ShardingPropagator:
kwargs_schema=kwargs_op_strategy,
)
def try_propagate_implicit_strategy(
self, op: torch._ops.OpOverload, op_schema: OpSchema
):
"""
Try propagate the implicit replication sharding strategy for an operator given the op_schema.
"""
if op not in self.op_strategy_funcs:
if not self.enable_implicit_strategy:
raise NotImplementedError(
f"Operator {op_schema.op} does not have a sharding strategy registered."
)
else:
# lazy import to avoid circular dependency
from torch.distributed.tensor._ops.utils import replicate_op_strategy
schema_info = op_schema.schema_info if op_schema else None
self.implicit_strategy_op_tracker.append(op)
self.register_op_strategy(op, replicate_op_strategy, schema_info)
# TODO: point warning message to instructions on how to write a
# strategy once we have
warnings.warn(
f"implicitly register sharding strategy op {op.name()} using {replicate_op_strategy}"
)
op_strategy = self.op_strategy_funcs[op](op_schema)
return op_strategy
def propagate(self, op_info: OpInfo) -> None:
# We cannot use an lru cache if we know that inputs will have dynamic shapes,
# because SymInts are not hashable.
@ -278,9 +309,56 @@ class ShardingPropagator:
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
if op_schema.op in self.op_strategy_funcs:
if op_schema.op in self.op_to_rules:
# propagate the sharding with rule
sharding_prop_func = self.op_to_rules[op_schema.op]
# step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except NotImplementedError as e:
raise e
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}.\nError: {e}"
) from e
# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
# decide how to do redistribute on inputs
if output_sharding.output_spec is None:
if output_sharding.redistribute_schema is None:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}!"
)
else:
# we do auto redistribute on inputs if necessary
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(
output_sharding.redistribute_schema
)
# we set the output sharding with the new propagation result
# so that dispatching know both output_spec and redistribute_schema
# exist, which indicates a reshard is needed
output_sharding.output_spec = propagation_res.output_spec
output_sharding.needs_redistribute = True
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
else:
# wrap the op_schema with op strategy for sharding strategy propagation
strategy_schema = self._wrap_with_op_strategy(op_schema)
strategy_schema.schema_info = op_schema.schema_info
# assign implicit strategy if enabled
self.try_propagate_implicit_strategy(op_schema.op, strategy_schema)
# run sharding strategy propagation/generation
op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema)
@ -442,53 +520,6 @@ class ShardingPropagator:
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
elif op_schema.op in self.op_to_rules:
# propagate the sharding with rule
sharding_prop_func = self.op_to_rules[op_schema.op]
# step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except NotImplementedError as e:
raise e
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}.\nError: {e}"
) from e
# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
# decide how to do redistribute on inputs
if output_sharding.output_spec is None:
if output_sharding.redistribute_schema is None:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}!"
)
else:
# we do auto redistribute on inputs if necessary
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(
output_sharding.redistribute_schema
)
# we set the output sharding with the new propagation result
# so that dispatching know both output_spec and redistribute_schema
# exist, which indicates a reshard is needed
output_sharding.output_spec = propagation_res.output_spec
output_sharding.needs_redistribute = True
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
else:
raise NotImplementedError(
f"Operator {op_schema.op} does not have a sharding strategy registered."
)
def _select_strategy(self, strategy: OpStrategy) -> OpSpec:
if len(strategy.strategies) == 1: