From 4c1fabf2c9eb0c9773b09ff56761f8361fb60304 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Mon, 14 Jul 2025 15:58:20 -0700 Subject: [PATCH] [DTensor] have split_strategy return OpStrategy instead of TupleStrategy (#158051) **Summary** `split_strategy` used `TupleStrategy` as return type because DTensor sharding propagation's `OpStrategy` support on multi-returns only applies to `Tuple`. However, `TupleStrategy`'s not a good fit for `split` op. `TupleStrategy` was initially introduced to handle the sharding strategy of `foreach_*` ops where the input args can be split into independent subsets regarding sharding decisions, so are the outputs. To address the misuse, this PR adds `OpStrategy` propagation for `List[Tensor]` (note that this support is INCOMPLETE because it only checks the return type to be `torch.ListType`). Nevertheless, the logic for `Tuple` returns also made similar assumption so I think it's fine to unblock in such a way. Besides adding `OpStrategy` support to ops having `List[Tensor]` return type, this PR also changes `split_strategy`'s return from `TupleStrategy` to `OpStrategy`. **Test** `pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158051 Approved by: https://github.com/wconstab, https://github.com/zpcore --- torch/distributed/tensor/_op_schema.py | 7 ++++ torch/distributed/tensor/_ops/_tensor_ops.py | 40 +++++++++++--------- torch/distributed/tensor/_sharding_prop.py | 5 ++- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/torch/distributed/tensor/_op_schema.py b/torch/distributed/tensor/_op_schema.py index acf15c6c0ea4..54d85aa1b3ab 100644 --- a/torch/distributed/tensor/_op_schema.py +++ b/torch/distributed/tensor/_op_schema.py @@ -345,6 +345,13 @@ class OpSchema: return_types[0].type, torch.TensorType ) + def return_type_list_tensor_like(self) -> bool: + # returns True if the return type is a List + return_types = self.op._schema.returns + return len(return_types) == 1 and isinstance( + return_types[0].type, torch.ListType + ) + def return_type_tensor(self) -> bool: return_types = self.op._schema.returns # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index a81db1a3b124..9bdfc90d145d 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -1074,7 +1074,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding: ], RuntimeSchemaInfo(1), ) -def split_strategy(op_schema: OpSchema) -> TupleStrategy: +def split_strategy(op_schema: OpSchema) -> OpStrategy: input_strategy = op_schema.args_schema[0] split_size_or_sections = op_schema.args_schema[1] assert isinstance(input_strategy, OpStrategy) @@ -1097,23 +1097,27 @@ def split_strategy(op_schema: OpSchema) -> TupleStrategy: ) assert isinstance(output_size_list, Sized) - split_strategies = [] + all_strategies = [] + for strategy in input_strategy.strategies: + spec = strategy.output_spec + placements = spec.placements + if is_tensor_dim_sharded(spec, dim=dim): + # if the input is sharded on the split dim, we need to unshard it + placements = unshard_tensor_dim(spec.placements, dim=dim) - for _ in range(len(output_size_list)): - op_strategy = OpStrategy([]) - - for strategy in input_strategy.strategies: - spec = strategy.output_spec - placements = spec.placements - if is_tensor_dim_sharded(spec, dim=dim): - # if the input is sharded on the split dim, we need to unshard it - placements = unshard_tensor_dim(spec.placements, dim=dim) - - spec = DTensorSpec(spec.mesh, placements) - - op_strategy.strategies.append( - OpSpec(output_specs=spec, input_specs=([spec])) + input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta) + output_specs = tuple( + DTensorSpec(spec.device_mesh, placements) + for _ in range(len(output_size_list)) + ) + all_strategies.append( + OpSpec( + output_specs=output_specs, + input_specs=(input_spec,), + redistribute_cost=[ + generate_redistribute_costs(input_strategy, input_spec) + ], ) - split_strategies.append(op_strategy) + ) - return TupleStrategy(split_strategies) + return OpStrategy(all_strategies) diff --git a/torch/distributed/tensor/_sharding_prop.py b/torch/distributed/tensor/_sharding_prop.py index 4b1536644b87..69af19fea26a 100644 --- a/torch/distributed/tensor/_sharding_prop.py +++ b/torch/distributed/tensor/_sharding_prop.py @@ -353,7 +353,10 @@ class ShardingPropagator: for _ in range(len(op_schema.op._schema.returns)) ] ) - elif op_schema.return_type_tensor(): + elif ( + op_schema.return_type_tensor() + or op_schema.return_type_list_tensor_like() + ): output_specs = output_strategy.output_specs else: output_specs = None