mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - mostly toplevels (#145178)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145178 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
1ce533867f
commit
f2cfe8b59f
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -8,14 +8,14 @@ from torch import Tensor
|
||||
from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
in_dims_t = Union[int, Tuple]
|
||||
out_dims_t = Union[int, Tuple[int, ...]]
|
||||
in_dims_t = Union[int, tuple]
|
||||
out_dims_t = Union[int, tuple[int, ...]]
|
||||
|
||||
|
||||
# Checks that all args-to-be-batched have the same batch dim size
|
||||
def _validate_and_get_batch_size(
|
||||
flat_in_dims: List[Optional[int]],
|
||||
flat_args: List,
|
||||
flat_in_dims: list[Optional[int]],
|
||||
flat_args: list,
|
||||
) -> int:
|
||||
batch_sizes = [
|
||||
arg.size(in_dim)
|
||||
@ -30,7 +30,7 @@ def _validate_and_get_batch_size(
|
||||
return batch_sizes[0]
|
||||
|
||||
|
||||
def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
|
||||
def _num_outputs(batched_outputs: Union[Tensor, tuple[Tensor, ...]]) -> int:
|
||||
if isinstance(batched_outputs, tuple):
|
||||
return len(batched_outputs)
|
||||
return 1
|
||||
@ -42,7 +42,7 @@ def _as_tuple(
|
||||
value: Any,
|
||||
num_elements: int,
|
||||
error_message_lambda: Callable[[], str],
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
if not isinstance(value, tuple):
|
||||
return (value,) * num_elements
|
||||
if len(value) != num_elements:
|
||||
@ -54,10 +54,10 @@ def _as_tuple(
|
||||
# Returns the (potentially) batched arguments and the batch_size.
|
||||
def _create_batched_inputs(
|
||||
in_dims: in_dims_t,
|
||||
args: Tuple,
|
||||
args: tuple,
|
||||
vmap_level: int,
|
||||
func: Callable,
|
||||
) -> Tuple[Tuple, int]:
|
||||
) -> tuple[tuple, int]:
|
||||
if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
|
||||
raise ValueError(
|
||||
f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
|
||||
@ -114,13 +114,13 @@ def _create_batched_inputs(
|
||||
|
||||
# Undos the batching (and any batch dimensions) associated with the `vmap_level`.
|
||||
def _unwrap_batched(
|
||||
batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
|
||||
batched_outputs: Union[Tensor, tuple[Tensor, ...]],
|
||||
out_dims: out_dims_t,
|
||||
vmap_level: int,
|
||||
batch_size: int,
|
||||
func: Callable,
|
||||
allow_none_pass_through: bool = False,
|
||||
) -> Tuple:
|
||||
) -> tuple:
|
||||
num_outputs = _num_outputs(batched_outputs)
|
||||
out_dims_as_tuple = _as_tuple(
|
||||
out_dims,
|
||||
|
Reference in New Issue
Block a user