mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
PEP585 update - torch/utils (#145201)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
693d8c7e94
commit
2f9d378f7b
@ -15,18 +15,8 @@ collection support for PyTorch APIs.
|
||||
import functools
|
||||
import sys
|
||||
import types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||
from typing_extensions import deprecated, TypeIs
|
||||
|
||||
import optree
|
||||
@ -79,14 +69,14 @@ R = TypeVar("R")
|
||||
|
||||
Context = Any
|
||||
PyTree = Any
|
||||
FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]]
|
||||
FlattenFunc = Callable[[PyTree], tuple[list[Any], Context]]
|
||||
UnflattenFunc = Callable[[Iterable[Any], Context], PyTree]
|
||||
OpTreeUnflattenFunc = Callable[[Context, Iterable[Any]], PyTree]
|
||||
DumpableContext = Any # Any json dumpable text
|
||||
ToDumpableContextFn = Callable[[Context], DumpableContext]
|
||||
FromDumpableContextFn = Callable[[DumpableContext], Context]
|
||||
KeyPath = Tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]]
|
||||
KeyPath = tuple[KeyEntry, ...]
|
||||
FlattenWithKeysFunc = Callable[[PyTree], tuple[list[tuple[KeyEntry, Any]], Any]]
|
||||
|
||||
|
||||
def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
@ -98,7 +88,7 @@ def _reverse_args(func: UnflattenFunc) -> OpTreeUnflattenFunc:
|
||||
|
||||
|
||||
def register_pytree_node(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
@ -166,7 +156,7 @@ def register_pytree_node(
|
||||
category=FutureWarning,
|
||||
)
|
||||
def _register_pytree_node(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
@ -217,7 +207,7 @@ def _register_pytree_node(
|
||||
|
||||
|
||||
def _private_register_pytree_node(
|
||||
cls: Type[Any],
|
||||
cls: type[Any],
|
||||
flatten_fn: FlattenFunc,
|
||||
unflatten_fn: UnflattenFunc,
|
||||
*,
|
||||
@ -285,7 +275,7 @@ def tree_is_leaf(
|
||||
def tree_flatten(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Tuple[List[Any], TreeSpec]:
|
||||
) -> tuple[list[Any], TreeSpec]:
|
||||
"""Flatten a pytree.
|
||||
|
||||
See also :func:`tree_unflatten`.
|
||||
@ -395,7 +385,7 @@ def tree_iter(
|
||||
def tree_leaves(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
"""Get the leaves of a pytree.
|
||||
|
||||
See also :func:`tree_flatten`.
|
||||
@ -549,12 +539,12 @@ def tree_map_(
|
||||
)
|
||||
|
||||
|
||||
Type2 = Tuple[Type[T], Type[S]]
|
||||
Type3 = Tuple[Type[T], Type[S], Type[U]]
|
||||
Type2 = tuple[type[T], type[S]]
|
||||
Type3 = tuple[type[T], type[S], type[U]]
|
||||
if sys.version_info >= (3, 10):
|
||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType]
|
||||
TypeAny = Union[type[Any], tuple[type[Any], ...], types.UnionType]
|
||||
else:
|
||||
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]]
|
||||
TypeAny = Union[type[Any], tuple[type[Any], ...]]
|
||||
|
||||
Fn2 = Callable[[Union[T, S]], R]
|
||||
Fn3 = Callable[[Union[T, S, U]], R]
|
||||
@ -577,7 +567,7 @@ def map_only(type_or_types_or_pred: Type3[T, S, U], /) -> MapOnlyFn[Fn3[T, S, U,
|
||||
|
||||
|
||||
@overload
|
||||
def map_only(type_or_types_or_pred: Type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
def map_only(type_or_types_or_pred: type[T], /) -> MapOnlyFn[Fn[T, Any]]:
|
||||
...
|
||||
|
||||
|
||||
@ -640,7 +630,7 @@ def map_only(
|
||||
|
||||
@overload
|
||||
def tree_map_only(
|
||||
type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
@ -694,7 +684,7 @@ def tree_map_only(
|
||||
|
||||
@overload
|
||||
def tree_map_only_(
|
||||
type_or_types_or_pred: Type[T],
|
||||
type_or_types_or_pred: type[T],
|
||||
/,
|
||||
func: Fn[T, Any],
|
||||
tree: PyTree,
|
||||
@ -766,7 +756,7 @@ def tree_any(
|
||||
|
||||
@overload
|
||||
def tree_all_only(
|
||||
type_or_types: Type[T],
|
||||
type_or_types: type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
@ -810,7 +800,7 @@ def tree_all_only(
|
||||
|
||||
@overload
|
||||
def tree_any_only(
|
||||
type_or_types: Type[T],
|
||||
type_or_types: type[T],
|
||||
/,
|
||||
pred: Fn[T, bool],
|
||||
tree: PyTree,
|
||||
@ -856,7 +846,7 @@ def broadcast_prefix(
|
||||
prefix_tree: PyTree,
|
||||
full_tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Any]:
|
||||
) -> list[Any]:
|
||||
"""Return a list of broadcasted leaves in ``prefix_tree`` to match the number of leaves in ``full_tree``.
|
||||
|
||||
If a ``prefix_tree`` is a prefix of a ``full_tree``, this means the ``full_tree`` can be
|
||||
@ -891,7 +881,7 @@ def broadcast_prefix(
|
||||
Returns:
|
||||
A list of leaves in ``prefix_tree`` broadcasted to match the number of leaves in ``full_tree``.
|
||||
"""
|
||||
result: List[Any] = []
|
||||
result: list[Any] = []
|
||||
|
||||
def add_leaves(x: Any, subtree: PyTree) -> None:
|
||||
subtreespec = tree_structure(subtree, is_leaf=is_leaf)
|
||||
@ -918,7 +908,7 @@ def _broadcast_to_and_flatten(
|
||||
tree: PyTree,
|
||||
treespec: TreeSpec,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Optional[List[Any]]:
|
||||
) -> Optional[list[Any]]:
|
||||
assert _is_pytreespec_instance(treespec)
|
||||
full_tree = tree_unflatten([0] * treespec.num_leaves, treespec)
|
||||
try:
|
||||
@ -977,7 +967,7 @@ class LeafSpec(TreeSpec, metaclass=LeafSpecMeta):
|
||||
def tree_flatten_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]:
|
||||
) -> tuple[list[tuple[KeyPath, Any]], TreeSpec]:
|
||||
"""Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
@ -1000,7 +990,7 @@ def tree_flatten_with_path(
|
||||
def tree_leaves_with_path(
|
||||
tree: PyTree,
|
||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||
) -> List[Tuple[KeyPath, Any]]:
|
||||
) -> list[tuple[KeyPath, Any]]:
|
||||
"""Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
|
||||
|
||||
Args:
|
||||
|
Reference in New Issue
Block a user