PEP585 update - mostly toplevels (#145178)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-21 13:42:12 -08:00
committed by PyTorch MergeBot
parent 1ce533867f
commit f2cfe8b59f
39 changed files with 356 additions and 386 deletions

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import functools
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import deprecated
import torch
@ -8,14 +8,14 @@ from torch import Tensor
from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
in_dims_t = Union[int, Tuple]
out_dims_t = Union[int, Tuple[int, ...]]
in_dims_t = Union[int, tuple]
out_dims_t = Union[int, tuple[int, ...]]
# Checks that all args-to-be-batched have the same batch dim size
def _validate_and_get_batch_size(
flat_in_dims: List[Optional[int]],
flat_args: List,
flat_in_dims: list[Optional[int]],
flat_args: list,
) -> int:
batch_sizes = [
arg.size(in_dim)
@ -30,7 +30,7 @@ def _validate_and_get_batch_size(
return batch_sizes[0]
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
if isinstance(batched_outputs, tuple):
return len(batched_outputs)
return 1
@ -42,7 +42,7 @@ def _as_tuple(
value: Any,
num_elements: int,
error_message_lambda: Callable[[], str],
) -> Tuple:
) -> tuple:
if not isinstance(value, tuple):
return (value,) * num_elements
if len(value) != num_elements:
@ -54,10 +54,10 @@ def _as_tuple(
# Returns the (potentially) batched arguments and the batch_size.
def _create_batched_inputs(
in_dims: in_dims_t,
args: Tuple,
args: tuple,
vmap_level: int,
func: Callable,
) -> Tuple[Tuple, int]:
) -> tuple[tuple, int]:
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
raise ValueError(
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
@ -114,13 +114,13 @@ def _create_batched_inputs(
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
def _unwrap_batched(
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
out_dims: out_dims_t,
vmap_level: int,
batch_size: int,
func: Callable,
allow_none_pass_through: bool = False,
) -> Tuple:
) -> tuple:
num_outputs = _num_outputs(batched_outputs)
out_dims_as_tuple = _as_tuple(
out_dims,