mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
By adding a few small helpers (e.g., a `splice` method to `_MeshLayout`, and making `_init_process_groups` static and thus stateless) we can substantially shorten the definition of the unflatten method, and help readability. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165556 Approved by: https://github.com/fduwjj ghstack dependencies: #165554, #165555
272 lines
9.6 KiB
Python
272 lines
9.6 KiB
Python
#################################################################################################
|
|
#
|
|
# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: BSD-3-Clause
|
|
#
|
|
# Redistribution and use in source and binary forms, with or without
|
|
# modification, are permitted provided that the following conditions are met:
|
|
#
|
|
# 1. Redistributions of source code must retain the above copyright notice, this
|
|
# list of conditions and the following disclaimer.
|
|
#
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
# and/or other materials provided with the distribution.
|
|
#
|
|
# 3. Neither the name of the copyright holder nor the names of its
|
|
# contributors may be used to endorse or promote products derived from
|
|
# this software without specific prior written permission.
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
#
|
|
#################################################################################################
|
|
|
|
"""
|
|
Functions for manipulating IntTuples
|
|
"""
|
|
|
|
from functools import reduce
|
|
from itertools import chain
|
|
from typing import Optional, TypeAlias, Union
|
|
from typing_extensions import TypeIs
|
|
|
|
from .typing import Integer
|
|
|
|
|
|
# Type aliases for better readability
|
|
IntTuple: TypeAlias = Union[int, tuple["IntTuple", ...]]
|
|
|
|
|
|
def is_int(x: object) -> TypeIs[int]:
|
|
return isinstance(x, Integer)
|
|
|
|
|
|
def is_tuple(x: object) -> TypeIs[tuple]:
|
|
return isinstance(x, tuple)
|
|
|
|
|
|
def as_tuple(x: IntTuple) -> tuple[IntTuple, ...]:
|
|
if is_int(x):
|
|
return (x,)
|
|
return x
|
|
|
|
|
|
def match_structure(a: IntTuple, b: IntTuple) -> bool:
|
|
if is_int(a) and is_int(b):
|
|
return True
|
|
if is_tuple(a) and is_tuple(b):
|
|
return len(a) == len(b) and all(match_structure(x, y) for x, y in zip(a, b))
|
|
return False
|
|
|
|
|
|
def flatten(t: IntTuple) -> tuple[int, ...]:
|
|
if is_tuple(t):
|
|
if len(t) == 0:
|
|
return ()
|
|
else:
|
|
return tuple(i for a in t for i in flatten(a))
|
|
else:
|
|
return (t,)
|
|
|
|
|
|
def signum(a: int) -> int:
|
|
return bool(a > 0) - bool(a < 0)
|
|
|
|
|
|
def product(a: IntTuple) -> int:
|
|
if is_tuple(a):
|
|
return reduce(lambda val, elem: val * product(elem), a, 1)
|
|
else:
|
|
return a
|
|
|
|
|
|
def inner_product(a: IntTuple, b: IntTuple) -> int:
|
|
if is_tuple(a) and is_tuple(b): # tuple tuple
|
|
assert len(a) == len(b)
|
|
return sum(inner_product(x, y) for x, y in zip(a, b))
|
|
else: # "int" "int"
|
|
assert not is_tuple(a) and not is_tuple(b)
|
|
return a * b
|
|
|
|
|
|
def tuple_max(a: IntTuple) -> int:
|
|
if is_tuple(a):
|
|
return max(tuple_max(x) for x in a)
|
|
else:
|
|
return a
|
|
|
|
|
|
def elem_scale(a: IntTuple, b: IntTuple) -> IntTuple:
|
|
if is_tuple(a):
|
|
if is_tuple(b): # tuple tuple
|
|
assert len(a) == len(b)
|
|
return tuple(elem_scale(x, y) for x, y in zip(a, b))
|
|
else: # tuple "int"
|
|
raise AssertionError("Invalid combination: tuple with int")
|
|
else:
|
|
if is_tuple(b): # "int" tuple
|
|
return elem_scale(a, product(b))
|
|
else: # "int" "int"
|
|
return a * b
|
|
|
|
|
|
# Inclusive prefix ceil div with output congruent to input a
|
|
def shape_div(a: IntTuple, b: IntTuple) -> IntTuple:
|
|
if is_tuple(a):
|
|
if is_tuple(b): # tuple tuple
|
|
assert len(a) == len(b)
|
|
return tuple(shape_div(x, y) for x, y in zip(a, b))
|
|
else: # tuple "int"
|
|
# r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
|
|
r = []
|
|
for v in a:
|
|
r.append(shape_div(v, b))
|
|
b = shape_div(b, product(v))
|
|
return tuple(r)
|
|
else:
|
|
if is_tuple(b): # "int" tuple
|
|
return shape_div(a, product(b))
|
|
else: # "int" "int"
|
|
assert a % b == 0 or b % a == 0
|
|
return (a + b - 1) // b
|
|
|
|
|
|
# Exclusive suffix product with output congruent to input a (lexicographic)
|
|
def suffix_product(a: IntTuple, init: IntTuple = 1) -> IntTuple:
|
|
# TODO: With all these length asserts, may want to create a zip_strict wrapper.
|
|
if is_tuple(a):
|
|
if is_tuple(init): # tuple tuple
|
|
assert len(a) == len(init)
|
|
return tuple(suffix_product(x, i) for x, i in zip(a, init))
|
|
else: # tuple "int"
|
|
# Process from right to left for lexicographic ordering
|
|
# r = [prefix_product(a[len(a)-1],init)] +
|
|
# [prefix_product(a[i],init := init * product(a[i+1])) for i in range(len(a)-1,0)].reverse()
|
|
r = []
|
|
|
|
# Calculate products from right to left, appending to list
|
|
for i in range(len(a) - 1, -1, -1):
|
|
r.append(suffix_product(a[i], init))
|
|
init = init * product(a[i])
|
|
|
|
# Reverse to get correct lexicographic order
|
|
r.reverse()
|
|
return tuple(r)
|
|
else:
|
|
if is_tuple(init): # "int" tuple
|
|
raise AssertionError("Invalid combination: int with tuple init")
|
|
else: # "int" "int"
|
|
return init
|
|
|
|
|
|
def idx2crd(
|
|
idx: IntTuple, shape: IntTuple, stride: Optional[IntTuple] = None
|
|
) -> IntTuple:
|
|
if stride is None:
|
|
stride = suffix_product(shape)
|
|
|
|
if is_tuple(idx):
|
|
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
|
|
assert len(idx) == len(shape) and len(stride) == len(shape)
|
|
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx, shape, stride))
|
|
else: # tuple "int" "int"
|
|
raise AssertionError("Invalid combination: tuple with int stride")
|
|
else:
|
|
if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple
|
|
assert len(shape) == len(stride)
|
|
return tuple(idx2crd(idx, s, d) for s, d in zip(shape, stride))
|
|
else: # "int" "int" "int"
|
|
assert not is_tuple(shape) and not is_tuple(stride)
|
|
return (idx // stride) % shape # all are ints after type checks
|
|
|
|
|
|
def crd2idx(
|
|
crd: Optional[IntTuple], shape: IntTuple, stride: Optional[IntTuple] = None
|
|
) -> int:
|
|
if stride is None:
|
|
stride = suffix_product(shape)
|
|
|
|
if is_tuple(crd):
|
|
if is_tuple(shape) and is_tuple(stride): # tuple tuple tuple
|
|
assert len(crd) == len(shape) and len(stride) == len(shape)
|
|
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
|
|
else: # tuple "int" "int"
|
|
raise AssertionError(f"Invalid combination: crd={crd}, shape={shape}")
|
|
else:
|
|
if crd is None:
|
|
crd = 0
|
|
|
|
if is_tuple(shape) and is_tuple(stride): # "int" tuple tuple
|
|
assert len(shape) == len(stride)
|
|
result = 0
|
|
# Process from right to left for lexicographic ordering
|
|
for i in range(len(shape) - 1, 0, -1):
|
|
result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
|
|
crd = crd // product(shape[i])
|
|
if len(shape) > 0:
|
|
result += crd2idx(crd, shape[0], stride[0])
|
|
return result
|
|
else: # "int" "int" "int"
|
|
assert not is_tuple(shape) and not is_tuple(stride)
|
|
return crd * stride # all are ints after type checks
|
|
|
|
|
|
# Transform crd into the dst_shape's iteration space
|
|
def crd2crd(
|
|
crd: IntTuple, dst_shape: IntTuple, src_shape: Optional[IntTuple] = None
|
|
) -> IntTuple:
|
|
if is_tuple(crd):
|
|
if is_tuple(dst_shape): # tuple tuple
|
|
assert len(crd) == len(dst_shape)
|
|
return tuple(crd2crd(x, y) for x, y in zip(crd, dst_shape))
|
|
else: # tuple "int"
|
|
# Ambiguous unless we have src_shape
|
|
assert src_shape is not None
|
|
return crd2idx(crd, src_shape)
|
|
else:
|
|
if is_tuple(dst_shape): # "int" tuple
|
|
return idx2crd(crd, dst_shape)
|
|
else: # "int" "int"
|
|
assert crd < dst_shape
|
|
return crd
|
|
|
|
|
|
# Filter trg according to crd: keep only elements of trg that are paired with None
|
|
def slice_(crd: Union[None, tuple, int], trg: Union[tuple, int]) -> Union[tuple, int]:
|
|
if is_tuple(crd):
|
|
if is_tuple(trg): # tuple tuple
|
|
assert len(crd) == len(trg)
|
|
# match C++ behavior of `filter_tuple` using `tuple_cat(...)`
|
|
return tuple(
|
|
chain(
|
|
*filter( # type: ignore[arg-type] # filter returns Iterator which is compatible
|
|
lambda x: x != (),
|
|
[slice_(c, s) for c, s in zip(crd, trg)],
|
|
)
|
|
)
|
|
)
|
|
else:
|
|
raise AssertionError("Invalid combination: tuple crd with int trg")
|
|
elif crd is None:
|
|
# match C++ behavior `return cute::tuple<B>{b};`
|
|
return (trg,)
|
|
else:
|
|
return ()
|
|
|
|
|
|
# Determine if None appears at any of an int_tuples' terminals
|
|
def has_none(a: Union[None, tuple, int]) -> bool:
|
|
if is_tuple(a):
|
|
return any(has_none(v) for v in a)
|
|
else:
|
|
return a is None
|