PEP585 update - torch/utils (#145201)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 16:17:30 -08:00
committed by PyTorch MergeBot
parent 693d8c7e94
commit 2f9d378f7b
70 changed files with 491 additions and 550 deletions

View File

@ -15,18 +15,8 @@ collection support for PyTorch APIs.
import functools
import sys
import types
from typing import (
Any,
Callable,
Iterable,
List,
Optional,
overload,
Tuple,
Type,
TypeVar,
Union,
)
from collections.abc import Iterable
from typing import Any, Callable, Optional, overload, TypeVar, Union
from typing_extensions import deprecated, TypeIs
import optree
@ -79,14 +69,14 @@ R = TypeVar("R")
Context = Any
PyTree = Any
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
DumpableContext = Any # Any json dumpable text
ToDumpableContextFn = Callable[[Context], DumpableContext]
FromDumpableContextFn = Callable[[DumpableContext], Context]
KeyPath = Tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
KeyPath = tuple[KeyEntry, ...]
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
@ -98,7 +88,7 @@ def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
def register_pytree_node(
cls: Type[Any],
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
@ -166,7 +156,7 @@ def register_pytree_node(
category=FutureWarning,
)
def _register_pytree_node(
cls: Type[Any],
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
@ -217,7 +207,7 @@ def _register_pytree_node(
def _private_register_pytree_node(
cls: Type[Any],
cls: type[Any],
flatten_fn: FlattenFunc,
unflatten_fn: UnflattenFunc,
*,
@ -285,7 +275,7 @@ def tree_is_leaf(
def tree_flatten(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Tuple[List[Any], TreeSpec]:
) -> tuple[list[Any], TreeSpec]:
"""Flatten a pytree.
See also :func:`tree_unflatten`.
@ -395,7 +385,7 @@ def tree_iter(
def tree_leaves(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
) -> list[Any]:
"""Get the leaves of a pytree.
See also :func:`tree_flatten`.
@ -549,12 +539,12 @@ def tree_map_(
)
Type2 = Tuple[Type[T], Type[S]]
Type3 = Tuple[Type[T], Type[S], Type[U]]
Type2 = tuple[type[T], type[S]]
Type3 = tuple[type[T], type[S], type[U]]
if sys.version_info >= (3, 10):
TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
else:
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
TypeAny = Union[type[Any], tuple[type[Any], ...]]
Fn2 = Callable[[Union[T, S]], R]
Fn3 = Callable[[Union[T, S, U]], R]
@ -577,7 +567,7 @@ def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U,
@overload
def map_only(type_or_types_or_pred: Type[T], /) -> MapOnlyFn[Fn[T, Any]]:
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
...
@ -640,7 +630,7 @@ def map_only(
@overload
def tree_map_only(
type_or_types_or_pred: Type[T],
type_or_types_or_pred: type[T],
/,
func: Fn[T, Any],
tree: PyTree,
@ -694,7 +684,7 @@ def tree_map_only(
@overload
def tree_map_only_(
type_or_types_or_pred: Type[T],
type_or_types_or_pred: type[T],
/,
func: Fn[T, Any],
tree: PyTree,
@ -766,7 +756,7 @@ def tree_any(
@overload
def tree_all_only(
type_or_types: Type[T],
type_or_types: type[T],
/,
pred: Fn[T, bool],
tree: PyTree,
@ -810,7 +800,7 @@ def tree_all_only(
@overload
def tree_any_only(
type_or_types: Type[T],
type_or_types: type[T],
/,
pred: Fn[T, bool],
tree: PyTree,
@ -856,7 +846,7 @@ def broadcast_prefix(
prefix_tree: PyTree,
full_tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Any]:
) -> list[Any]:
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
@ -891,7 +881,7 @@ def broadcast_prefix(
Returns:
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
"""
result: List[Any] = []
result: list[Any] = []
def add_leaves(x: Any, subtree: PyTree) -> None:
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
@ -918,7 +908,7 @@ def _broadcast_to_and_flatten(
tree: PyTree,
treespec: TreeSpec,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Optional[List[Any]]:
) -> Optional[list[Any]]:
assert _is_pytreespec_instance(treespec)
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
try:
@ -977,7 +967,7 @@ class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
def tree_flatten_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
Args:
@ -1000,7 +990,7 @@ def tree_flatten_with_path(
def tree_leaves_with_path(
tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> List[Tuple[KeyPath, Any]]:
) -> list[tuple[KeyPath, Any]]:
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
Args: