[dtensor] refactor view ops to use OpStrategy (#126011)

As titled. Some ops require adjustment of output shape argument. In rule-based sharding prop, global output shape was inferred in the rule (in `view_ops.py`). In strategy-based sharding prop, it is now obtained from propagated out_tensor_meta (in `sharding_prop.py`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126011
Approved by: https://github.com/wanchaol, https://github.com/XilunWu
This commit is contained in:
Tianyu Liu
2024-05-16 13:13:02 -07:00
committed by PyTorch MergeBot
parent a0df40f195
commit 9edf54df4d
5 changed files with 159 additions and 229 deletions

View File

@ -11,9 +11,9 @@ from torch.distributed._tensor import DeviceMesh, distribute_tensor, Replicate,
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.ops.view_ops import (
Broadcast,
dim_maps,
Flatten,
InputDim,
ops,
Repeat,
Singleton,
Split,
@ -130,8 +130,8 @@ class TestViewOps(DTensorTestBase):
return 6
def call_dt_test(self, op, args, kwargs, device_mesh: DeviceMesh):
spec = ops[op]
rules = spec.dim_map(*args, **kwargs)
dim_map = dim_maps[op]
rules = dim_map(*args, **kwargs)
outputs = op(*args, **kwargs)
flat_args = pytree.arg_tree_leaves(*args)
in_shape = flat_args[0].shape
@ -163,7 +163,6 @@ class TestViewOps(DTensorTestBase):
)
for in_shard in all_sharding_choices:
# print(f' |--- {in_shard}')
in_dt = distribute_tensor(args[0], device_mesh, in_shard)
comm_mode = CommDebugMode()
@ -180,7 +179,7 @@ class TestViewOps(DTensorTestBase):
self.assertEqual(outputs, full_out)
def dimmap_test(self, op, args, expected_rule_output):
rules = ops[op].dim_map(*args)
rules = dim_maps[op](*args)
self.assertEqual(rules, expected_rule_output)
self.call_dt_test(op, args, {}, self.device_mesh)
@ -229,7 +228,7 @@ class TestViewOps(DTensorTestBase):
)
with self.assertRaises(AssertionError):
ops[torch.broadcast_to].dim_map(randn(24, 36), (1, 2, 4))
dim_maps[torch.broadcast_to](randn(24, 36), (1, 2, 4))
self.dimmap_test(
torch.broadcast_to,
@ -495,14 +494,14 @@ class TestViewOps(DTensorTestBase):
InputDim(0),
Flatten((InputDim(1), InputDim(2))),
)
view_as_complex_rule = ops[torch.view_as_complex].dim_map(inp)
view_as_complex_rule = dim_maps[torch.view_as_complex](inp)
self.assertEqual(view_as_complex_rule, expected_view_as_complex_rule)
expected_view_as_real_rule = (
InputDim(0),
Split(InputDim(1), (13, 2), 0),
Split(InputDim(1), (13, 2), 1),
)
view_as_real_rule = ops[torch.view_as_real].dim_map(intermediate)
view_as_real_rule = dim_maps[torch.view_as_real](intermediate)
self.assertEqual(view_as_real_rule, expected_view_as_real_rule)
# test sharded computation correctness

View File

@ -9,11 +9,7 @@ import torch.utils._pytree as pytree
from torch import Tensor
from torch.distributed._tensor import DeviceMesh, Replicate, Shard
from torch.distributed._tensor.ops.view_ops import (
DimSpec,
InputDim,
ops as view_op_rules,
)
from torch.distributed._tensor.ops.view_ops import dim_maps, DimSpec, InputDim
from torch.distributed._tensor.placement_types import _Partial, DTensorSpec
aten = torch.ops.aten
@ -80,12 +76,12 @@ class BatchDimAnalyzer:
return self.batch_dim_map[node]
if node.target in self.dim_rule_map:
view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index]
dim_map = dim_maps[self.dim_rule_map[node.target]] # type: ignore[index]
args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args)
kwargs_val = pytree.tree_map_only(
fx.Node, lambda n: n.meta["val"], node.kwargs
)
output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val)
output_dim_rules = dim_map(*args_val, **kwargs_val)
def collect_input_dim(cmd: DimSpec, input_dims: Set[int]):
if isinstance(cmd, InputDim):

View File

@ -161,6 +161,14 @@ class OpStrategy(StrategyType):
def output_shape(self):
return self.strategies[0].output_spec.shape
@property
def ndim(self):
return self.output_ndim
@property
def shape(self):
return self.output_shape
class TupleStrategy(StrategyType):
"""

View File

@ -16,23 +16,24 @@ from typing import (
import torch
from torch import Tensor
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.distributed._tensor._utils import compute_local_shape
from torch.distributed._tensor.api import Shard
from torch.distributed._tensor.op_schema import (
OpSchema,
OutputSharding,
OpStrategy,
PlacementStrategy,
RuntimeSchemaInfo,
StrategyType,
)
from torch.distributed._tensor.ops.utils import (
generate_redistribute_costs,
normalize_dim,
normalize_dims,
prod,
register_prop_rule,
register_op_strategy,
)
from torch.distributed._tensor.placement_types import DTensorSpec, Placement, Replicate
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
from torch.distributed.device_mesh import DeviceMesh
aten = torch.ops.aten
@ -454,68 +455,41 @@ def dim_reduction(
)
@dataclass
class Op:
dim_map: Callable[..., DimMap]
shape_argnum: Optional[int] = None
ops: Dict[Callable[..., torch.Tensor], Op] = {
torch.atleast_1d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 1)),
torch.atleast_2d: Op(dim_map=lambda x: dim_pad_left(x.ndim, 2)),
torch.atleast_3d: Op(dim_map=lambda x: dim_atleast_3d(x.ndim)),
torch.broadcast_to: Op(
dim_map=lambda input, shape: expand(input.shape, shape), shape_argnum=1
dim_maps: Dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = {
torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1),
torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2),
torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim),
torch.broadcast_to: lambda input, shape: expand(input.shape, shape),
Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
torch.flatten: lambda tensor: dim_flatten(tensor.ndim),
torch.movedim: lambda input, source, destination: dim_movedim(
input.ndim, source, destination
),
Tensor.expand: Op(
dim_map=lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)),
shape_argnum=1,
torch.permute: lambda input, dims: tuple(
InputDim(i) for i in normalize_dims(dims, input.ndim)
),
torch.flatten: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
torch.movedim: Op(
dim_map=lambda input, source, destination: dim_movedim(
input.ndim, source, destination
)
),
torch.permute: Op(
dim_map=lambda input, dims: tuple(
InputDim(i) for i in normalize_dims(dims, input.ndim)
)
),
torch.ravel: Op(dim_map=lambda tensor: dim_flatten(tensor.ndim)),
Tensor.repeat: Op(dim_map=lambda self, *sizes: dim_repeat(self.ndim, sizes)),
torch.reshape: Op(
dim_map=lambda input, shape: view_groups(input.shape, shape),
shape_argnum=1,
),
torch.squeeze: Op(dim_map=lambda input, dim=None: dim_squeeze(input.shape, dim)),
torch.tile: Op(dim_map=lambda input, dims: dim_tile(input.ndim, dims)),
torch.transpose: Op(
dim_map=lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1)
),
torch.unsqueeze: Op(dim_map=lambda input, dim: dim_unsqueeze(input.ndim, dim)),
Tensor.view: Op(
dim_map=lambda input, *shape: view_groups(input.shape, shape),
shape_argnum=1,
),
torch.view_as_complex: Op(
dim_map=lambda input: dim_flatten(input.ndim, input.ndim - 2)
),
torch.view_as_real: Op(dim_map=lambda input: dim_view_as_real(input.shape)),
torch.ravel: lambda tensor: dim_flatten(tensor.ndim),
Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes),
torch.reshape: lambda input, shape: view_groups(input.shape, shape),
torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim),
torch.tile: lambda input, dims: dim_tile(input.ndim, dims),
torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1),
torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim),
Tensor.view: lambda input, *shape: view_groups(input.shape, shape),
torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2),
torch.view_as_real: lambda input: dim_view_as_real(input.shape),
}
def propagate_shape_and_sharding(
in_shard: Sequence[Placement],
input_src_placements: Sequence[Placement],
local_in_shape: Shape,
rule: DimMap,
mesh_sizes: Shape,
) -> Tuple[Shape, Optional[Sequence[Placement]], torch.Tensor]:
) -> Tuple[Sequence[Placement], Sequence[Placement]]:
"""
Determine output sharding and tensor shape based on given global tensor shape and input sharding.
Takes as input the global shape of the tensor, and the input sharding,
and produce corresponding output sharding and shape of the output tensor.
Determine input target sharding and output sharding based on
given global tensor shape and input source sharding.
Sharding propagation follows mapped dimensions:
- An output dimension that maps directly to an input dimension is sharded equally
@ -524,16 +498,13 @@ def propagate_shape_and_sharding(
- An output dimension that is a split of the input dimension can only be sharded
if the leftmost split size is divisible by the mesh dimension
"""
assert len(in_shard) == len(mesh_sizes)
sharded_in_dims: Set[int] = {s.dim for s in in_shard if isinstance(s, Shard)}
assert len(input_src_placements) == len(mesh_sizes)
# for each input dim, for each mesh dim, provides a list of possible shardable dimensions
shardable_dims: torch.Tensor = torch.ones(
(len(local_in_shape), len(mesh_sizes)), dtype=torch.bool
)
mesh_ndim = len(mesh_sizes)
shardable_dims: Dict[int, List[bool]] = {}
# in case an input dimension disappears (e.g. collapsing, reduction)
# we cannot shard in that dimension (we need a replication fall-back rule)
seen_input_dims: Set[int] = set()
def collect_used_inputs(cmd: DimSpec) -> None:
@ -545,28 +516,19 @@ def propagate_shape_and_sharding(
for cmd in rule:
collect_used_inputs(cmd)
for dim in range(len(local_in_shape)):
shardable_dims[dim, :] = dim in seen_input_dims
shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim
def get_dim_size(cmd: DimSpec) -> Tuple[int, Optional[InputDim]]:
def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]:
if isinstance(cmd, InputDim):
seen_input_dims.add(cmd.input_dim)
return (
local_in_shape[cmd.input_dim],
cmd if cmd.input_dim in sharded_in_dims else None,
)
return cmd
elif isinstance(cmd, Flatten):
for dim in cmd.input_dims[1:]:
if isinstance(dim, InputDim):
shardable_dims[dim.input_dim, :] = False
shardable_dims[dim.input_dim] = [False] * mesh_ndim
dim0 = cmd.input_dims[0]
return (
prod(get_dim_size(a)[0] for a in cmd.input_dims),
dim0
if isinstance(dim0, InputDim) and dim0.input_dim in sharded_in_dims
else None,
)
return dim0 if isinstance(dim0, InputDim) else None
elif isinstance(cmd, Split):
_, in_dim = get_dim_size(cmd.input_dim)
in_dim = get_in_dim_to_shard(cmd.input_dim)
out_size = cmd.group_shape[cmd.split_id]
if cmd.split_id == 0 and in_dim is not None:
# we need to check that the input dimension is divisible
@ -579,14 +541,13 @@ def propagate_shape_and_sharding(
# but we will allow it if that's the input and it's compatible
# 1. is this dimension shardable on each individual mesh dim?
for mesh_dim, mesh_dim_size in enumerate(mesh_sizes):
shardable_dims[in_dim.input_dim, mesh_dim] = (
out_size % mesh_dim_size == 0
)
shardable_dims[in_dim.input_dim] = [
out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes
]
# 2. here we special case things like [Shard(0), Shard(0)]
submesh_size = 1
for size, shard in zip(mesh_sizes, in_shard):
for size, shard in zip(mesh_sizes, input_src_placements):
if isinstance(shard, Shard) and shard.dim == in_dim:
submesh_size *= size
assert (
@ -594,158 +555,113 @@ def propagate_shape_and_sharding(
), f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}."
# we will only shard our first component of the split
return out_size, in_dim if cmd.split_id == 0 else None
elif isinstance(cmd, Singleton):
return 1, None
elif isinstance(cmd, Broadcast):
return cmd.dim_size, None
elif isinstance(cmd, NewDim):
return cmd.size, None
return in_dim if cmd.split_id == 0 else None
elif isinstance(cmd, Repeat):
size, in_dim = get_dim_size(cmd.input_dim)
in_dim = get_in_dim_to_shard(cmd.input_dim)
if in_dim is not None:
shardable_dims[in_dim.input_dim, :] = False
return size * cmd.times, None
shardable_dims[in_dim.input_dim] = [False] * mesh_ndim
return None
else:
raise RuntimeError(f"cmd not found: {cmd}, in rule: {rule}")
return None
dim_map = {}
out_shape = []
# for each output dim, find the corresponding input dim in terms of sharding prop
shard_dim_map = {}
for dim, cmd in enumerate(rule):
out_size, in_dim = get_dim_size(cmd)
out_shape.append(out_size)
in_dim = get_in_dim_to_shard(cmd)
if in_dim is not None:
dim_map[in_dim.input_dim] = dim
shard_dim_map[in_dim.input_dim] = dim
needs_reshard = any(
isinstance(placement, Shard) and not shardable_dims[placement.dim][mesh_dim]
for mesh_dim, placement in enumerate(in_shard)
)
input_tgt_placements = [
Replicate()
if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim]
else p
for mesh_dim, p in enumerate(input_src_placements)
]
output_placements = [
Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p
for p in input_tgt_placements
]
output_placements = (
None
if needs_reshard
else [Shard(dim_map[s.dim]) if isinstance(s, Shard) else s for s in in_shard]
)
return (tuple(out_shape), output_placements, shardable_dims)
return input_tgt_placements, output_placements
def register_prop_rule_map(
def register_op_strategy_map(
aten_op_overload: torch._ops.OpOverload,
local_op_name: Callable[..., torch.Tensor],
schema_info: Optional[RuntimeSchemaInfo] = None,
) -> None:
spec: Op = ops[local_op_name]
dim_map: Callable[..., DimMap] = dim_maps[local_op_name]
@register_prop_rule(aten_op_overload, schema_info=schema_info)
def reshape_prop(op_schema: OpSchema) -> OutputSharding:
rules = spec.dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
input_dtensor_spec = cast(DTensorSpec, op_schema.args_schema[0])
mesh = input_dtensor_spec.mesh
assert isinstance(
input_dtensor_spec, DTensorSpec
), "Expected first input to be a DTensorSpec"
global_in_shape = input_dtensor_spec.shape
@register_op_strategy(aten_op_overload, schema_info=schema_info)
def reshape_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema)
input_strategy = cast(OpStrategy, op_schema.args_schema[0])
global_in_shape = input_strategy.output_shape
assert global_in_shape is not None, "Shape required."
with disable_proxy_modes_tracing(), unset_fake_temporarily():
(
global_out_shape,
shard_out,
shardable_dims,
) = propagate_shape_and_sharding(
input_dtensor_spec.placements,
output_strategy = OpStrategy([])
for input_placement_strategy in input_strategy.strategies:
input_src_spec = input_placement_strategy.output_spec
input_tgt_placements, output_placements = propagate_shape_and_sharding(
input_src_spec.placements,
tuple(global_in_shape),
rules,
mesh.shape,
)
if shard_out is not None:
# no reshard needed
output_dtensor_spec = DTensorSpec(mesh=mesh, placements=tuple(shard_out))
# We only need the local shape to lower the call into the local op
args = op_schema.args_schema
shape_argnum = spec.shape_argnum
if shape_argnum is not None:
# compute the local shape from the global shape, then return
# a resharding even if we don't really reshard, the only reason
# for this type of resharding is to lower the global shape to
# local shape
local_out_shape = compute_local_shape(
list(global_out_shape), mesh, shard_out
)
suggested_schema = OpSchema(
op=op_schema.op,
args_schema=args[:shape_argnum]
+ (tuple(local_out_shape),)
+ args[shape_argnum + 1 :],
kwargs_schema=op_schema.kwargs_schema,
)
return OutputSharding(
output_spec=output_dtensor_spec,
redistribute_schema=suggested_schema,
needs_redistribute=True,
)
return OutputSharding(output_spec=output_dtensor_spec)
else:
# TODO: optimize this. we shouldn't simply blindly replicate
# unshardable dims ...
# FIXME: this can be wrong for situations where we have
# [Shard(0), Shard(0)]
suggested_placements = [
p
if not isinstance(p, Shard) or shardable_dims[p.dim][mesh_dim]
else Replicate()
for mesh_dim, p in enumerate(input_dtensor_spec.placements)
input_tgt_spec = DTensorSpec(
placements=tuple(input_tgt_placements),
mesh=input_src_spec.mesh,
tensor_meta=input_src_spec.tensor_meta,
)
redistribute_costs = [
generate_redistribute_costs(input_strategy, input_tgt_spec)
]
return OutputSharding(
output_spec=None,
redistribute_schema=OpSchema(
op=op_schema.op,
args_schema=(
DTensorSpec(
placements=tuple(suggested_placements),
mesh=input_dtensor_spec.mesh,
tensor_meta=input_dtensor_spec.tensor_meta,
),
)
+ op_schema.args_schema[1:],
kwargs_schema=op_schema.kwargs_schema,
),
output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements))
output_strategy.strategies.append(
PlacementStrategy(
output_specs=output_spec,
input_specs=(input_tgt_spec,),
redistribute_cost=redistribute_costs,
)
)
return output_strategy
register_prop_rule_map(aten.squeeze.default, torch.squeeze)
register_prop_rule_map(
register_op_strategy_map(aten.squeeze.default, torch.squeeze)
register_op_strategy_map(
aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1))
register_prop_rule_map(
register_op_strategy_map(
aten.view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
)
register_op_strategy_map(
aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten._unsafe_view.default, Tensor.view, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(
register_op_strategy_map(
aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1)
)
register_prop_rule_map(aten.view_as_complex.default, torch.view_as_complex)
register_prop_rule_map(aten.view_as_real.default, torch.view_as_real)
register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex)
register_op_strategy_map(aten.view_as_real.default, torch.view_as_real)

View File

@ -45,15 +45,21 @@ class ShardingPropagator:
# op map to save static argnum to decide to reuse sharding prop cache or re-run sharding prop
self.op_to_schema_info: Dict[OpOverload, RuntimeSchemaInfo] = {}
self.propagate_op_sharding = lru_cache(None)(self.propagate_op_sharding_non_cached) # type: ignore[method-assign]
# op map to save indices of size (and stride) args which may need to be modified in sharding prop
self.op_to_size_and_stride_idx: Dict[
# op map to save indices of shape (and stride) args which may need to be modified in sharding prop
self.op_to_shape_and_stride_idx: Dict[
OpOverload, Union[int, Tuple[int, int]]
] = {
# new factory ops
aten.new_empty.default: 1,
aten.new_full.default: 1,
aten.new_ones.default: 1,
aten.new_zeros.default: 1,
aten.new_empty_strided.default: (1, 2),
# view ops
aten.expand.default: 1,
aten.reshape.default: 1,
aten.view.default: 1,
aten._unsafe_view.default: 1,
}
def register_sharding_prop_rule(
@ -260,16 +266,19 @@ class ShardingPropagator:
)
suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
# size and stride args need to be modified for new factory ops, potentially
if op_schema.op in self.op_to_size_and_stride_idx:
# shape and stride args need to be modified for
# view ops and new factory ops, potentially
if op_schema.op in self.op_to_shape_and_stride_idx:
assert isinstance(output_strategy.output_spec, DTensorSpec)
# It happens when the output has the same shape as the input
# and the input placements are not all Replicate().
if output_strategy.output_spec.is_sharded():
needs_redistribute = True
suggestion_schema = self._adjust_size_and_stride_args(
op_schema, output_strategy.output_spec, mesh
schema = suggestion_schema or op_schema
assert isinstance(out_tensor_meta, TensorMeta)
suggestion_schema = self._adjust_shape_and_stride_args(
out_tensor_meta, schema, output_strategy.output_spec, mesh
)
needs_redistribute = True
# construct output spec for the op
if op_schema.return_type_tuple_tensor_like():
@ -442,29 +451,31 @@ class ShardingPropagator:
# for eager execution, we just select the one with the minimal redistribute cost
return strategy.strategies[strategy_costs.index(min(strategy_costs))]
def _adjust_size_and_stride_args(
self, op_schema: OpSchema, spec: DTensorSpec, mesh: DeviceMesh
def _adjust_shape_and_stride_args(
self,
out_tensor_meta: TensorMeta,
schema: OpSchema,
spec: DTensorSpec,
mesh: DeviceMesh,
) -> OpSchema:
size_stride_idx = self.op_to_size_and_stride_idx[op_schema.op]
if isinstance(size_stride_idx, tuple):
size_idx, stride_idx = size_stride_idx
shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op]
if isinstance(shape_stride_idx, tuple):
shape_idx, stride_idx = shape_stride_idx
else:
size_idx = size_stride_idx
shape_idx = shape_stride_idx
stride_idx = None
expected_input_schema = list(op_schema.args_schema)
size = cast(list, expected_input_schema[size_idx])
# # adjust size to be the same as that of the _local_tensor
# # of the DTensor input arg at index 0, which is inferred
expected_input_schema[size_idx] = compute_local_shape(
size, mesh, spec.placements
expected_input_schema = list(schema.args_schema)
# adjust shape to be the same as that of the _local_tensor
# of the DTensor input arg at index 0, which is inferred
expected_input_schema[shape_idx] = compute_local_shape(
out_tensor_meta.shape, mesh, spec.placements
)
# adjust the stride arg for aten.new_empty_strided.default
if stride_idx:
stride = cast(list, expected_input_schema[stride_idx])
expected_input_schema[stride_idx] = compute_local_stride(
stride, mesh, spec.placements
out_tensor_meta.stride, mesh, spec.placements
)
return OpSchema(op_schema.op, tuple(expected_input_schema), {})
return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema)