mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes #164543 This PR improves the `__str__` method of DTensor's `OpSchema` to provide better readable error message when dispatch fails as the error message prints `{op_info.schema}` example 1 `aten.embedding` ``` aten.embedding.default(Spec(f32[2048, 256](S(0))), Spec(i64[16, 2048](S(0)R))) on DeviceMesh((dp=2, tp=2), 'cuda', stride=(2, 1))) ``` example 2 `aten.mm` ``` aten.mm.default(Spec(f32[1024, 512](S(1))), Spec(f32[512, 256](S(0)))) on DeviceMesh((tp=4), 'cuda', stride=(1,))) ``` example 3 `aten._scaled_dot_product_flash_attention` ``` aten._scaled_dot_product_flash_attention.default(Spec(f16[8, 16, 128, 64](RS(1))), Spec(f16[8, 16, 128, 64](RS(1))), Spec(f16[8, 16, 128, 64](RS(1)))) on DeviceMesh((dp=2, tp=4), 'cuda', stride=(4, 1))) ``` Added test ``` python test/distributed/tensor/test_dtensor_ops.py -k test_embedding_error_msg ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164623 Approved by: https://github.com/zpcore
605 lines
22 KiB
Python
605 lines
22 KiB
Python
# mypy: allow-untyped-defs
|
|
"""
|
|
DTensor operator schema definitions and utilities.
|
|
|
|
This module defines the core data structures and utilities for describing and managing
|
|
distributed tensor operations in PyTorch's DTensor system. It provides the foundational
|
|
schema types used for sharding propagation, operator strategy selection, and distributed
|
|
execution planning.
|
|
|
|
Key components:
|
|
- OpSpec: Describes acceptable sharding placements for operations
|
|
- OpStrategy: Represents the possible sharding strategies for an operator
|
|
- TupleStrategy: Container for multiple strategies when ops have tuple/list of tensors input
|
|
- OpSchema: Describes operator input/output schemas with DTensorSpecs
|
|
- OutputSharding: Manages output sharding specifications and redistribution
|
|
- RuntimeSchemaInfo: Runtime execution metadata for operators
|
|
- OpInfo: Complete runtime operator execution information
|
|
|
|
These schema definitions enable the DTensor system to:
|
|
1. Propagate tensor sharding information to the operator outputs
|
|
2. Greedily select sharding strategies for distributed operations
|
|
3. Plan and execute tensor redistributions when needed
|
|
4. Cache sharding decisions for performance optimization
|
|
"""
|
|
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass, field
|
|
from functools import cached_property
|
|
from typing import Any, Optional, Union
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
from torch._C import (
|
|
_DTensor_OpSchema_post_init,
|
|
_DTensor_OpSchema_recompute_comparison_key,
|
|
)
|
|
from torch._ops import OpOverload
|
|
from torch.distributed.device_mesh import DeviceMesh
|
|
from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
|
from torch.distributed.tensor.placement_types import Placement
|
|
|
|
|
|
try:
|
|
from torch.utils._cxx_pytree import (
|
|
register_pytree_node,
|
|
tree_leaves,
|
|
tree_map_only,
|
|
TreeSpec,
|
|
)
|
|
except ImportError:
|
|
from torch.utils._pytree import ( # type: ignore[no-redef, assignment]
|
|
register_pytree_node,
|
|
tree_leaves,
|
|
tree_map_only,
|
|
TreeSpec,
|
|
)
|
|
|
|
|
|
# Common type aliases
|
|
ArgsType = tuple[object, ...]
|
|
KwargsType = dict[str, object]
|
|
|
|
PlacementList = list[Optional[Placement]]
|
|
|
|
# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type should
|
|
# be the same set of possibilities.
|
|
OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
|
|
|
|
|
|
def _rebuild_tensor_from_dtensor_meta(arg) -> object:
|
|
"""
|
|
This is used to propagate tensor metadata, must be under fake mode
|
|
"""
|
|
assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta."
|
|
return torch.empty_strided(
|
|
arg.tensor_meta.shape,
|
|
arg.tensor_meta.stride,
|
|
dtype=arg.tensor_meta.dtype,
|
|
)
|
|
|
|
|
|
def _pretty_print_spec(spec: object) -> str:
|
|
if spec is None:
|
|
return "None"
|
|
elif isinstance(spec, DTensorSpec):
|
|
return "".join([str(p) for p in spec.placements])
|
|
elif isinstance(spec, Sequence):
|
|
return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")"
|
|
else:
|
|
raise RuntimeError(f"Unknown spec type to print: spec={spec}")
|
|
|
|
|
|
@dataclass
|
|
class OpSpec:
|
|
"""
|
|
An OpSpec describes an acceptable sharding placements of an operation, with the
|
|
specified DTensorSpecs for both the output and the inputs.
|
|
|
|
note: when the op return value is a single DTensor object, output_specs is
|
|
DTensorSpec; when the return value is a tuple of Optional[DTensor],
|
|
output_specs is a tuple of Optional[DTensorSpec].
|
|
|
|
note: we MUST produce an DTensorSpec for every output that is a Tensor. None
|
|
entries only occur for non-Tensor outputs (e.g., operators that return Optional[Tensor],
|
|
or non-Tensor outputs.)
|
|
|
|
invariant: the DeviceMesh on all DTensorSpec must be the same
|
|
"""
|
|
|
|
# output_specs and input_specs are related: for this op, given these input_specs,
|
|
# this is the way the output would look
|
|
output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]]
|
|
input_specs: Optional[Sequence[DTensorSpec]] = None
|
|
|
|
"""
|
|
redistribute_cost tells how expensive it is to redistribute a given input into the
|
|
placement specified in this OpSpec.
|
|
|
|
outer list: one entry (list) per (tensor) input in the op's arg schema
|
|
inner list: one entry (cost value) per possible sharding spec for that input
|
|
|
|
Example:
|
|
-------
|
|
another_op() -> tensor_a # another_op produces the output that becomes our first input
|
|
my_op(tensor_a)
|
|
|
|
Let's assume this OpSpec's input_specs are [Replicate()],
|
|
but another_op() supports 2 strategies (OpSpecs) which produce outputs of
|
|
Replicate()
|
|
Shard(0)
|
|
|
|
In this example, redistribute_costs would look like this
|
|
[
|
|
# one row representing "my_op's first input" (tensor_a)
|
|
[
|
|
# two entries, one for each strategies supported by another_op
|
|
0.0, # cost of redistributing tensor_a from 'Replicate()'
|
|
K, # cost of redistributing tensor_a from 'Shard(0)'
|
|
],
|
|
"""
|
|
redistribute_cost: Optional[list[list[float]]] = None
|
|
|
|
@cached_property
|
|
def output_spec(self) -> DTensorSpec:
|
|
"""
|
|
This function requires that the strategy have exactly one DTensorSpec as the
|
|
output spec. If the output_specs is a tuple, we throw an exception.
|
|
"""
|
|
if isinstance(self.output_specs, DTensorSpec):
|
|
return self.output_specs
|
|
else:
|
|
raise ValueError(
|
|
f"function output_spec expects a single DTensorSpec but got: {self.output_specs}"
|
|
)
|
|
|
|
@cached_property
|
|
def mesh(self):
|
|
if isinstance(self.output_specs, DTensorSpec):
|
|
return self.output_specs.mesh
|
|
elif isinstance(self.output_specs, tuple):
|
|
out_spec = self.output_specs[0]
|
|
assert isinstance(out_spec, DTensorSpec)
|
|
return out_spec.mesh
|
|
else:
|
|
raise ValueError(
|
|
f"function output_spec expects a single DTensorSpec or a tuple of DTensorSpec but got: {self.output_specs}"
|
|
)
|
|
|
|
def input_spec(self, index: int = 0) -> DTensorSpec:
|
|
assert self.input_specs is not None, "input_specs of OpSpec is None!"
|
|
assert len(self.input_specs) > index, (
|
|
f"Invalid index {index} for input_specs of length "
|
|
f"{len(self.input_specs)}: {self.input_specs}"
|
|
)
|
|
return self.input_specs[index]
|
|
|
|
def __str__(self) -> str:
|
|
if self.input_specs is not None:
|
|
input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> "
|
|
else:
|
|
input_specs_str = ""
|
|
output_spec_str = _pretty_print_spec(self.output_specs)
|
|
return f"{input_specs_str}{output_spec_str}"
|
|
|
|
|
|
class StrategyType:
|
|
"""
|
|
Base class type for op strategy, We have two StrategyType:
|
|
OpStrategy and TupleStrategy
|
|
"""
|
|
|
|
|
|
class OpStrategy(StrategyType):
|
|
"""
|
|
OpStrategy that consists of a list of sharding strategies associated with the op,
|
|
where each strategy is an OpSpec that describes the acceptable input/output sharding.
|
|
|
|
invariant: the DeviceMesh on all OpSpec must be the same
|
|
"""
|
|
|
|
def __init__(self, strategies: list[OpSpec]) -> None:
|
|
super().__init__()
|
|
self.strategies: list[OpSpec] = strategies
|
|
|
|
def __str__(self) -> str:
|
|
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
|
|
mesh_shape = self.mesh_shape
|
|
return f"OpStragety[{strategy_list_str}] @ mesh: {mesh_shape}"
|
|
|
|
def max_num_shards(self) -> int:
|
|
"""
|
|
Returns the max number of shards across all OpSpecs
|
|
"""
|
|
return max(strategy.output_spec.num_shards for strategy in self.strategies)
|
|
|
|
@property
|
|
def mesh(self):
|
|
return self.strategies[0].mesh
|
|
|
|
@property
|
|
def mesh_shape(self):
|
|
return self.strategies[0].mesh.shape
|
|
|
|
@property
|
|
def ndim(self):
|
|
return self.strategies[0].output_spec.ndim
|
|
|
|
@property
|
|
def shape(self):
|
|
return self.strategies[0].output_spec.shape
|
|
|
|
|
|
class TupleStrategy(StrategyType):
|
|
"""
|
|
TupleStrategy is a special case for operators that are fundamentally compound or batched such that some subset
|
|
of the inputs and outputs are completely unrelated to some other subset.
|
|
|
|
Generally, foreach_* ops are the most common use-case for TupleStrategy, because they accept lists of inputs,
|
|
but operate independently on each input or tuple of zipped inputs.
|
|
|
|
For example, [out_a, out_b] = torch.foreach_add([a, b], scalar): input a's sharding only affects out_a's sharding,
|
|
independent of b and out_b.
|
|
|
|
An example of an operator that should NOT use TupleStrategy is torch.split. It produces a List[Tensor]
|
|
as its output, but the sharding decision of one output is bound together with the decision
|
|
of each other output and the common input.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
children: Sequence[StrategyType],
|
|
) -> None:
|
|
super().__init__()
|
|
self.children: Sequence[StrategyType] = children
|
|
|
|
@property
|
|
@deprecated(
|
|
"TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs
|
|
category=FutureWarning,
|
|
)
|
|
def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs
|
|
"""
|
|
Alias for children, to maintain backward compatibility.
|
|
"""
|
|
return self.children
|
|
|
|
def child_mesh(self, index: int) -> DeviceMesh:
|
|
op_strategy = self.children[index]
|
|
assert isinstance(op_strategy, OpStrategy)
|
|
return op_strategy.mesh
|
|
|
|
def __str__(self) -> str:
|
|
child_strategies_str = ", ".join(
|
|
[f"{str(strat)}" for idx, strat in enumerate(self.children)]
|
|
)
|
|
return f"TupleStrategy({child_strategies_str})"
|
|
|
|
|
|
try:
|
|
register_pytree_node(
|
|
TupleStrategy,
|
|
lambda node: (node.children, None),
|
|
lambda children, _: TupleStrategy(tuple(children)),
|
|
)
|
|
except ValueError:
|
|
# already registered TupleStrategy, skip
|
|
pass
|
|
|
|
|
|
@dataclass
|
|
class RuntimeSchemaInfo:
|
|
"""
|
|
RuntimeSchemaInfo stores the operator schema related information for runtime (eager)
|
|
execution. This is mainly used for two ways: 1. to generate hash for args to determine
|
|
whether to re-run sharding prop or not 2. to determine if we need pytree
|
|
"""
|
|
|
|
# This static_argnum records static arg "starting index" for ops that have non-tensor
|
|
# args/kwargs which would affect sharding propagation results. All args starting from
|
|
# this index would be hashed to our sharding cache.
|
|
# Note that only a few ops need this information, e.g. view, transpose, var.dim, etc.
|
|
static_argnum: int = 100
|
|
# This static_kwargkey records static kwarg names which would affect sharding prop
|
|
static_kwargkey: Optional[list[str]] = None
|
|
# each op can decide if it wants to use pytree flatten/unflatten during operator
|
|
# eager execution, by default we don't need to do flatten/unflatten, only if the
|
|
# op indicate it needs to, this is to accelerate eager performance.
|
|
needs_pytree: bool = False
|
|
|
|
|
|
@dataclass
|
|
class OpSchema:
|
|
"""
|
|
OpSchema is a data class that describes an operator input schemas, it includes
|
|
DTensorSpecs/OpStrategies (instead of DTensor) and non-tensor args/kwargs (positional
|
|
order preserved). It is mainly used by the DTensor's dispatching logic to perform various
|
|
actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
|
|
|
|
NOTE: this must be used as a read only data class
|
|
TODO: make this a frozen dataclass
|
|
|
|
Args:
|
|
op: the operator overload we are intercepting
|
|
args_schema: contains args except that the DTensor args have been replaced
|
|
with its DTensorSpec or OpStrategy
|
|
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
|
|
with its DTensorSpec or OpStrategy
|
|
"""
|
|
|
|
op: OpOverload
|
|
args_schema: ArgsType
|
|
kwargs_schema: KwargsType
|
|
|
|
schema_info: Optional[RuntimeSchemaInfo] = None
|
|
|
|
_comparison_key: Optional[tuple[object, ...]] = None
|
|
|
|
has_symints: bool = field(init=False)
|
|
|
|
@property
|
|
def args_spec(self) -> tuple[DTensorSpec, ...]:
|
|
"""
|
|
args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
|
|
with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
|
|
mainly used by sharding propagation to propagate the output spec
|
|
"""
|
|
args = (
|
|
tree_leaves(self.args_schema)
|
|
if self.schema_info is not None and self.schema_info.needs_pytree
|
|
else self.args_schema
|
|
)
|
|
return tuple(item for item in args if isinstance(item, DTensorSpec))
|
|
|
|
@property
|
|
def args_strategy(self) -> tuple[OpStrategy, ...]:
|
|
# filter out non-relevant values from args schema to get a clean OpStrategy list
|
|
# separate with args_spec for the ease of type annotation
|
|
# TODO: see if we should merge this with args_spec
|
|
args = (
|
|
tree_leaves(self.args_schema)
|
|
if self.schema_info is not None and self.schema_info.needs_pytree
|
|
else self.args_schema
|
|
)
|
|
return tuple(item for item in args if isinstance(item, OpStrategy))
|
|
|
|
def __repr__(self) -> str:
|
|
args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema])
|
|
return (
|
|
f"OpSchema(op={self.op},"
|
|
f" args_schema=({args_schema}),"
|
|
f" kwargs_schema={self.kwargs_schema})"
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
args_schema: list[str] = []
|
|
device_mesh = None
|
|
|
|
for arg in self.args_schema:
|
|
if isinstance(arg, DTensorSpec):
|
|
args_schema.append(str(arg))
|
|
device_mesh = arg.mesh
|
|
elif isinstance(arg, OpStrategy):
|
|
assert len(arg.strategies) == 1
|
|
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
|
|
device_mesh = arg.mesh
|
|
elif isinstance(arg, TupleStrategy):
|
|
first_op_strategy = arg.children[0]
|
|
assert isinstance(first_op_strategy, OpStrategy)
|
|
device_mesh = first_op_strategy.mesh
|
|
args_schema.append(str(arg))
|
|
else:
|
|
args_schema.append(str(arg))
|
|
|
|
return f"{self.op}({', '.join(args_schema)}) on {device_mesh})"
|
|
|
|
def __post_init__(self) -> None:
|
|
_DTensor_OpSchema_post_init(self)
|
|
|
|
def arg_type_tensor_or_tensor_list_like(self, arg: object) -> bool:
|
|
is_tensor = isinstance(arg, DTensorSpec)
|
|
if is_tensor:
|
|
return True
|
|
|
|
if not isinstance(arg, list):
|
|
return False
|
|
|
|
return all(isinstance(e, DTensorSpec) or e is None for e in arg)
|
|
|
|
def return_type_tuple_tensor_like(self) -> bool:
|
|
# all dispatch ops could only return Tuple[Tensor] or have None/ints/floats
|
|
# in the tuple, but the first element must be a Tensor, so this check is enough
|
|
return_types = self.op._schema.returns
|
|
return len(return_types) > 1 and isinstance(
|
|
return_types[0].type, torch.TensorType
|
|
)
|
|
|
|
def return_type_list_tensor_like(self) -> bool:
|
|
# returns True if the return type is a List
|
|
return_types = self.op._schema.returns
|
|
return len(return_types) == 1 and isinstance(
|
|
return_types[0].type, torch.ListType
|
|
)
|
|
|
|
def return_type_tensor(self) -> bool:
|
|
return_types = self.op._schema.returns
|
|
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
|
|
# return types, so this check is enough for tensor like types
|
|
return isinstance(return_types[0].type, torch.TensorType)
|
|
|
|
def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh:
|
|
"""
|
|
This util can be used to get a mesh from the OpSchema that contains multiple
|
|
DTensors as arguments. When `validate` is True, it will try to validate that all the
|
|
arguments have the same mesh to avoid unexpected cross mesh errors.
|
|
|
|
NOTE: this util currently does not handle TupleStrategy when `validate=True`,
|
|
this is because for TupleStrategy there could be different types of checks, i.e.:
|
|
- for stack and cat like op, we need to check within a TupleStrategy is every
|
|
input is on the same mesh
|
|
- for foreach like ops we need to check "zipped" inputs are on the same mesh
|
|
for each index.
|
|
"""
|
|
first_arg = self.args_schema[0]
|
|
if isinstance(first_arg, (DTensorSpec, OpStrategy)):
|
|
mesh = first_arg.mesh
|
|
elif isinstance(first_arg, (list, tuple, TupleStrategy)):
|
|
first_elem = (
|
|
first_arg.children[0]
|
|
if isinstance(first_arg, TupleStrategy)
|
|
else first_arg[0]
|
|
)
|
|
assert isinstance(first_elem, (DTensorSpec, OpStrategy))
|
|
mesh = first_elem.mesh
|
|
else:
|
|
raise ValueError(f"Cannot find device mesh from args for op : {self.op}.")
|
|
|
|
if validate:
|
|
for arg in self.args_schema[1:]:
|
|
if isinstance(arg, (DTensorSpec, OpStrategy)) and arg.mesh != mesh:
|
|
raise RuntimeError(
|
|
f"DTensor does not support cross-mesh operation on {self.op}! "
|
|
f"Got meshes: {mesh} {arg.mesh}. "
|
|
f"Please make sure all the arguments have the same DeviceMesh."
|
|
)
|
|
|
|
return mesh
|
|
|
|
def is_inplace_op(self) -> bool:
|
|
# simple analysis of function schema to determine
|
|
# if this is an inplace variant, it might not
|
|
# be entirely correct, but it's good enough for now.
|
|
return self.op._schema.name[-1] == "_"
|
|
|
|
def is_out_variant_op(self) -> bool:
|
|
# simple analysis of function schema to determine
|
|
# if this is an out variant, it might not
|
|
# be entirely correct, but it's good enough for now.
|
|
return "out" in self.op._schema.overload_name
|
|
|
|
def is_view_op(self) -> bool:
|
|
return self.op._schema._is_view_op()
|
|
|
|
def _recompute_comparison_key(self) -> None:
|
|
_DTensor_OpSchema_recompute_comparison_key(self)
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self._comparison_key)
|
|
|
|
def __eq__(self, other: object) -> bool:
|
|
# early return checks
|
|
if not isinstance(other, OpSchema):
|
|
return False
|
|
|
|
if self.op != other.op:
|
|
return False
|
|
|
|
if len(self.args_schema) != len(other.args_schema):
|
|
return False
|
|
|
|
return self._comparison_key == other._comparison_key
|
|
|
|
def gen_fake_args(self) -> ArgsType:
|
|
"""
|
|
gen_fake_args: generate fake args for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake args for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec,
|
|
_rebuild_tensor_from_dtensor_meta,
|
|
self.args_schema,
|
|
is_leaf=lambda x: isinstance(x, DTensorSpec),
|
|
)
|
|
|
|
def gen_fake_kwargs(self) -> KwargsType:
|
|
"""
|
|
gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used
|
|
by sharding propagation rules to generate fake kwargs for the operator
|
|
to run the local tensor operator and get the output spec.
|
|
"""
|
|
return tree_map_only(
|
|
DTensorSpec,
|
|
_rebuild_tensor_from_dtensor_meta,
|
|
self.kwargs_schema,
|
|
is_leaf=lambda x: isinstance(x, DTensorSpec),
|
|
)
|
|
|
|
def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
|
|
suggestion_args_spec = self.args_spec
|
|
new_arg_schema: list[object] = []
|
|
idx_of_args_spec = 0
|
|
if (
|
|
origin_schema.schema_info is not None
|
|
and origin_schema.schema_info.needs_pytree
|
|
):
|
|
args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema)
|
|
else:
|
|
args_schema = origin_schema.args_schema
|
|
for arg in args_schema:
|
|
if isinstance(arg, DTensorSpec):
|
|
new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
|
|
idx_of_args_spec += 1
|
|
else:
|
|
new_arg_schema.append(arg)
|
|
self.args_schema = tuple(new_arg_schema)
|
|
self.kwargs_schema = origin_schema.kwargs_schema
|
|
self._recompute_comparison_key()
|
|
|
|
|
|
@dataclass
|
|
class OutputSharding:
|
|
"""
|
|
OutputSharding is a data class that is used by the sharding propagation,
|
|
it could set the output_spec upon successful propagation. If needs_redistribute
|
|
is set to True, a redistribute_schema would be returned together to indicate
|
|
the input arguments needs to be redistributed before the op execution.
|
|
|
|
NOTE: the redistribute_schema generated by sharding propagation should be
|
|
exactly the same as the operator OpSchema, except the DTensorSpecs
|
|
"""
|
|
|
|
# specifies the output sharding pattern
|
|
output_spec: OutputSpecType
|
|
# schema for redistribution if needed
|
|
redistribute_schema: Optional[OpSchema] = None
|
|
# flag indicating if inputs need redistribution
|
|
needs_redistribute: bool = False
|
|
# flag to use values from `redistribute_schema`
|
|
use_val_from_redistribute_schema: bool = False
|
|
|
|
@cached_property
|
|
def mesh(self):
|
|
if isinstance(self.output_spec, DTensorSpec):
|
|
return self.output_spec.mesh
|
|
elif isinstance(self.output_spec, tuple):
|
|
out_spec = self.output_spec[0]
|
|
if isinstance(out_spec, DTensorSpec):
|
|
return out_spec.mesh
|
|
else:
|
|
raise ValueError(f"Unknown output spec type: {type(out_spec)}")
|
|
else:
|
|
raise ValueError(f"Unknown output spec type: {type(self.output_spec)}")
|
|
|
|
|
|
@dataclass
|
|
class OpInfo:
|
|
"""
|
|
All Runtime Op execution info are packed here
|
|
"""
|
|
|
|
# The first compute device mesh recorded from args
|
|
# NOTE: one op could have multiple meshes from its args. We just record the first
|
|
# mesh here to check if current rank should participate in computation or not.
|
|
compute_mesh: DeviceMesh
|
|
|
|
# compete runtime operator infos
|
|
schema: OpSchema
|
|
flat_args_schema: list[object]
|
|
local_args: Sequence[object]
|
|
local_kwargs: dict[str, object]
|
|
args_tree_spec: Optional[TreeSpec] = None
|
|
|
|
# the output sharding info
|
|
output_sharding: Optional[OutputSharding] = None
|