mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-06 00:54:56 +08:00
This is the first series of PR that adopts operator impls to use a strategy based approach, each op utilizes OpStrategy and PlacementStrategy to generate their own strategy. By utilizing the strategy based approach along with the op graph, we could enable more advanced op implementation (decomp is possible), and turn the sharding prop to be more like a contraint satisfication problem. This PR alone only adds some basic tensor op strategies, and it directly works on the op graph that was used for metadata propagation. The tensor ops added in this PR mainly follows one of the arg strategy. The next set of PRs would add more op strategies to other ops. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607 Approved by: https://github.com/XilunWu
242 lines
9.0 KiB
Python
242 lines
9.0 KiB
Python
from dataclasses import dataclass
|
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
from torch.distributed._tensor.placement_types import DTensorSpec
|
|
from torch.utils._pytree import tree_map_only
|
|
|
|
|
|
# Common type aliases
|
|
ArgsType = Tuple[object, ...]
|
|
KwargsType = Dict[str, object]
|
|
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
|
|
# be the same set of possibilities.
|
|
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
|
|
|
|
|
|
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
|
|
""" "
|
|
This is used to propagate tensor metadata, must be under fake mode
|
|
"""
|
|
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
|
|
return torch.empty_strided(
|
|
arg.tensor_meta.shape,
|
|
arg.tensor_meta.stride,
|
|
dtype=arg.tensor_meta.dtype,
|
|
requires_grad=arg.tensor_meta.requires_grad,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PlacementStrategy(object):
|
|
"""
|
|
A placement strategy describes an acceptable sharding placements of the output
|
|
and the tensor arguments of an operation.
|
|
"""
|
|
|
|
output_spec: DTensorSpec
|
|
input_specs: Optional[Sequence[DTensorSpec]] = None
|
|
|
|
def pretty_print_placements(self, placements):
|
|
return "".join([str(p) for p in placements])
|
|
|
|
def __str__(self) -> str:
|
|
if self.input_specs is None:
|
|
input_specs_str = ""
|
|
else:
|
|
input_specs_str = ", ".join(
|
|
[
|
|
self.pretty_print_placements(spec.placements)
|
|
for spec in self.input_specs
|
|
]
|
|
)
|
|
output_spec_str = self.pretty_print_placements(self.output_spec.placements)
|
|
return f"({input_specs_str}) -> ({output_spec_str}) @ mesh layout: {tuple(self.output_spec.mesh.mesh.shape)}"
|
|
|
|
|
|
class StrategyType(object):
|
|
"""
|
|
Base class type for op strategy, We have two StrategyType:
|
|
OpStrategy and TupleStrategy
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class OpStrategy(StrategyType):
|
|
"""
|
|
OpStrategy that consists of a list of placement strategies associated with the op
|
|
"""
|
|
|
|
def __init__(self, strategies: List[PlacementStrategy]) -> None:
|
|
super().__init__()
|
|
self.strategies: List[PlacementStrategy] = strategies
|
|
|
|
def __str__(self) -> str:
|
|
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
|
return f"OpStrategy: [{strategy_list_str}]"
|
|
|
|
def max_num_shards(self) -> int:
|
|
"""
|
|
Returns the max number of shards across all placement strategies
|
|
"""
|
|
return max([strategy.output_spec.num_shards for strategy in self.strategies])
|
|
|
|
|
|
class TupleStrategy(StrategyType):
|
|
"""
|
|
TupleStrategy represents the output strategy of this op is a tuple
|
|
of strategy, i.e. If the output of this op is a tuple of tensors, we should
|
|
return a TupleStrategy that contains a tuple of OpStrategy.
|
|
|
|
NOTE: if the output of the op is a List[Tensor], it's likely we should return
|
|
OpStrategy directly in all cases.
|
|
"""
|
|
|
|
def __init__(self, childs: Tuple[StrategyType, ...]) -> None:
|
|
super().__init__()
|
|
self.childs: Tuple[StrategyType, ...] = childs
|
|
|
|
def __str__(self) -> str:
|
|
tuple_strategies_str = "TupleStrategy: "
|
|
child_strategies_str = "\n".join(
|
|
[
|
|
f" tuple idx: {idx}, strategy: {str(strat)}"
|
|
for idx, strat in enumerate(self.childs)
|
|
]
|
|
)
|
|
return f"{tuple_strategies_str}\n{child_strategies_str}"
|
|
|
|
|
|
@dataclass
|
|
class OpSchema:
|
|
"""
|
|
OpSchema is a data class that describes an operator input schemas, it
|
|
includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
|
|
preserved). It is mainly used by the dispatching logic below to run things like
|
|
sharding propagation.
|
|
|
|
Sharding propagation rules registered could utilize this data class and
|
|
do inplace update some fields (when necessary, i.e shape related ops) to make
|
|
sure the args/kwargs are legit before passing to the local tensor operator.
|
|
This is the main reason that we don't freeze this dataclass.
|
|
|
|
NOTE: greater access to the operator inputs comes with greater responsibility.
|
|
Here are some basic rules about what can be used and what can be changed.
|
|
|
|
Args:
|
|
func_schema: the function schema of the operator
|
|
args_schema: contains args except that the DTensor args have been replaced
|
|
with its DTensorSpec
|
|
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
|
|
with its DTensorSpec
|
|
|
|
What can be used:
|
|
- every attribute within this class could be read to conduct
|
|
sharding propagation.
|
|
What can be changed:
|
|
- only the args_schema and kwargs_schema could be changed.
|
|
- every non-tensor args could be changed to accomodate for local tensor
|
|
operations (i.e. for ops like view/reshape/...)
|
|
- every "DTensorSpec" attribute inside `args_schema`, `kwargs_schema` and
|
|
`args_spec` SHOULD NOT be updated! DTensorSpec are read only and sharding
|
|
propagation shouldn't inplace update them, otherwise the input DTensor
|
|
placements will get implicitly changed and it's error-prone.
|
|
"""
|
|
|
|
func_schema: torch._C.FunctionSchema
|
|
args_schema: ArgsType
|
|
kwargs_schema: KwargsType
|
|
|
|
is_inplace: bool = False
|
|
is_out_variant: bool = False
|
|
|
|
def __post_init__(self) -> None:
|
|
# simple analysis of function schema to determine
|
|
# if this is an inplace/out variant, it might not
|
|
# be entirely correct, but it's good enough for now.
|
|
self.is_inplace = self.func_schema.name[-1] == "_"
|
|
self.is_out_variant = "out" in self.func_schema.overload_name
|
|
|
|
@property
|
|
def args_spec(self) -> Tuple[DTensorSpec, ...]:
|
|
"""
|
|
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
|
|
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
|
|
mainly used by sharding propagation to propagate the output spec
|
|
"""
|
|
# filter out non-relevant values from args schema to get a clean spec list
|
|
# this would mainly be used by sharding propagation rules
|
|
return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec))
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"OpSchema(func_schema={self.func_schema},"
|
|
f" args_schema={self.args_schema},"
|
|
f" kwargs_schema={self.kwargs_schema})"
|
|
)
|
|
|
|
def __hash__(self) -> int:
|
|
# NOTE: we turn kwargs_schema into a frozenset to hash as it would not be nested dict
|
|
frozen_set_kwargs_schema = frozenset(self.kwargs_schema.items())
|
|
return hash((self.func_schema, self.args_spec, frozen_set_kwargs_schema))
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
if not isinstance(other, OpSchema):
|
|
return False
|
|
return (
|
|
self.func_schema == other.func_schema
|
|
and self.args_schema == other.args_schema
|
|
and self.kwargs_schema == other.kwargs_schema
|
|
)
|
|
|
|
def gen_fake_args(self) -> ArgsType:
|
|
"""
|
|
gen_fake_args: generate fake args for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake args for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.args_schema
|
|
)
|
|
|
|
def gen_fake_kwargs(self) -> KwargsType:
|
|
"""
|
|
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake kwargs for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
|
|
)
|
|
|
|
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
|
suggestion_args_spec = self.args_spec
|
|
new_arg_schema: List[object] = []
|
|
idx_of_args_spec = 0
|
|
for arg in origin_schema.args_schema:
|
|
if isinstance(arg, DTensorSpec):
|
|
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
|
idx_of_args_spec += 1
|
|
else:
|
|
new_arg_schema.append(arg)
|
|
self.args_schema = tuple(new_arg_schema)
|
|
self.kwargs_schema = origin_schema.kwargs_schema
|
|
|
|
|
|
@dataclass
|
|
class OutputSharding:
|
|
"""
|
|
OutputSharding is a data class that is used by the sharding propagation
|
|
rules, it could set the output_spec upon successful propagation, and if
|
|
it failed, output_spec would become None and sharding propagation rules
|
|
could give a list of suggestions for inputs to reshard.
|
|
|
|
NOTE: the schema_suggestion generated by sharding propagation should be
|
|
exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
|
|
"""
|
|
|
|
output_spec: OutputSpecType
|
|
schema_suggestions: Optional[List[OpSchema]] = None
|
|
failed_reason: Optional[str] = None
|