mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DTensor] have split_strategy return OpStrategy instead of TupleStrategy (#158051)
**Summary** `split_strategy` used `TupleStrategy` as return type because DTensor sharding propagation's `OpStrategy` support on multi-returns only applies to `Tuple`. However, `TupleStrategy`'s not a good fit for `split` op. `TupleStrategy` was initially introduced to handle the sharding strategy of `foreach_*` ops where the input args can be split into independent subsets regarding sharding decisions, so are the outputs. To address the misuse, this PR adds `OpStrategy` propagation for `List[Tensor]` (note that this support is INCOMPLETE because it only checks the return type to be `torch.ListType`). Nevertheless, the logic for `Tuple` returns also made similar assumption so I think it's fine to unblock in such a way. Besides adding `OpStrategy` support to ops having `List[Tensor]` return type, this PR also changes `split_strategy`'s return from `TupleStrategy` to `OpStrategy`. **Test** `pytest test/distributed/tensor/test_tensor_ops.py -s -k test_split_on_partial` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158051 Approved by: https://github.com/wconstab, https://github.com/zpcore
This commit is contained in:
committed by
PyTorch MergeBot
parent
a2ad16be72
commit
4c1fabf2c9
@ -345,6 +345,13 @@ class OpSchema:
|
||||
return_types[0].type, torch.TensorType
|
||||
)
|
||||
|
||||
def return_type_list_tensor_like(self) -> bool:
|
||||
# returns True if the return type is a List
|
||||
return_types = self.op._schema.returns
|
||||
return len(return_types) == 1 and isinstance(
|
||||
return_types[0].type, torch.ListType
|
||||
)
|
||||
|
||||
def return_type_tensor(self) -> bool:
|
||||
return_types = self.op._schema.returns
|
||||
# all dispatch ops only return Tensor or Tuple[Tensor] for tensor like
|
||||
|
@ -1074,7 +1074,7 @@ def prop_index(op_schema: OpSchema) -> OutputSharding:
|
||||
],
|
||||
RuntimeSchemaInfo(1),
|
||||
)
|
||||
def split_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||
def split_strategy(op_schema: OpSchema) -> OpStrategy:
|
||||
input_strategy = op_schema.args_schema[0]
|
||||
split_size_or_sections = op_schema.args_schema[1]
|
||||
assert isinstance(input_strategy, OpStrategy)
|
||||
@ -1097,11 +1097,7 @@ def split_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||
)
|
||||
assert isinstance(output_size_list, Sized)
|
||||
|
||||
split_strategies = []
|
||||
|
||||
for _ in range(len(output_size_list)):
|
||||
op_strategy = OpStrategy([])
|
||||
|
||||
all_strategies = []
|
||||
for strategy in input_strategy.strategies:
|
||||
spec = strategy.output_spec
|
||||
placements = spec.placements
|
||||
@ -1109,11 +1105,19 @@ def split_strategy(op_schema: OpSchema) -> TupleStrategy:
|
||||
# if the input is sharded on the split dim, we need to unshard it
|
||||
placements = unshard_tensor_dim(spec.placements, dim=dim)
|
||||
|
||||
spec = DTensorSpec(spec.mesh, placements)
|
||||
|
||||
op_strategy.strategies.append(
|
||||
OpSpec(output_specs=spec, input_specs=([spec]))
|
||||
input_spec = DTensorSpec(spec.device_mesh, placements, spec.tensor_meta)
|
||||
output_specs = tuple(
|
||||
DTensorSpec(spec.device_mesh, placements)
|
||||
for _ in range(len(output_size_list))
|
||||
)
|
||||
all_strategies.append(
|
||||
OpSpec(
|
||||
output_specs=output_specs,
|
||||
input_specs=(input_spec,),
|
||||
redistribute_cost=[
|
||||
generate_redistribute_costs(input_strategy, input_spec)
|
||||
],
|
||||
)
|
||||
)
|
||||
split_strategies.append(op_strategy)
|
||||
|
||||
return TupleStrategy(split_strategies)
|
||||
return OpStrategy(all_strategies)
|
||||
|
@ -353,7 +353,10 @@ class ShardingPropagator:
|
||||
for _ in range(len(op_schema.op._schema.returns))
|
||||
]
|
||||
)
|
||||
elif op_schema.return_type_tensor():
|
||||
elif (
|
||||
op_schema.return_type_tensor()
|
||||
or op_schema.return_type_list_tensor_like()
|
||||
):
|
||||
output_specs = output_strategy.output_specs
|
||||
else:
|
||||
output_specs = None
|
||||
|
Reference in New Issue
Block a user