[BE] fix typo in torch/distributed/tensor/: childs -> children (#156609)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156609
Approved by: https://github.com/wanchaol, https://github.com/cyyever
ghstack dependencies: #156311
This commit is contained in:
Xuehai Pan
2025-07-09 13:23:54 +08:00
committed by PyTorch MergeBot
parent 4cc8b60d1b
commit ffe11b2bf2
6 changed files with 37 additions and 22 deletions

View File

@ -3,6 +3,7 @@ from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Optional, Union
from typing_extensions import deprecated
import torch
from torch._ops import OpOverload
@ -174,18 +175,32 @@ class TupleStrategy(StrategyType):
then we should return a single OpStrategy instead of a TupleStrategy
"""
def __init__(self, childs: Sequence[StrategyType]) -> None:
def __init__(
self,
children: Sequence[StrategyType],
) -> None:
super().__init__()
self.childs: Sequence[StrategyType] = childs
self.children: Sequence[StrategyType] = children
@property
@deprecated(
"TupleStrategy.childs is deprecated, use TupleStrategy.children instead.", # codespell:ignore childs
category=FutureWarning,
)
def childs(self) -> Sequence[StrategyType]: # codespell:ignore childs
"""
Alias for children, to maintain backward compatibility.
"""
return self.children
def child_mesh(self, index: int) -> DeviceMesh:
op_strategy = self.childs[index]
op_strategy = self.children[index]
assert isinstance(op_strategy, OpStrategy)
return op_strategy.mesh
def __str__(self) -> str:
child_strategies_str = ", ".join(
[f"{str(strat)}" for idx, strat in enumerate(self.childs)]
[f"{str(strat)}" for idx, strat in enumerate(self.children)]
)
return f"TupleStrategy({child_strategies_str})"
@ -282,7 +297,7 @@ class OpSchema:
args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs))
mesh_shape = arg.mesh_shape
elif isinstance(arg, TupleStrategy):
first_op_strategy = arg.childs[0]
first_op_strategy = arg.children[0]
assert isinstance(first_op_strategy, OpStrategy)
mesh_shape = first_op_strategy.mesh_shape
args_schema.append(str(arg))
@ -342,7 +357,7 @@ class OpSchema:
mesh = first_arg.mesh
elif isinstance(first_arg, (list, tuple, TupleStrategy)):
first_elem = (
first_arg.childs[0]
first_arg.children[0]
if isinstance(first_arg, TupleStrategy)
else first_arg[0]
)

View File

@ -419,8 +419,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
assert isinstance(input_tuple_strategy, TupleStrategy)
norm_type = args_schema[1] if len(args_schema) > 1 else 2
assert isinstance(norm_type, (int, float, str)), f"{norm_type}"
output_tuple_strategy_childs: list[OpStrategy] = []
for op_strategy in input_tuple_strategy.childs:
output_tuple_strategy_children: list[OpStrategy] = []
for op_strategy in input_tuple_strategy.children:
assert isinstance(op_strategy, OpStrategy), f"{op_strategy}"
reduce_dims = list(range(op_strategy.ndim))
output_strategy = common_reduction_strategy(
@ -429,8 +429,8 @@ def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy:
reduction_linear=True,
reduction_op=NormReduction(norm_type),
)
output_tuple_strategy_childs.append(output_strategy)
return TupleStrategy(output_tuple_strategy_childs)
output_tuple_strategy_children.append(output_strategy)
return TupleStrategy(output_tuple_strategy_children)
@register_op_strategy(

View File

@ -707,12 +707,12 @@ def list_pointwise_strategy(
) -> list[Optional[TupleStrategy]]:
first_arg = args_schema[0]
assert isinstance(first_arg, TupleStrategy)
strategy_len = len(first_arg.childs)
strategy_len = len(first_arg.children)
tuple_strategies: list[Optional[TupleStrategy]] = []
for arg_idx, arg in enumerate(args_schema):
if isinstance(arg, TupleStrategy):
# every tuple strategy should have the same length
assert len(arg.childs) == strategy_len
assert len(arg.children) == strategy_len
tuple_strategies.append(arg)
elif isinstance(arg, OpStrategy):
if arg_idx > 0: # implicitly broadcast
@ -732,10 +732,10 @@ def list_pointwise_strategy(
follow_strategy: TupleStrategy = not_none(args_strategies[0])
list_strategy: list[OpStrategy] = []
for child_idx, child_strtgy in enumerate(follow_strategy.childs):
for child_idx, child_strtgy in enumerate(follow_strategy.children):
assert isinstance(child_strtgy, OpStrategy)
args_schema: list[Optional[OpStrategy]] = [
cast(OpStrategy, arg_strategy.childs[child_idx]) if arg_strategy else None
cast(OpStrategy, arg_strategy.children[child_idx]) if arg_strategy else None
for arg_strategy in args_strategies
]
pointwise_strategy: OpStrategy = common_pointwise_strategy(

View File

@ -602,7 +602,7 @@ def _derive_follow_placements_from_tuple_strategy(
follow_placements: Optional[list[Placement]] = None
mesh = tuple_strategy.child_mesh(0)
for arg_strategy in tuple_strategy.childs:
for arg_strategy in tuple_strategy.children:
assert isinstance(arg_strategy, OpStrategy)
if arg_strategy.mesh != mesh:
raise ValueError(
@ -644,7 +644,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
first_input_strategy = input_tuple_strategy.children[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
@ -662,7 +662,7 @@ def stack_strategy(op_schema: OpSchema) -> StrategyType:
input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
for _ in range(len(input_tuple_strategy.children))
)
follow_placements = normalize_shard_for_stack(follow_placements, dim)
@ -681,7 +681,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
input_tuple_strategy = args_schema[0]
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
first_input_strategy = input_tuple_strategy.childs[0]
first_input_strategy = input_tuple_strategy.children[0]
assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}"
common_input_ndim = first_input_strategy.ndim
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
@ -701,7 +701,7 @@ def cat_strategy(op_schema: OpSchema) -> StrategyType:
input_specs = tuple(
DTensorSpec(mesh, tuple(follow_placements))
for _ in range(len(input_tuple_strategy.childs))
for _ in range(len(input_tuple_strategy.children))
)
op_strategy.strategies.append(
OpSpec(
@ -765,7 +765,7 @@ def prop_index_put(op_schema: OpSchema) -> StrategyType:
# 1. `indices` should all be replicated first.
indices_redistribute_costs = []
new_indices_spec: list[Optional[DTensorSpec]] = []
for indices_spec_child in indices_spec.childs:
for indices_spec_child in indices_spec.children:
assert isinstance(indices_spec_child, OpStrategy)
replicated_spec = DTensorSpec(

View File

@ -368,7 +368,7 @@ class ShardingPropagator:
# runtime select OpSpec for each TupleStrategy input arg
selected_strategies: list[OpSpec] = []
out_spec_list: list[DTensorSpec] = []
for strategy in op_strategy.childs:
for strategy in op_strategy.children:
assert isinstance(strategy, OpStrategy)
selected_strategy = self._select_strategy(strategy)
selected_strategies.append(selected_strategy)

View File

@ -77,7 +77,7 @@ def register_sharding(op: Union[OpOverload, list[OpOverload]]):
# take the output spec from the first strategy
return strategy.strategies[0].output_spec
elif isinstance(strategy, TupleStrategy):
return tuple(strategy_to_spec(s) for s in strategy.childs)
return tuple(strategy_to_spec(s) for s in strategy.children)
else:
return strategy