[DTensor] have split_strategy return OpStrategy instead of TupleStrategy (#158051)

**Summary**
`split_strategy` used `TupleStrategy` as return type because DTensor sharding
propagation's `OpStrategy` support on multi-returns only applies to `Tuple`.

However, `TupleStrategy`'s not a good fit for `split` op. `TupleStrategy` was
initially introduced to handle the sharding strategy of `foreach_*` ops where
the input args can be split into independent subsets regarding sharding decisions,
so are the outputs.

To address the misuse, this PR adds `OpStrategy` propagation for `List[Tensor]`
(note that this support is INCOMPLETE because it only checks the return type
to be `torch.ListType`). Nevertheless, the logic for `Tuple` returns also made
similar assumption so I think it's fine to unblock in such a way.

Besides adding `OpStrategy` support to ops having `List[Tensor]` return type,
this PR also changes `split_strategy`'s return from `TupleStrategy` to `OpStrategy`.

**Test**
`pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158051
Approved by: https://github.com/wconstab, https://github.com/zpcore
This commit is contained in:
Xilun Wu
2025-07-14 15:58:20 -07:00
committed by PyTorch MergeBot
parent a2ad16be72
commit 4c1fabf2c9
3 changed files with 33 additions and 19 deletions

View File

@ -345,6 +345,13 @@ class OpSchema:
return_types[0].type, torch.TensorType
)
def return_type_list_tensor_like(self) -> bool:
# returns True if the return type is a List
return_types = self.op._schema.returns
return len(return_types) == 1 and isinstance(
return_types[0].type, torch.ListType
)
def return_type_tensor(self) -> bool:
return_types = self.op._schema.returns
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like

View File

@ -1074,7 +1074,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
],
RuntimeSchemaInfo(1),
)
def split_strategy(op_schema: OpSchema) -> TupleStrategy:
def split_strategy(op_schema: OpSchema) -> OpStrategy:
input_strategy = op_schema.args_schema[0]
split_size_or_sections = op_schema.args_schema[1]
assert isinstance(input_strategy, OpStrategy)
@ -1097,23 +1097,27 @@ def split_strategy(op_schema: OpSchema) -> TupleStrategy:
)
assert isinstance(output_size_list, Sized)
split_strategies = []
all_strategies = []
for strategy in input_strategy.strategies:
spec = strategy.output_spec
placements = spec.placements
if is_tensor_dim_sharded(spec, dim=dim):
# if the input is sharded on the split dim, we need to unshard it
placements = unshard_tensor_dim(spec.placements, dim=dim)
for _ in range(len(output_size_list)):
op_strategy = OpStrategy([])
for strategy in input_strategy.strategies:
spec = strategy.output_spec
placements = spec.placements
if is_tensor_dim_sharded(spec, dim=dim):
# if the input is sharded on the split dim, we need to unshard it
placements = unshard_tensor_dim(spec.placements, dim=dim)
spec = DTensorSpec(spec.mesh, placements)
op_strategy.strategies.append(
OpSpec(output_specs=spec, input_specs=([spec]))
input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta)
output_specs = tuple(
DTensorSpec(spec.device_mesh, placements)
for _ in range(len(output_size_list))
)
all_strategies.append(
OpSpec(
output_specs=output_specs,
input_specs=(input_spec,),
redistribute_cost=[
generate_redistribute_costs(input_strategy, input_spec)
],
)
split_strategies.append(op_strategy)
)
return TupleStrategy(split_strategies)
return OpStrategy(all_strategies)

View File

@ -353,7 +353,10 @@ class ShardingPropagator:
for _ in range(len(op_schema.op._schema.returns))
]
)
elif op_schema.return_type_tensor():
elif (
op_schema.return_type_tensor()
or op_schema.return_type_list_tensor_like()
):
output_specs = output_strategy.output_specs
else:
output_specs = None