Files
pytorch/torch/distributed/tensor/_sharding_prop.py

527 lines
22 KiB
Python

# mypy: allow-untyped-defs
import threading
from collections.abc import Sequence
from functools import lru_cache
from itertools import chain
from typing import Callable, cast, Optional, Union
import torch
from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
from torch.distributed.tensor._op_schema import (
OpInfo,
OpSchema,
OpStrategy,
OutputSharding,
OutputSpecType,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
TupleStrategy,
)
from torch.distributed.tensor._utils import (
compute_local_shape_and_global_offset,
compute_local_stride,
try_find_mesh_from_args,
)
aten = torch.ops.aten
def _length(obj) -> int:
if obj is None:
return 0
if not isinstance(obj, Sequence):
return 1
return len(obj)
class LocalLRUCache(threading.local):
def __init__(self, user_function: Callable) -> None:
self.cache = lru_cache(None)(user_function)
def __call__(self, *args, **kwargs) -> object:
return self.cache(*args, **kwargs)
def cache_info(self):
return self.cache.cache_info()
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
self.op_strategy_funcs: dict[
OpOverload,
Callable[[DeviceMesh, OpSchema], StrategyType],
] = {}
# op map to save static argnum to decide to reuse sharding prop cache or
# re-run sharding prop
self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {}
self.propagate_op_sharding = LocalLRUCache(
self.propagate_op_sharding_non_cached
)
# op map to save indices of shape (and stride) args which may need to be
# modified in sharding prop
self.op_to_shape_and_stride_idx: dict[
OpOverload, Union[int, tuple[int, int]]
] = {
# new factory ops
aten.new_empty.default: 1,
aten.new_full.default: 1,
aten.new_ones.default: 1,
aten.new_zeros.default: 1,
aten.new_empty_strided.default: (1, 2),
# view ops
aten.expand.default: 1,
aten.reshape.default: 1,
aten.view.default: 1,
aten._unsafe_view.default: 1,
}
def register_sharding_prop_rule(
self,
op_overload: OpOverload,
rule_func: Callable[[OpSchema], OutputSharding],
schema_info: Optional[RuntimeSchemaInfo] = None,
):
"""
Register a sharding propagation rule for an operator.
"""
self.op_to_rules[op_overload] = rule_func
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info
def register_op_strategy(
self,
op_overload: OpOverload,
strategy_func: Callable[[DeviceMesh, OpSchema], StrategyType],
schema_info: Optional[RuntimeSchemaInfo] = None,
):
"""
Register a sharding strategy generator for an operator.
"""
self.op_strategy_funcs[op_overload] = strategy_func
if schema_info is not None:
self.op_to_schema_info[op_overload] = schema_info
def _propagate_tensor_meta_non_cached(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas
"""
if op_schema.op == aten.equal.default:
# data dependent ops can't be used for fake propagation
return None
# NOTE: We must call the tracing in fake tensor mode so that it
# avoids materializing memory
with FakeTensorMode():
fake_args = op_schema.gen_fake_args()
fake_kwargs = op_schema.gen_fake_kwargs()
fake_out = op_schema.op(*fake_args, **fake_kwargs)
if isinstance(fake_out, torch.Tensor):
return TensorMeta(
shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype
)
elif isinstance(fake_out, (tuple, list)):
tensor_meta_list: list[Optional[TensorMeta]] = []
for fake_out_item in fake_out:
if isinstance(fake_out_item, torch.Tensor):
tensor_meta_list.append(
TensorMeta(
shape=fake_out_item.shape,
stride=fake_out_item.stride(),
dtype=fake_out_item.dtype,
)
)
else:
tensor_meta_list.append(None)
return (
tuple(tensor_meta_list)
if isinstance(fake_out, tuple)
else tensor_meta_list
)
else:
# if fake is not a tensor or tuple of tensor, return as none
return None
@lru_cache # noqa: B019
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
return self._propagate_tensor_meta_non_cached(op_schema)
def _wrap_output_spec_tensor_meta(
self,
op: OpOverload,
output_specs: OutputSpecType,
output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]],
) -> None:
"""
Wrap the output_specs with the tensor metadata from the output.
"""
if isinstance(output_specs, DTensorSpec):
if not isinstance(output_tensor_meta, TensorMeta):
# Either error due to ShardingPropagator or due to incorrect OutputSpec
if not isinstance(output_tensor_meta, (tuple, list)):
raise ValueError(
"ShardingPropagator error: output does not have an associated "
"TensorMeta"
)
raise ValueError(
f"For the op {op.name()}, `output_specs` has 1 output which does "
"not equal the "
f"number of op outputs: {len(output_tensor_meta)}."
)
output_specs.tensor_meta = output_tensor_meta
elif isinstance(output_specs, (tuple, list)):
if not isinstance(output_tensor_meta, (tuple, list)) or len(
output_specs
) != len(output_tensor_meta):
raise ValueError(
f"For the op {op.name()}, `output_specs` has {len(output_specs)} "
"outputs which does not equal the "
f"number of op outputs {_length(output_tensor_meta)}."
)
for i, spec in enumerate(output_specs):
if isinstance(spec, DTensorSpec):
output_tensor_meta_i = output_tensor_meta[i]
if not isinstance(output_tensor_meta_i, TensorMeta):
# NOTE: aten.convolution_backward.default is an exception and it
# needs extra handling because the first Tensor in the output
# tuple can be `None` if the input Tensor to convolution op has
# `requires_grad=False` (e.g. convolution layer is the first
# layer in the model). We explicitly allow its corresponding
# TensorMeta to be `None`.
if (
op == aten.convolution_backward.default
and i == 0
and output_tensor_meta_i is None
):
assert isinstance(output_specs, list)
output_specs[i] = None
continue
else:
raise ValueError(
f"ShardingPropagator error: output {i} of {op.name()} "
"does not have an associated TensorMeta"
)
spec.tensor_meta = output_tensor_meta_i
def propagate(self, op_info: OpInfo) -> None:
# We cannot use an lru cache if we know that inputs will have dynamic shapes,
# because SymInts are not hashable.
# This is generally ok because this only happens during tracing in torch.compile,
# and tracing does not need to be as fast as eagermode DTensor usages.
if op_info.schema.has_symints:
output_sharding = self.propagate_op_sharding_non_cached(op_info.schema)
else:
output_sharding = cast(
OutputSharding, self.propagate_op_sharding(op_info.schema)
)
op_info.output_sharding = output_sharding
def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
"""
# special case op, we don't need to propagate for local
# scalar. TODO: figure out a better way to handle this
if op_schema.op is aten._local_scalar_dense.default:
return OutputSharding(None, op_schema)
out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema)
def spec_to_strategy(spec: object) -> object:
if isinstance(spec, DTensorSpec):
return OpStrategy([PlacementStrategy(spec)])
elif (
isinstance(spec, (list, tuple))
and len(spec) > 0
and isinstance(spec[0], DTensorSpec)
):
# tensor list create tuple strategy
tuple_strategy = [spec_to_strategy(s) for s in spec]
tuple_strategy = cast(Sequence[StrategyType], tuple_strategy)
return TupleStrategy(
tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy
)
else:
return spec
if op_schema.op in self.op_strategy_funcs:
# generate op strategy for the op.
mesh = try_find_mesh_from_args(op_schema.op, op_schema.args_schema)
# swap the args spec with args strategies
args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema]
kwargs_op_strategy = {
k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items()
}
# construct a new OpSchema on args for strategy based propagation
strategy_schema: OpSchema = OpSchema(
op=op_schema.op,
args_schema=tuple(args_op_strategy),
kwargs_schema=kwargs_op_strategy,
)
op_strategy = self.op_strategy_funcs[op_schema.op](mesh, strategy_schema)
if isinstance(op_strategy, OpStrategy):
# single Op strategy
output_strategy = self._select_strategy(op_strategy)
# check if we need to redistribute the input
needs_redistribute = False
expected_input_specs: list[DTensorSpec] = []
# in case where the op does not specify input_specs and output_specs
# is a DTensorSpec, we use output_specs as the spec for each DTensor
# input arg.
if output_strategy.input_specs is None:
assert isinstance(output_strategy.output_specs, DTensorSpec)
for idx, input_spec in enumerate(op_schema.args_spec):
desired_spec = (
output_strategy.output_spec
if output_strategy.input_specs is None
else output_strategy.input_specs[idx]
)
expected_input_specs.append(
desired_spec.shallow_copy_with_tensor_meta(
input_spec.tensor_meta
)
)
if input_spec.placements != desired_spec.placements:
needs_redistribute = True
suggestion_schema = None
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.op, tuple(expected_input_specs), {}
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
# shape and stride args need to be modified for
# view ops and new factory ops, potentially
if op_schema.op in self.op_to_shape_and_stride_idx:
assert isinstance(output_strategy.output_spec, DTensorSpec)
# It happens when the output has the same shape as the input
# and the input placements are not all Replicate().
if output_strategy.output_spec.is_sharded():
schema = suggestion_schema or op_schema
assert isinstance(out_tensor_meta, TensorMeta)
suggestion_schema = self._adjust_shape_and_stride_args(
out_tensor_meta, schema, output_strategy.output_spec, mesh
)
needs_redistribute = True
# construct output spec for the op
if op_schema.return_type_tuple_tensor_like():
# for ops that return multiple tensors and the output_specs is not
# a tuple, we use a tuple of that single output spec as the new
# output_specs
output_specs: OutputSpecType = output_strategy.output_specs
if isinstance(output_specs, DTensorSpec):
output_specs = tuple(
[
# create a new DTensorSpec with the same placement as the
# output_specs in output_strategy
DTensorSpec(
mesh=output_specs.mesh,
placements=output_specs.placements,
tensor_meta=output_specs.tensor_meta,
)
for _ in range(len(op_schema.op._schema.returns))
]
)
elif op_schema.return_type_tensor():
output_specs = output_strategy.output_specs
else:
output_specs = None
output_sharding = OutputSharding(
output_specs,
suggestion_schema,
needs_redistribute=needs_redistribute,
)
elif isinstance(op_strategy, TupleStrategy):
# tuple strategy output sharding processing
# runtime selected placement strategy for each TupleStrategy input arg
selected_strategies: list[PlacementStrategy] = []
out_spec_list: list[DTensorSpec] = []
for strategy in op_strategy.childs:
assert isinstance(strategy, OpStrategy)
selected_strategy = self._select_strategy(strategy)
selected_strategies.append(selected_strategy)
out_spec_list.append(selected_strategy.output_spec)
needs_redistribute = False
suggestion_args: list[object] = []
tensor_or_list_tensor_arg_idx = 0
for arg in op_schema.args_schema:
if (
arg
and isinstance(arg, (list, tuple))
and isinstance(arg[0], DTensorSpec)
):
expected_input_spec_list: list[DTensorSpec] = []
for idx, arg_spec in enumerate(arg):
expected_input_spec = selected_strategies[idx].input_spec(
tensor_or_list_tensor_arg_idx
)
expected_input_spec = (
expected_input_spec.shallow_copy_with_tensor_meta(
arg_spec.tensor_meta
)
)
if arg_spec.placements != expected_input_spec.placements:
needs_redistribute = True
expected_input_spec_list.append(expected_input_spec)
suggestion_args.append(
tuple(expected_input_spec_list)
if isinstance(arg, tuple)
else expected_input_spec_list
)
tensor_or_list_tensor_arg_idx += 1
elif isinstance(arg, DTensorSpec):
expected_input_spec = selected_strategies[0].input_spec(
tensor_or_list_tensor_arg_idx
)
expected_input_spec = (
expected_input_spec.shallow_copy_with_tensor_meta(
arg.tensor_meta
)
)
if arg.placements != expected_input_spec.placements:
needs_redistribute = True
suggestion_args.append(expected_input_spec)
tensor_or_list_tensor_arg_idx += 1
else:
suggestion_args.append(arg)
suggestion_schema = None
if needs_redistribute:
suggestion_schema = OpSchema(
op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema
)
output_sharding = OutputSharding(
tuple(out_spec_list) if out_tensor_meta is not None else None,
suggestion_schema,
needs_redistribute=needs_redistribute,
)
else:
raise ValueError("Unsupported op strategy type")
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
elif op_schema.op in self.op_to_rules:
# propagate the sharding with rule
sharding_prop_func = self.op_to_rules[op_schema.op]
# step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
except NotImplementedError as e:
raise e
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}.\nError: {e}"
) from e
# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
# decide how to do redistribute on inputs
if output_sharding.output_spec is None:
if output_sharding.redistribute_schema is None:
raise RuntimeError(
f"Sharding propagation failed on op {op_schema}!"
)
else:
# we do auto redistribute on inputs if necessary
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(
output_sharding.redistribute_schema
)
# we set the output sharding with the new propagation result
# so that dispatching know both output_spec and redistribute_schema
# exist, which indicates a reshard is needed
output_sharding.output_spec = propagation_res.output_spec
output_sharding.needs_redistribute = True
# associate the output sharding with the output tensor metadata
self._wrap_output_spec_tensor_meta(
op_schema.op, output_sharding.output_spec, out_tensor_meta
)
return output_sharding
else:
raise NotImplementedError(
f"Operator {op_schema.op} does not have a sharding strategy registered."
)
def _select_strategy(self, strategy: OpStrategy) -> PlacementStrategy:
if len(strategy.strategies) == 1:
# short cut with only one possible strategy
return strategy.strategies[0]
strategy_costs: list[float] = []
for strtg in strategy.strategies:
assert strtg.redistribute_cost is not None, (
"must set redistribute cost each strategy!"
)
redistribute_cost = sum(chain.from_iterable(strtg.redistribute_cost))
strategy_costs.append(redistribute_cost)
# for eager execution, we just select the one with the minimal redistribute cost
return strategy.strategies[strategy_costs.index(min(strategy_costs))]
def _adjust_shape_and_stride_args(
self,
out_tensor_meta: TensorMeta,
schema: OpSchema,
spec: DTensorSpec,
mesh: DeviceMesh,
) -> OpSchema:
shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
if isinstance(shape_stride_idx, tuple):
shape_idx, stride_idx = shape_stride_idx
else:
shape_idx = shape_stride_idx
stride_idx = None
expected_input_schema = list(schema.args_schema)
# adjust shape to be the same as that of the _local_tensor
# of the DTensor input arg at index 0, which is inferred
expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset(
out_tensor_meta.shape, mesh, spec.placements
)
# adjust the stride arg for aten.new_empty_strided.default
if stride_idx:
expected_input_schema[stride_idx] = compute_local_stride(
out_tensor_meta.stride, mesh, spec.placements
)
return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)