mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] refactor view ops to use OpStrategy (#126011)
As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/126011 Approved by: https://github.com/wanchaol, https://github.com/XilunWu
This commit is contained in:
committed by
PyTorch MergeBot
parent
a0df40f195
commit
9edf54df4d
@ -11,9 +11,9 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate,
|
||||
from torch.distributed._tensor.debug import CommDebugMode
|
||||
from torch.distributed._tensor.ops.view_ops import (
|
||||
Broadcast,
|
||||
dim_maps,
|
||||
Flatten,
|
||||
InputDim,
|
||||
ops,
|
||||
Repeat,
|
||||
Singleton,
|
||||
Split,
|
||||
@ -130,8 +130,8 @@ class TestViewOps(DTensorTestBase):
|
||||
return 6
|
||||
|
||||
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
|
||||
spec = ops[op]
|
||||
rules = spec.dim_map(*args, **kwargs)
|
||||
dim_map = dim_maps[op]
|
||||
rules = dim_map(*args, **kwargs)
|
||||
outputs = op(*args, **kwargs)
|
||||
flat_args = pytree.arg_tree_leaves(*args)
|
||||
in_shape = flat_args[0].shape
|
||||
@ -163,7 +163,6 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
|
||||
for in_shard in all_sharding_choices:
|
||||
# print(f' |--- {in_shard}')
|
||||
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
|
||||
|
||||
comm_mode = CommDebugMode()
|
||||
@ -180,7 +179,7 @@ class TestViewOps(DTensorTestBase):
|
||||
self.assertEqual(outputs, full_out)
|
||||
|
||||
def dimmap_test(self, op, args, expected_rule_output):
|
||||
rules = ops[op].dim_map(*args)
|
||||
rules = dim_maps[op](*args)
|
||||
self.assertEqual(rules, expected_rule_output)
|
||||
self.call_dt_test(op, args, {}, self.device_mesh)
|
||||
|
||||
@ -229,7 +228,7 @@ class TestViewOps(DTensorTestBase):
|
||||
)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4))
|
||||
dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
|
||||
|
||||
self.dimmap_test(
|
||||
torch.broadcast_to,
|
||||
@ -495,14 +494,14 @@ class TestViewOps(DTensorTestBase):
|
||||
InputDim(0),
|
||||
Flatten((InputDim(1), InputDim(2))),
|
||||
)
|
||||
view_as_complex_rule = ops[torch.view_as_complex].dim_map(inp)
|
||||
view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
|
||||
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
|
||||
expected_view_as_real_rule = (
|
||||
InputDim(0),
|
||||
Split(InputDim(1), (13, 2), 0),
|
||||
Split(InputDim(1), (13, 2), 1),
|
||||
)
|
||||
view_as_real_rule = ops[torch.view_as_real].dim_map(intermediate)
|
||||
view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
|
||||
self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
|
||||
|
||||
# test sharded computation correctness
|
||||
|
@ -9,11 +9,7 @@ import torch.utils._pytree as pytree
|
||||
from torch import Tensor
|
||||
|
||||
from torch.distributed._tensor import DeviceMesh, Replicate, Shard
|
||||
from torch.distributed._tensor.ops.view_ops import (
|
||||
DimSpec,
|
||||
InputDim,
|
||||
ops as view_op_rules,
|
||||
)
|
||||
from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim
|
||||
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
|
||||
|
||||
aten = torch.ops.aten
|
||||
@ -80,12 +76,12 @@ class BatchDimAnalyzer:
|
||||
return self.batch_dim_map[node]
|
||||
|
||||
if node.target in self.dim_rule_map:
|
||||
view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index]
|
||||
dim_map = dim_maps[self.dim_rule_map[node.target]] # type: ignore[index]
|
||||
args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
|
||||
kwargs_val = pytree.tree_map_only(
|
||||
fx.Node, lambda n: n.meta["val"], node.kwargs
|
||||
)
|
||||
output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
|
||||
output_dim_rules = dim_map(*args_val, **kwargs_val)
|
||||
|
||||
def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
|
||||
if isinstance(cmd, InputDim):
|
||||
|
@ -161,6 +161,14 @@ class OpStrategy(StrategyType):
|
||||
def output_shape(self):
|
||||
return self.strategies[0].output_spec.shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
return self.output_ndim
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.output_shape
|
||||
|
||||
|
||||
class TupleStrategy(StrategyType):
|
||||
"""
|
||||
|
@ -16,23 +16,24 @@ from typing import (
|
||||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch.distributed._tensor._utils import compute_local_shape
|
||||
from torch.distributed._tensor.api import Shard
|
||||
from torch.distributed._tensor.op_schema import (
|
||||
OpSchema,
|
||||
OutputSharding,
|
||||
OpStrategy,
|
||||
PlacementStrategy,
|
||||
RuntimeSchemaInfo,
|
||||
StrategyType,
|
||||
)
|
||||
from torch.distributed._tensor.ops.utils import (
|
||||
generate_redistribute_costs,
|
||||
normalize_dim,
|
||||
normalize_dims,
|
||||
prod,
|
||||
register_prop_rule,
|
||||
register_op_strategy,
|
||||
)
|
||||
|
||||
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate
|
||||
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -454,68 +455,41 @@ def dim_reduction(
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Op:
|
||||
dim_map: Callable[..., DimMap]
|
||||
shape_argnum: Optional[int] = None
|
||||
|
||||
|
||||
ops: Dict[Callable[..., torch.Tensor], Op] = {
|
||||
torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)),
|
||||
torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)),
|
||||
torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)),
|
||||
torch.broadcast_to: Op(
|
||||
dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1
|
||||
dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
|
||||
torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
|
||||
torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
|
||||
torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
|
||||
torch.broadcast_to: lambda input, shape: expand(input.shape, shape),
|
||||
Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
|
||||
torch.flatten: lambda tensor: dim_flatten(tensor.ndim),
|
||||
torch.movedim: lambda input, source, destination: dim_movedim(
|
||||
input.ndim, source, destination
|
||||
),
|
||||
Tensor.expand: Op(
|
||||
dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
|
||||
shape_argnum=1,
|
||||
torch.permute: lambda input, dims: tuple(
|
||||
InputDim(i) for i in normalize_dims(dims, input.ndim)
|
||||
),
|
||||
torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
|
||||
torch.movedim: Op(
|
||||
dim_map=lambda input, source, destination: dim_movedim(
|
||||
input.ndim, source, destination
|
||||
)
|
||||
),
|
||||
torch.permute: Op(
|
||||
dim_map=lambda input, dims: tuple(
|
||||
InputDim(i) for i in normalize_dims(dims, input.ndim)
|
||||
)
|
||||
),
|
||||
torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
|
||||
Tensor.repeat: Op(dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes)),
|
||||
torch.reshape: Op(
|
||||
dim_map=lambda input, shape: view_groups(input.shape, shape),
|
||||
shape_argnum=1,
|
||||
),
|
||||
torch.squeeze: Op(dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim)),
|
||||
torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)),
|
||||
torch.transpose: Op(
|
||||
dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1)
|
||||
),
|
||||
torch.unsqueeze: Op(dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim)),
|
||||
Tensor.view: Op(
|
||||
dim_map=lambda input, *shape: view_groups(input.shape, shape),
|
||||
shape_argnum=1,
|
||||
),
|
||||
torch.view_as_complex: Op(
|
||||
dim_map=lambda input: dim_flatten(input.ndim, input.ndim - 2)
|
||||
),
|
||||
torch.view_as_real: Op(dim_map=lambda input: dim_view_as_real(input.shape)),
|
||||
torch.ravel: lambda tensor: dim_flatten(tensor.ndim),
|
||||
Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes),
|
||||
torch.reshape: lambda input, shape: view_groups(input.shape, shape),
|
||||
torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim),
|
||||
torch.tile: lambda input, dims: dim_tile(input.ndim, dims),
|
||||
torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1),
|
||||
torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim),
|
||||
Tensor.view: lambda input, *shape: view_groups(input.shape, shape),
|
||||
torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2),
|
||||
torch.view_as_real: lambda input: dim_view_as_real(input.shape),
|
||||
}
|
||||
|
||||
|
||||
def propagate_shape_and_sharding(
|
||||
in_shard: Sequence[Placement],
|
||||
input_src_placements: Sequence[Placement],
|
||||
local_in_shape: Shape,
|
||||
rule: DimMap,
|
||||
mesh_sizes: Shape,
|
||||
) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]:
|
||||
) -> Tuple[Sequence[Placement], Sequence[Placement]]:
|
||||
"""
|
||||
Determine output sharding and tensor shape based on given global tensor shape and input sharding.
|
||||
|
||||
Takes as input the global shape of the tensor, and the input sharding,
|
||||
and produce corresponding output sharding and shape of the output tensor.
|
||||
Determine input target sharding and output sharding based on
|
||||
given global tensor shape and input source sharding.
|
||||
|
||||
Sharding propagation follows mapped dimensions:
|
||||
- An output dimension that maps directly to an input dimension is sharded equally
|
||||
@ -524,16 +498,13 @@ def propagate_shape_and_sharding(
|
||||
- An output dimension that is a split of the input dimension can only be sharded
|
||||
if the leftmost split size is divisible by the mesh dimension
|
||||
"""
|
||||
assert len(in_shard) == len(mesh_sizes)
|
||||
sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)}
|
||||
assert len(input_src_placements) == len(mesh_sizes)
|
||||
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
|
||||
shardable_dims: torch.Tensor = torch.ones(
|
||||
(len(local_in_shape), len(mesh_sizes)), dtype=torch.bool
|
||||
)
|
||||
mesh_ndim = len(mesh_sizes)
|
||||
shardable_dims: Dict[int, List[bool]] = {}
|
||||
|
||||
# in case an input dimension disappears (e.g. collapsing, reduction)
|
||||
# we cannot shard in that dimension (we need a replication fall-back rule)
|
||||
|
||||
seen_input_dims: Set[int] = set()
|
||||
|
||||
def collect_used_inputs(cmd: DimSpec) -> None:
|
||||
@ -545,28 +516,19 @@ def propagate_shape_and_sharding(
|
||||
for cmd in rule:
|
||||
collect_used_inputs(cmd)
|
||||
for dim in range(len(local_in_shape)):
|
||||
shardable_dims[dim, :] = dim in seen_input_dims
|
||||
shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
|
||||
|
||||
def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]:
|
||||
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
|
||||
if isinstance(cmd, InputDim):
|
||||
seen_input_dims.add(cmd.input_dim)
|
||||
return (
|
||||
local_in_shape[cmd.input_dim],
|
||||
cmd if cmd.input_dim in sharded_in_dims else None,
|
||||
)
|
||||
return cmd
|
||||
elif isinstance(cmd, Flatten):
|
||||
for dim in cmd.input_dims[1:]:
|
||||
if isinstance(dim, InputDim):
|
||||
shardable_dims[dim.input_dim, :] = False
|
||||
shardable_dims[dim.input_dim] = [False] * mesh_ndim
|
||||
dim0 = cmd.input_dims[0]
|
||||
return (
|
||||
prod(get_dim_size(a)[0] for a in cmd.input_dims),
|
||||
dim0
|
||||
if isinstance(dim0, InputDim) and dim0.input_dim in sharded_in_dims
|
||||
else None,
|
||||
)
|
||||
return dim0 if isinstance(dim0, InputDim) else None
|
||||
elif isinstance(cmd, Split):
|
||||
_, in_dim = get_dim_size(cmd.input_dim)
|
||||
in_dim = get_in_dim_to_shard(cmd.input_dim)
|
||||
out_size = cmd.group_shape[cmd.split_id]
|
||||
if cmd.split_id == 0 and in_dim is not None:
|
||||
# we need to check that the input dimension is divisible
|
||||
@ -579,14 +541,13 @@ def propagate_shape_and_sharding(
|
||||
# but we will allow it if that's the input and it's compatible
|
||||
|
||||
# 1. is this dimension shardable on each individual mesh dim?
|
||||
for mesh_dim, mesh_dim_size in enumerate(mesh_sizes):
|
||||
shardable_dims[in_dim.input_dim, mesh_dim] = (
|
||||
out_size % mesh_dim_size == 0
|
||||
)
|
||||
shardable_dims[in_dim.input_dim] = [
|
||||
out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes
|
||||
]
|
||||
|
||||
# 2. here we special case things like [Shard(0), Shard(0)]
|
||||
submesh_size = 1
|
||||
for size, shard in zip(mesh_sizes, in_shard):
|
||||
for size, shard in zip(mesh_sizes, input_src_placements):
|
||||
if isinstance(shard, Shard) and shard.dim == in_dim:
|
||||
submesh_size *= size
|
||||
assert (
|
||||
@ -594,158 +555,113 @@ def propagate_shape_and_sharding(
|
||||
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
|
||||
|
||||
# we will only shard our first component of the split
|
||||
return out_size, in_dim if cmd.split_id == 0 else None
|
||||
elif isinstance(cmd, Singleton):
|
||||
return 1, None
|
||||
elif isinstance(cmd, Broadcast):
|
||||
return cmd.dim_size, None
|
||||
elif isinstance(cmd, NewDim):
|
||||
return cmd.size, None
|
||||
return in_dim if cmd.split_id == 0 else None
|
||||
elif isinstance(cmd, Repeat):
|
||||
size, in_dim = get_dim_size(cmd.input_dim)
|
||||
in_dim = get_in_dim_to_shard(cmd.input_dim)
|
||||
if in_dim is not None:
|
||||
shardable_dims[in_dim.input_dim, :] = False
|
||||
return size * cmd.times, None
|
||||
shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
|
||||
return None
|
||||
else:
|
||||
raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}")
|
||||
return None
|
||||
|
||||
dim_map = {}
|
||||
out_shape = []
|
||||
# for each output dim, find the corresponding input dim in terms of sharding prop
|
||||
shard_dim_map = {}
|
||||
for dim, cmd in enumerate(rule):
|
||||
out_size, in_dim = get_dim_size(cmd)
|
||||
out_shape.append(out_size)
|
||||
in_dim = get_in_dim_to_shard(cmd)
|
||||
if in_dim is not None:
|
||||
dim_map[in_dim.input_dim] = dim
|
||||
shard_dim_map[in_dim.input_dim] = dim
|
||||
|
||||
needs_reshard = any(
|
||||
isinstance(placement, Shard) and not shardable_dims[placement.dim][mesh_dim]
|
||||
for mesh_dim, placement in enumerate(in_shard)
|
||||
)
|
||||
input_tgt_placements = [
|
||||
Replicate()
|
||||
if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]
|
||||
else p
|
||||
for mesh_dim, p in enumerate(input_src_placements)
|
||||
]
|
||||
output_placements = [
|
||||
Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p
|
||||
for p in input_tgt_placements
|
||||
]
|
||||
|
||||
output_placements = (
|
||||
None
|
||||
if needs_reshard
|
||||
else [Shard(dim_map[s.dim]) if isinstance(s, Shard) else s for s in in_shard]
|
||||
)
|
||||
|
||||
return (tuple(out_shape), output_placements, shardable_dims)
|
||||
return input_tgt_placements, output_placements
|
||||
|
||||
|
||||
def register_prop_rule_map(
|
||||
def register_op_strategy_map(
|
||||
aten_op_overload: torch._ops.OpOverload,
|
||||
local_op_name: Callable[..., torch.Tensor],
|
||||
schema_info: Optional[RuntimeSchemaInfo] = None,
|
||||
) -> None:
|
||||
spec: Op = ops[local_op_name]
|
||||
dim_map: Callable[..., DimMap] = dim_maps[local_op_name]
|
||||
|
||||
@register_prop_rule(aten_op_overload, schema_info=schema_info)
|
||||
def reshape_prop(op_schema: OpSchema) -> OutputSharding:
|
||||
rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
|
||||
input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0])
|
||||
mesh = input_dtensor_spec.mesh
|
||||
|
||||
assert isinstance(
|
||||
input_dtensor_spec, DTensorSpec
|
||||
), "Expected first input to be a DTensorSpec"
|
||||
global_in_shape = input_dtensor_spec.shape
|
||||
@register_op_strategy(aten_op_overload, schema_info=schema_info)
|
||||
def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
|
||||
rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
|
||||
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
|
||||
global_in_shape = input_strategy.output_shape
|
||||
assert global_in_shape is not None, "Shape required."
|
||||
|
||||
with disable_proxy_modes_tracing(), unset_fake_temporarily():
|
||||
(
|
||||
global_out_shape,
|
||||
shard_out,
|
||||
shardable_dims,
|
||||
) = propagate_shape_and_sharding(
|
||||
input_dtensor_spec.placements,
|
||||
output_strategy = OpStrategy([])
|
||||
for input_placement_strategy in input_strategy.strategies:
|
||||
input_src_spec = input_placement_strategy.output_spec
|
||||
|
||||
input_tgt_placements, output_placements = propagate_shape_and_sharding(
|
||||
input_src_spec.placements,
|
||||
tuple(global_in_shape),
|
||||
rules,
|
||||
mesh.shape,
|
||||
)
|
||||
|
||||
if shard_out is not None:
|
||||
# no reshard needed
|
||||
output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out))
|
||||
|
||||
# We only need the local shape to lower the call into the local op
|
||||
args = op_schema.args_schema
|
||||
shape_argnum = spec.shape_argnum
|
||||
if shape_argnum is not None:
|
||||
# compute the local shape from the global shape, then return
|
||||
# a resharding even if we don't really reshard, the only reason
|
||||
# for this type of resharding is to lower the global shape to
|
||||
# local shape
|
||||
local_out_shape = compute_local_shape(
|
||||
list(global_out_shape), mesh, shard_out
|
||||
)
|
||||
|
||||
suggested_schema = OpSchema(
|
||||
op=op_schema.op,
|
||||
args_schema=args[:shape_argnum]
|
||||
+ (tuple(local_out_shape),)
|
||||
+ args[shape_argnum + 1 :],
|
||||
kwargs_schema=op_schema.kwargs_schema,
|
||||
)
|
||||
return OutputSharding(
|
||||
output_spec=output_dtensor_spec,
|
||||
redistribute_schema=suggested_schema,
|
||||
needs_redistribute=True,
|
||||
)
|
||||
|
||||
return OutputSharding(output_spec=output_dtensor_spec)
|
||||
|
||||
else:
|
||||
# TODO: optimize this. we shouldn't simply blindly replicate
|
||||
# unshardable dims ...
|
||||
# FIXME: this can be wrong for situations where we have
|
||||
# [Shard(0), Shard(0)]
|
||||
suggested_placements = [
|
||||
p
|
||||
if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim]
|
||||
else Replicate()
|
||||
for mesh_dim, p in enumerate(input_dtensor_spec.placements)
|
||||
input_tgt_spec = DTensorSpec(
|
||||
placements=tuple(input_tgt_placements),
|
||||
mesh=input_src_spec.mesh,
|
||||
tensor_meta=input_src_spec.tensor_meta,
|
||||
)
|
||||
redistribute_costs = [
|
||||
generate_redistribute_costs(input_strategy, input_tgt_spec)
|
||||
]
|
||||
return OutputSharding(
|
||||
output_spec=None,
|
||||
redistribute_schema=OpSchema(
|
||||
op=op_schema.op,
|
||||
args_schema=(
|
||||
DTensorSpec(
|
||||
placements=tuple(suggested_placements),
|
||||
mesh=input_dtensor_spec.mesh,
|
||||
tensor_meta=input_dtensor_spec.tensor_meta,
|
||||
),
|
||||
)
|
||||
+ op_schema.args_schema[1:],
|
||||
kwargs_schema=op_schema.kwargs_schema,
|
||||
),
|
||||
|
||||
output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements))
|
||||
output_strategy.strategies.append(
|
||||
PlacementStrategy(
|
||||
output_specs=output_spec,
|
||||
input_specs=(input_tgt_spec,),
|
||||
redistribute_cost=redistribute_costs,
|
||||
)
|
||||
)
|
||||
|
||||
return output_strategy
|
||||
|
||||
register_prop_rule_map(aten.squeeze.default, torch.squeeze)
|
||||
register_prop_rule_map(
|
||||
|
||||
register_op_strategy_map(aten.squeeze.default, torch.squeeze)
|
||||
register_op_strategy_map(
|
||||
aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1))
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_op_strategy_map(
|
||||
aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(
|
||||
register_op_strategy_map(
|
||||
aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
|
||||
)
|
||||
register_prop_rule_map(aten.view_as_complex.default, torch.view_as_complex)
|
||||
register_prop_rule_map(aten.view_as_real.default, torch.view_as_real)
|
||||
register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex)
|
||||
register_op_strategy_map(aten.view_as_real.default, torch.view_as_real)
|
||||
|
@ -45,15 +45,21 @@ class ShardingPropagator:
|
||||
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
|
||||
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
|
||||
self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign]
|
||||
# op map to save indices of size (and stride) args which may need to be modified in sharding prop
|
||||
self.op_to_size_and_stride_idx: Dict[
|
||||
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
|
||||
self.op_to_shape_and_stride_idx: Dict[
|
||||
OpOverload, Union[int, Tuple[int, int]]
|
||||
] = {
|
||||
# new factory ops
|
||||
aten.new_empty.default: 1,
|
||||
aten.new_full.default: 1,
|
||||
aten.new_ones.default: 1,
|
||||
aten.new_zeros.default: 1,
|
||||
aten.new_empty_strided.default: (1, 2),
|
||||
# view ops
|
||||
aten.expand.default: 1,
|
||||
aten.reshape.default: 1,
|
||||
aten.view.default: 1,
|
||||
aten._unsafe_view.default: 1,
|
||||
}
|
||||
|
||||
def register_sharding_prop_rule(
|
||||
@ -260,16 +266,19 @@ class ShardingPropagator:
|
||||
)
|
||||
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
|
||||
|
||||
# size and stride args need to be modified for new factory ops, potentially
|
||||
if op_schema.op in self.op_to_size_and_stride_idx:
|
||||
# shape and stride args need to be modified for
|
||||
# view ops and new factory ops, potentially
|
||||
if op_schema.op in self.op_to_shape_and_stride_idx:
|
||||
assert isinstance(output_strategy.output_spec, DTensorSpec)
|
||||
# It happens when the output has the same shape as the input
|
||||
# and the input placements are not all Replicate().
|
||||
if output_strategy.output_spec.is_sharded():
|
||||
needs_redistribute = True
|
||||
suggestion_schema = self._adjust_size_and_stride_args(
|
||||
op_schema, output_strategy.output_spec, mesh
|
||||
schema = suggestion_schema or op_schema
|
||||
assert isinstance(out_tensor_meta, TensorMeta)
|
||||
suggestion_schema = self._adjust_shape_and_stride_args(
|
||||
out_tensor_meta, schema, output_strategy.output_spec, mesh
|
||||
)
|
||||
needs_redistribute = True
|
||||
|
||||
# construct output spec for the op
|
||||
if op_schema.return_type_tuple_tensor_like():
|
||||
@ -442,29 +451,31 @@ class ShardingPropagator:
|
||||
# for eager execution, we just select the one with the minimal redistribute cost
|
||||
return strategy.strategies[strategy_costs.index(min(strategy_costs))]
|
||||
|
||||
def _adjust_size_and_stride_args(
|
||||
self, op_schema: OpSchema, spec: DTensorSpec, mesh: DeviceMesh
|
||||
def _adjust_shape_and_stride_args(
|
||||
self,
|
||||
out_tensor_meta: TensorMeta,
|
||||
schema: OpSchema,
|
||||
spec: DTensorSpec,
|
||||
mesh: DeviceMesh,
|
||||
) -> OpSchema:
|
||||
size_stride_idx = self.op_to_size_and_stride_idx[op_schema.op]
|
||||
if isinstance(size_stride_idx, tuple):
|
||||
size_idx, stride_idx = size_stride_idx
|
||||
shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
|
||||
if isinstance(shape_stride_idx, tuple):
|
||||
shape_idx, stride_idx = shape_stride_idx
|
||||
else:
|
||||
size_idx = size_stride_idx
|
||||
shape_idx = shape_stride_idx
|
||||
stride_idx = None
|
||||
|
||||
expected_input_schema = list(op_schema.args_schema)
|
||||
size = cast(list, expected_input_schema[size_idx])
|
||||
# # adjust size to be the same as that of the _local_tensor
|
||||
# # of the DTensor input arg at index 0, which is inferred
|
||||
expected_input_schema[size_idx] = compute_local_shape(
|
||||
size, mesh, spec.placements
|
||||
expected_input_schema = list(schema.args_schema)
|
||||
# adjust shape to be the same as that of the _local_tensor
|
||||
# of the DTensor input arg at index 0, which is inferred
|
||||
expected_input_schema[shape_idx] = compute_local_shape(
|
||||
out_tensor_meta.shape, mesh, spec.placements
|
||||
)
|
||||
|
||||
# adjust the stride arg for aten.new_empty_strided.default
|
||||
if stride_idx:
|
||||
stride = cast(list, expected_input_schema[stride_idx])
|
||||
expected_input_schema[stride_idx] = compute_local_stride(
|
||||
stride, mesh, spec.placements
|
||||
out_tensor_meta.stride, mesh, spec.placements
|
||||
)
|
||||
|
||||
return OpSchema(op_schema.op, tuple(expected_input_schema), {})
|
||||
return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)
|
||||
|
Reference in New Issue
Block a user