mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_C torch/_decomp torch/_lazy torch/_library torch/_numpy torch/_prims torch/_refs torch/_strobelight (#145102)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145102 Approved by: https://github.com/bobrenjc93 ghstack dependencies: #145105
This commit is contained in:
committed by
PyTorch MergeBot
parent
a79100ab11
commit
5b5766665d
@ -1,11 +1,11 @@
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable
|
||||
|
||||
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
|
||||
|
||||
def set_autograd_compiler(
|
||||
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
|
||||
dynamic: bool,
|
||||
) -> Tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
|
||||
) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
|
||||
def clear_cache() -> None: ...
|
||||
def is_cache_empty() -> bool: ...
|
||||
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
|
||||
|
@ -1,5 +1,5 @@
|
||||
import types
|
||||
from typing import Dict, NewType, Tuple
|
||||
from typing import NewType
|
||||
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
|
||||
@ -31,17 +31,17 @@ class _ExtraState:
|
||||
# properties Dynamo cares about for a frame.
|
||||
class _PyInterpreterFrame:
|
||||
f_code: types.CodeType
|
||||
f_locals: Dict[str, object]
|
||||
f_globals: Dict[str, object]
|
||||
f_builtins: Dict[str, object]
|
||||
f_locals: dict[str, object]
|
||||
f_globals: dict[str, object]
|
||||
f_builtins: dict[str, object]
|
||||
f_lasti: int
|
||||
f_lineo: int
|
||||
f_back: types.FrameType
|
||||
# A tuple containing cell objects captured by this frame.
|
||||
closure: Tuple[types.CellType]
|
||||
closure: tuple[types.CellType]
|
||||
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
|
||||
|
||||
py_opcode_caches: list[int]
|
||||
|
||||
def code_framelocals_names(code: types.CodeType) -> Tuple[str]: ...
|
||||
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -121,7 +121,7 @@ def install_storage_overlapping_guard(
|
||||
): ...
|
||||
def profile_guard_manager(
|
||||
guard_manager: GuardManager,
|
||||
f_locals: Dict[str, Any],
|
||||
f_locals: dict[str, Any],
|
||||
) -> float: ...
|
||||
|
||||
class TensorGuards:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import AnyStr, overload, Tuple
|
||||
from typing import AnyStr, overload
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
@ -16,4 +16,4 @@ class DelayedError:
|
||||
@overload
|
||||
def __call__(self, i0: Tensor) -> Tensor: ...
|
||||
@overload
|
||||
def __call__(self, *args: Tensor) -> Tuple[Tensor, ...]: ...
|
||||
def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ...
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from functools import lru_cache, partial, wraps
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
@ -9,7 +10,6 @@ from typing import (
|
||||
FrozenSet,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
@ -44,8 +44,8 @@ _P = ParamSpec("_P")
|
||||
|
||||
# TODO: relax key type here; torch registrations should be possible to; but
|
||||
# right now this type is accurate
|
||||
global_decomposition_table: Dict[
|
||||
str, Dict[torch._ops.OperatorBase, Callable]
|
||||
global_decomposition_table: dict[
|
||||
str, dict[torch._ops.OperatorBase, Callable]
|
||||
] = defaultdict(dict)
|
||||
|
||||
decomposition_table = global_decomposition_table["post_autograd"]
|
||||
@ -78,7 +78,7 @@ def _add_op_to_registry(registry, op, fn):
|
||||
If op is OpOverload, it will be added to the registry directly.
|
||||
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
|
||||
"""
|
||||
overloads: List[Union[torch._ops.OperatorBase]] = []
|
||||
overloads: list[Union[torch._ops.OperatorBase]] = []
|
||||
if isinstance(op, HigherOrderOperator):
|
||||
# There's no concept of overloads for HigherOrderOperator
|
||||
registry[op] = fn
|
||||
@ -232,7 +232,7 @@ def register_decomposition(
|
||||
def get_decompositions(
|
||||
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
|
||||
type: str = "post_autograd",
|
||||
) -> Dict[torch._ops.OperatorBase, Callable]:
|
||||
) -> dict[torch._ops.OperatorBase, Callable]:
|
||||
"""
|
||||
Retrieve a dictionary of decompositions corresponding to the list of
|
||||
operator overloads and overload packets passed as input. Overload
|
||||
@ -251,7 +251,7 @@ def get_decompositions(
|
||||
for opo in registry:
|
||||
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
||||
packets_to_overloads[opo.overloadpacket].append(opo)
|
||||
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
decompositions: dict[torch._ops.OperatorBase, Callable] = {}
|
||||
for op in aten_ops:
|
||||
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
|
||||
for op_overload in packets_to_overloads[op]:
|
||||
@ -262,7 +262,7 @@ def get_decompositions(
|
||||
|
||||
|
||||
def remove_decompositions(
|
||||
decompositions: Dict[torch._ops.OperatorBase, Callable],
|
||||
decompositions: dict[torch._ops.OperatorBase, Callable],
|
||||
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
|
||||
) -> None:
|
||||
"""
|
||||
@ -297,7 +297,7 @@ def core_aten_decompositions() -> "CustomDecompTable":
|
||||
# excluding decompositions that results in prim ops
|
||||
# Resulting opset of decomposition is core aten ops
|
||||
def _core_aten_decompositions_post_autograd() -> (
|
||||
Dict[torch._ops.OperatorBase, Callable]
|
||||
dict[torch._ops.OperatorBase, Callable]
|
||||
):
|
||||
aten = torch.ops.aten
|
||||
return get_decompositions(
|
||||
|
@ -5,10 +5,11 @@ import itertools
|
||||
import numbers
|
||||
import operator
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from itertools import chain, product
|
||||
from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._meta_registrations
|
||||
@ -39,7 +40,7 @@ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
# None of these functions are publicly accessible; get at them
|
||||
# from torch._decomps
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
||||
aten = torch._ops.ops.aten
|
||||
|
||||
@ -299,7 +300,7 @@ def _prelu_kernel_backward(
|
||||
grad_output: Tensor,
|
||||
self: Tensor,
|
||||
weight: Tensor,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
input_grad = torch.where(self > 0, grad_output, weight * grad_output)
|
||||
weight_grad = torch.where(self > 0, 0.0, self * grad_output)
|
||||
return (input_grad, weight_grad)
|
||||
@ -690,7 +691,7 @@ def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
|
||||
@out_wrapper()
|
||||
def slice_backward(
|
||||
grad_output: Tensor,
|
||||
input_sizes: List[int],
|
||||
input_sizes: list[int],
|
||||
dim: int,
|
||||
start: int,
|
||||
end: int,
|
||||
@ -760,7 +761,7 @@ def slice_forward(
|
||||
|
||||
def _normalize_start_end(
|
||||
x: Tensor, dim: int, start: Optional[int], end: Optional[int]
|
||||
) -> Tuple[int, int]:
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Normalize start and end such that both are in the range
|
||||
[0, x.get_size()[dim]] and start <= end.
|
||||
@ -824,7 +825,7 @@ def slice_scatter(
|
||||
|
||||
@register_decomposition(aten.select_backward)
|
||||
@out_wrapper()
|
||||
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
|
||||
def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int):
|
||||
grad_input = grad_output.new_zeros(input_sizes)
|
||||
return torch.select_scatter(grad_input, grad_output, dim, index)
|
||||
|
||||
@ -832,7 +833,7 @@ def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index
|
||||
@register_decomposition(aten.diagonal_backward)
|
||||
@out_wrapper()
|
||||
def diagonal_backward(
|
||||
grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
|
||||
grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
|
||||
):
|
||||
grad_input = grad_output.new_zeros(input_sizes)
|
||||
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
|
||||
@ -899,10 +900,10 @@ def _im2col_col2im_indices_along_dim(
|
||||
@out_wrapper()
|
||||
def im2col(
|
||||
input: Tensor,
|
||||
kernel_size: List[int],
|
||||
dilation: List[int],
|
||||
padding: List[int],
|
||||
stride: List[int],
|
||||
kernel_size: list[int],
|
||||
dilation: list[int],
|
||||
padding: list[int],
|
||||
stride: list[int],
|
||||
) -> Tensor:
|
||||
torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
||||
torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
||||
@ -982,11 +983,11 @@ def im2col(
|
||||
@pw_cast_for_opmath
|
||||
def col2im(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
kernel_size: List[int],
|
||||
dilation: List[int],
|
||||
padding: List[int],
|
||||
stride: List[int],
|
||||
output_size: list[int],
|
||||
kernel_size: list[int],
|
||||
dilation: list[int],
|
||||
padding: list[int],
|
||||
stride: list[int],
|
||||
) -> Tensor:
|
||||
torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
||||
torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
||||
@ -1094,7 +1095,7 @@ def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
|
||||
@register_decomposition(aten.unfold_backward)
|
||||
@out_wrapper()
|
||||
def unfold_backward(
|
||||
grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
|
||||
grad: Tensor, input_size: list[int], dimension: int, size: int, step: int
|
||||
) -> Tensor:
|
||||
if len(input_size) == 0:
|
||||
return torch.squeeze_copy(grad, 0)
|
||||
@ -1257,7 +1258,7 @@ def embedding_dense_backward(
|
||||
)
|
||||
|
||||
|
||||
def prod(x: List[int]):
|
||||
def prod(x: list[int]):
|
||||
r = 1
|
||||
for i in x:
|
||||
r *= i
|
||||
@ -1265,10 +1266,10 @@ def prod(x: List[int]):
|
||||
|
||||
|
||||
def _pad_chunk(
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
dim: int,
|
||||
num_chunks: int,
|
||||
) -> List[Tensor]:
|
||||
) -> list[Tensor]:
|
||||
padded_tensors = []
|
||||
for tensor in tensors:
|
||||
tensor_size = tensor.size()
|
||||
@ -1285,7 +1286,7 @@ def _pad_chunk(
|
||||
return padded_tensors
|
||||
|
||||
|
||||
def have_same_ndims(tensors: List[Tensor]):
|
||||
def have_same_ndims(tensors: list[Tensor]):
|
||||
ndim = tensors[0].ndim
|
||||
for tensor in tensors:
|
||||
if tensor.ndim != ndim:
|
||||
@ -1293,7 +1294,7 @@ def have_same_ndims(tensors: List[Tensor]):
|
||||
return True
|
||||
|
||||
|
||||
def leading_dimension_matches(tensors: List[Tensor], dim: int):
|
||||
def leading_dimension_matches(tensors: list[Tensor], dim: int):
|
||||
leading_dim_sizes = tensors[0].size()[:dim]
|
||||
for tensor in tensors:
|
||||
torch._check(
|
||||
@ -1303,7 +1304,7 @@ def leading_dimension_matches(tensors: List[Tensor], dim: int):
|
||||
|
||||
|
||||
def _preprocess_chunk_cat_inputs(
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
dim: int,
|
||||
num_chunks: int,
|
||||
):
|
||||
@ -1341,7 +1342,7 @@ def _preprocess_chunk_cat_inputs(
|
||||
|
||||
@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
|
||||
def _chunk_cat(
|
||||
tensors: List[Tensor],
|
||||
tensors: list[Tensor],
|
||||
dim: int,
|
||||
num_chunks: int,
|
||||
out: Optional[Tensor] = None,
|
||||
@ -1361,10 +1362,10 @@ def _chunk_cat(
|
||||
)
|
||||
def split_with_sizes_copy(
|
||||
self: Tensor,
|
||||
split_sizes: List[int],
|
||||
split_sizes: list[int],
|
||||
dim: int = 0,
|
||||
out: Optional[List[Tensor]] = None,
|
||||
) -> Optional[List[Tensor]]:
|
||||
out: Optional[list[Tensor]] = None,
|
||||
) -> Optional[list[Tensor]]:
|
||||
splits = aten.split_with_sizes(self, split_sizes, dim=dim)
|
||||
if out is None:
|
||||
return [s.clone(memory_format=torch.contiguous_format) for s in splits]
|
||||
@ -1376,19 +1377,19 @@ def split_with_sizes_copy(
|
||||
|
||||
|
||||
@register_decomposition(aten.unsafe_split.Tensor)
|
||||
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
|
||||
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
|
||||
return aten.split.Tensor(input, split_size, dim)
|
||||
|
||||
|
||||
@register_decomposition(aten.unsafe_split_with_sizes.default)
|
||||
def unsafe_split_with_sizes(
|
||||
input: Tensor, split_sizes: List[int], dim: int = 0
|
||||
) -> Tuple[Tensor, ...]:
|
||||
input: Tensor, split_sizes: list[int], dim: int = 0
|
||||
) -> tuple[Tensor, ...]:
|
||||
return aten.split_with_sizes.default(input, split_sizes, dim)
|
||||
|
||||
|
||||
@register_decomposition(aten.split.Tensor)
|
||||
def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
|
||||
def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
|
||||
input_sizes = self.shape
|
||||
dim_size = input_sizes[dim]
|
||||
if split_size == 0:
|
||||
@ -1412,7 +1413,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
|
||||
self: Tensor,
|
||||
tensor_indices_or_sections: Tensor,
|
||||
dim: int = 0,
|
||||
) -> Tuple[Tensor, ...]:
|
||||
) -> tuple[Tensor, ...]:
|
||||
assert tensor_indices_or_sections.device.type == "cpu"
|
||||
assert tensor_indices_or_sections.dtype == torch.int64
|
||||
split_dim = tensor_indices_or_sections.dim()
|
||||
@ -1505,8 +1506,8 @@ def native_group_norm_backward(
|
||||
C: int,
|
||||
HxW: int,
|
||||
group: int,
|
||||
output_mask: List[bool],
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
utils.check_same_device(
|
||||
grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
|
||||
)
|
||||
@ -1593,12 +1594,12 @@ def native_group_norm_backward_out(
|
||||
C: int,
|
||||
HxW: int,
|
||||
group: int,
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
*,
|
||||
out0: torch.Tensor,
|
||||
out1: torch.Tensor,
|
||||
out2: torch.Tensor,
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
result = native_group_norm_backward(
|
||||
grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
|
||||
)
|
||||
@ -1622,13 +1623,13 @@ def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
|
||||
def native_layer_norm_backward(
|
||||
grad_out: Tensor,
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
normalized_shape: list[int],
|
||||
mean: Tensor,
|
||||
rstd: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
output_mask: List[bool],
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
input_shape = input.shape
|
||||
input_ndim = input.dim()
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
@ -1643,8 +1644,8 @@ def native_layer_norm_backward(
|
||||
axis = input_ndim - len(normalized_shape)
|
||||
inner_dims = input_shape[axis:]
|
||||
outer_dims = input_shape[:axis]
|
||||
inner_dim_indices: List[int] = []
|
||||
outer_dim_indices: List[int] = []
|
||||
inner_dim_indices: list[int] = []
|
||||
outer_dim_indices: list[int] = []
|
||||
for i in range(input_ndim):
|
||||
if i >= axis:
|
||||
inner_dim_indices.append(i)
|
||||
@ -1705,17 +1706,17 @@ def native_layer_norm_backward(
|
||||
def native_layer_norm_backward_out(
|
||||
grad_out: Tensor,
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
normalized_shape: list[int],
|
||||
mean: Tensor,
|
||||
rstd: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
*,
|
||||
out0: torch.Tensor,
|
||||
out1: torch.Tensor,
|
||||
out2: torch.Tensor,
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
result = native_layer_norm_backward(
|
||||
grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
|
||||
)
|
||||
@ -1738,7 +1739,7 @@ def native_batch_norm_helper(
|
||||
momentum: float,
|
||||
eps: float,
|
||||
functional: bool,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
reduction_dims = [0] + list(range(2, input.dim()))
|
||||
computation_dtype = utils.get_computation_dtype(input.dtype)
|
||||
new_running_mean = running_mean
|
||||
@ -1821,7 +1822,7 @@ def native_batch_norm(
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps, False
|
||||
)
|
||||
@ -1849,7 +1850,7 @@ def native_batch_norm_decomposition(
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
if running_mean is None and running_var is None:
|
||||
return aten._native_batch_norm_legit(
|
||||
input, weight, bias, training, momentum, eps
|
||||
@ -1876,7 +1877,7 @@ def native_batch_norm_decomposition(
|
||||
|
||||
|
||||
@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
|
||||
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]:
|
||||
dim_size = tensor.size(dim)
|
||||
split_size = (dim_size + chunks - 1) // chunks
|
||||
|
||||
@ -1896,7 +1897,7 @@ def _native_batch_norm_legit_no_training(
|
||||
running_var: Tensor,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
return aten._native_batch_norm_legit.default(
|
||||
input,
|
||||
weight,
|
||||
@ -1919,7 +1920,7 @@ def _native_batch_norm_legit(
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||
input, weight, bias, running_mean, running_var, training, momentum, eps, False
|
||||
)
|
||||
@ -1934,7 +1935,7 @@ def _native_batch_norm_legit_no_stats(
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||
input, weight, bias, None, None, training, momentum, eps, False
|
||||
)
|
||||
@ -1951,7 +1952,7 @@ def _native_batch_norm_legit_functional(
|
||||
training: bool,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
(
|
||||
output,
|
||||
save_mean,
|
||||
@ -2002,7 +2003,7 @@ def _batch_norm_with_update(
|
||||
running_var: Tensor,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||
input,
|
||||
weight,
|
||||
@ -2029,7 +2030,7 @@ def _batch_norm_with_update_functional(
|
||||
running_var: Tensor,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
|
||||
(
|
||||
output,
|
||||
save_mean,
|
||||
@ -2056,7 +2057,7 @@ def _batch_norm_no_update(
|
||||
running_var: Tensor,
|
||||
momentum: float,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
|
||||
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
|
||||
input,
|
||||
weight,
|
||||
@ -2190,9 +2191,9 @@ def batch_norm_backward(
|
||||
save_invstd: Optional[Tensor],
|
||||
train: bool,
|
||||
eps: float,
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
reserve: Tensor,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
return native_batch_norm_backward(
|
||||
grad_out,
|
||||
input,
|
||||
@ -2218,8 +2219,8 @@ def native_batch_norm_backward(
|
||||
save_invstd: Optional[Tensor],
|
||||
train: bool,
|
||||
eps: float,
|
||||
output_mask: List[bool],
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
input_dtype = input.dtype
|
||||
if weight is not None:
|
||||
weight_dtype = weight.dtype
|
||||
@ -2262,10 +2263,10 @@ def native_batch_norm_backward(
|
||||
mean = running_mean_cast
|
||||
invstd = torch.rsqrt(running_var_cast + eps)
|
||||
|
||||
broadcast_mask: List[int] = [1] * input_rank
|
||||
broadcast_mask: list[int] = [1] * input_rank
|
||||
broadcast_mask[axis] = input_shape[axis]
|
||||
|
||||
reduction_axes: List[int] = []
|
||||
reduction_axes: list[int] = []
|
||||
for i in range(input_rank):
|
||||
if i != axis:
|
||||
reduction_axes.append(i)
|
||||
@ -2320,12 +2321,12 @@ def native_batch_norm_backward_out(
|
||||
save_invstd: Optional[Tensor],
|
||||
train: bool,
|
||||
eps: float,
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
*,
|
||||
out0: torch.Tensor,
|
||||
out1: torch.Tensor,
|
||||
out2: torch.Tensor,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
result = native_batch_norm_backward(
|
||||
grad_out,
|
||||
input,
|
||||
@ -2403,7 +2404,7 @@ def cudnn_batch_norm_backward(
|
||||
@register_decomposition(aten._adaptive_avg_pool2d)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
|
||||
def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]):
|
||||
# Preconditions
|
||||
device = input.device
|
||||
shape = input.shape
|
||||
@ -2506,7 +2507,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
|
||||
|
||||
|
||||
def _max_unpoolnd(
|
||||
self: TensorLike, indices: TensorLike, output_size: List[int], dim: int
|
||||
self: TensorLike, indices: TensorLike, output_size: list[int], dim: int
|
||||
):
|
||||
# If the input tensors self and indices came from max_pool call as
|
||||
# required by the documentation, this operation is deterministic
|
||||
@ -2534,7 +2535,7 @@ def _max_unpoolnd(
|
||||
def max_unpool2d(
|
||||
self: TensorLike,
|
||||
indices: TensorLike,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
):
|
||||
torch._check(
|
||||
indices.dtype == torch.int64,
|
||||
@ -2581,9 +2582,9 @@ def max_unpool2d(
|
||||
def max_unpool3d(
|
||||
input: TensorLike,
|
||||
indices: TensorLike,
|
||||
output_size: List[int],
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
output_size: list[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
):
|
||||
torch._check(
|
||||
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
|
||||
@ -2761,7 +2762,7 @@ def _index_copy(
|
||||
@register_decomposition(aten.log_sigmoid_forward)
|
||||
@out_wrapper("output", "buffer")
|
||||
@pw_cast_for_opmath
|
||||
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
|
||||
min = torch.minimum(self.new_zeros(()), self)
|
||||
z = torch.exp(-torch.abs(self))
|
||||
if self.is_cuda or self.is_xpu:
|
||||
@ -2840,8 +2841,8 @@ def get_scale_value(scales, idx):
|
||||
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_vec(
|
||||
input: Tensor,
|
||||
output_size: Optional[List[int]],
|
||||
scale_factors: Optional[List[float]],
|
||||
output_size: Optional[list[int]],
|
||||
scale_factors: Optional[list[float]],
|
||||
) -> Tensor:
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scales = (
|
||||
@ -2861,8 +2862,8 @@ def _upsample_nearest_vec(
|
||||
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
|
||||
def _upsample_nearest_exact_vec(
|
||||
input: Tensor,
|
||||
output_size: Optional[List[int]],
|
||||
scale_factors: Optional[List[float]],
|
||||
output_size: Optional[list[int]],
|
||||
scale_factors: Optional[list[float]],
|
||||
) -> Tensor:
|
||||
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
||||
scales = (
|
||||
@ -2909,7 +2910,7 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def upsample_nearest1d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
return _upsample_nearest(input, output_size, [scales])
|
||||
@ -2923,7 +2924,7 @@ def upsample_nearest1d(
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def upsample_nearest_exact1d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
return _upsample_nearest(input, output_size, [scales], exact=True)
|
||||
@ -2935,7 +2936,7 @@ def upsample_nearest_exact1d(
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def upsample_nearest2d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
@ -2950,7 +2951,7 @@ def upsample_nearest2d(
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def _upsample_nearest_exact2d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
@ -2963,7 +2964,7 @@ def _upsample_nearest_exact2d(
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def upsample_nearest3d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales_d: Optional[float] = None,
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
@ -2979,7 +2980,7 @@ def upsample_nearest3d(
|
||||
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
|
||||
def _upsample_nearest_exact3d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
scales_d: Optional[float] = None,
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
@ -2992,8 +2993,8 @@ def _upsample_nearest_exact3d(
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_nearest(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
scales: List[Optional[float]],
|
||||
output_size: list[int],
|
||||
scales: list[Optional[float]],
|
||||
exact: bool = False,
|
||||
) -> Tensor:
|
||||
spatial_indices = _compute_upsample_nearest_indices(
|
||||
@ -3072,7 +3073,7 @@ def one_layer_rnn_data(
|
||||
hh_bias = params[3] if has_biases else None
|
||||
|
||||
step_output = []
|
||||
hiddens: List[torch.Tensor] = []
|
||||
hiddens: list[torch.Tensor] = []
|
||||
|
||||
last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
|
||||
cur_hidden = hidden.narrow(0, 0, last_batch_size)
|
||||
@ -3159,7 +3160,7 @@ def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
|
||||
hx = hidden[0].unsqueeze(0)
|
||||
cx = hidden[1].unsqueeze(0)
|
||||
|
||||
batch_sizes: List[int] = []
|
||||
batch_sizes: list[int] = []
|
||||
mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
|
||||
hidden_size = hx.size(2)
|
||||
num_layers = 1
|
||||
@ -3710,7 +3711,7 @@ def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
|
||||
@out_wrapper()
|
||||
def upsample_linear1d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
align_corners: bool,
|
||||
scales_w: Optional[float] = None,
|
||||
) -> Tensor:
|
||||
@ -3724,7 +3725,7 @@ def upsample_linear1d(
|
||||
@out_wrapper()
|
||||
def upsample_bilinear2d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
align_corners: bool,
|
||||
scales_h: Optional[float] = None,
|
||||
scales_w: Optional[float] = None,
|
||||
@ -3738,7 +3739,7 @@ def upsample_bilinear2d(
|
||||
@out_wrapper()
|
||||
def upsample_trilinear3d(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
align_corners: bool,
|
||||
scales_d: Optional[float] = None,
|
||||
scales_h: Optional[float] = None,
|
||||
@ -3785,9 +3786,9 @@ def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
|
||||
@pw_cast_for_opmath
|
||||
def _upsample_linear(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
output_size: list[int],
|
||||
align_corners: bool,
|
||||
scales: List[Optional[float]],
|
||||
scales: list[Optional[float]],
|
||||
) -> Tensor:
|
||||
# get dimensions of original image
|
||||
n_channels = input.shape[1]
|
||||
@ -3937,7 +3938,7 @@ def _nll_loss_forward(
|
||||
weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
# self can be [N, C] or [C]
|
||||
# target can be [N] or []
|
||||
|
||||
@ -3992,7 +3993,7 @@ def nll_loss_forward(
|
||||
weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
|
||||
assert (
|
||||
target.dim() <= 1
|
||||
@ -4020,7 +4021,7 @@ def nll_loss2d_forward(
|
||||
weight: Optional[Tensor],
|
||||
reduction: int,
|
||||
ignore_index: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
return _nll_loss_forward(self, target, weight, reduction, ignore_index)
|
||||
|
||||
|
||||
@ -4108,7 +4109,7 @@ def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: boo
|
||||
return grid_x + grid_y + grid_z + grid_one
|
||||
|
||||
|
||||
def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool):
|
||||
def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool):
|
||||
n, _, h, w = size
|
||||
base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
|
||||
# base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
|
||||
@ -4118,7 +4119,7 @@ def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: boo
|
||||
return grid.view(n, h, w, 2)
|
||||
|
||||
|
||||
def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool):
|
||||
def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool):
|
||||
n, _, d, h, w = size
|
||||
base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
|
||||
# base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
|
||||
@ -4131,7 +4132,7 @@ def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: boo
|
||||
@register_decomposition(aten.affine_grid_generator)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool):
|
||||
def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool):
|
||||
torch._check(
|
||||
len(size) in (4, 5),
|
||||
lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
|
||||
@ -4454,7 +4455,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
|
||||
m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
|
||||
p = tensor2.size(-1) if dim_tensor2 > 1 else 1
|
||||
|
||||
batch_tensor2: List[int] = []
|
||||
batch_tensor2: list[int] = []
|
||||
# TODO: handling of slice
|
||||
for i in range(dim_tensor2 - 2):
|
||||
batch_tensor2.append(tensor2.size(i))
|
||||
@ -4520,7 +4521,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
|
||||
@pw_cast_for_opmath
|
||||
def upsample_bicubic2d_default(
|
||||
input: Tensor,
|
||||
output_size: Tuple[int, int],
|
||||
output_size: tuple[int, int],
|
||||
align_corners: bool,
|
||||
scale_h: Optional[float] = None,
|
||||
scale_w: Optional[float] = None,
|
||||
@ -4608,9 +4609,9 @@ def upsample_bicubic2d_default(
|
||||
@pw_cast_for_opmath
|
||||
def upsample_bicubic2d_vec(
|
||||
a: Tensor,
|
||||
output_size: Optional[Tuple[int, int]],
|
||||
output_size: Optional[tuple[int, int]],
|
||||
align_corners: bool,
|
||||
scale_factors: Optional[Tuple[float, float]] = None,
|
||||
scale_factors: Optional[tuple[float, float]] = None,
|
||||
) -> Tensor:
|
||||
torch._check(
|
||||
bool(output_size) + bool(scale_factors) == 1,
|
||||
@ -4619,7 +4620,7 @@ def upsample_bicubic2d_vec(
|
||||
if output_size is None:
|
||||
assert scale_factors is not None
|
||||
output_size = cast(
|
||||
Tuple[int, int],
|
||||
tuple[int, int],
|
||||
tuple(
|
||||
sym_int(sym_float(w) * scale)
|
||||
for w, scale in zip(a.shape[2:], scale_factors)
|
||||
@ -4634,7 +4635,7 @@ def upsample_bicubic2d_vec(
|
||||
@register_decomposition(aten.reflection_pad3d)
|
||||
@pw_cast_for_opmath
|
||||
@out_wrapper()
|
||||
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
||||
def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
|
||||
def idx(left, middle, right):
|
||||
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
||||
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
|
||||
@ -4651,7 +4652,7 @@ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
||||
@register_decomposition(aten.replication_pad3d)
|
||||
@pw_cast_for_opmath
|
||||
@out_wrapper()
|
||||
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
||||
def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
|
||||
def idx(left, middle, right):
|
||||
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
||||
return torch.clamp(dim_idx, 0, middle - 1)
|
||||
@ -4665,7 +4666,7 @@ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
||||
|
||||
def _reflection_or_replication_pad(
|
||||
a: Tensor,
|
||||
padding: Tuple[int, ...],
|
||||
padding: tuple[int, ...],
|
||||
idx_fn: Callable[[int, int, int], Tensor],
|
||||
) -> Tensor:
|
||||
dim = len(padding) // 2
|
||||
@ -4681,7 +4682,7 @@ def _reflection_or_replication_pad(
|
||||
|
||||
result = a
|
||||
for i in range(dim):
|
||||
idx: List[Any] = [None] * result.dim()
|
||||
idx: list[Any] = [None] * result.dim()
|
||||
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
|
||||
result = aten._unsafe_index(result, idx)
|
||||
|
||||
@ -4887,7 +4888,7 @@ def multilabel_margin_loss_forward(
|
||||
input: Tensor,
|
||||
target: Tensor,
|
||||
reduction: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
orig_input_shape = input.shape
|
||||
orig_target_shape = target.shape
|
||||
input = torch.atleast_2d(input)
|
||||
@ -4953,7 +4954,7 @@ def scaled_dot_product_flash_attention_for_cpu(
|
||||
*,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
torch._check(
|
||||
torch.is_floating_point(query),
|
||||
lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch._decomp
|
||||
@ -10,7 +10,7 @@ from torch._prims_common.wrappers import _maybe_remove_out_wrapper
|
||||
|
||||
|
||||
decomposition_table = torch._decomp.decomposition_table
|
||||
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {}
|
||||
register_decomposition = torch._decomp.register_decomposition
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -104,7 +104,7 @@ def trace(self: Tensor) -> Tensor:
|
||||
|
||||
|
||||
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
|
||||
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
|
||||
min = torch.minimum(self.new_zeros(()), self)
|
||||
z = torch.exp(-torch.abs(self))
|
||||
if self.is_cuda or self.is_xpu:
|
||||
@ -115,7 +115,7 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
|
||||
|
||||
def recompute_mean_var(
|
||||
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
|
||||
input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool
|
||||
):
|
||||
# for most norm decompositions, it will be the same as the core version except for here.
|
||||
# We recompute the mean and variance so that they track gradients through input
|
||||
@ -132,13 +132,13 @@ def recompute_mean_var(
|
||||
def native_layer_norm_backward(
|
||||
grad_out: Tensor,
|
||||
input: Tensor,
|
||||
normalized_shape: List[int],
|
||||
normalized_shape: list[int],
|
||||
mean: Tensor,
|
||||
rstd: Tensor,
|
||||
weight: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
output_mask: List[bool],
|
||||
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
|
||||
input_shape = input.shape
|
||||
input_ndim = input.dim()
|
||||
|
||||
@ -205,7 +205,7 @@ def native_layer_norm_backward(
|
||||
return (d_input, d_weight, d_bias)
|
||||
|
||||
|
||||
def prod(x: List[int]):
|
||||
def prod(x: list[int]):
|
||||
r = 1
|
||||
for i in x:
|
||||
r *= i
|
||||
@ -223,8 +223,8 @@ def native_batch_norm_backward(
|
||||
save_invstd: Optional[Tensor],
|
||||
train: bool,
|
||||
eps: float,
|
||||
output_mask: List[bool],
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
output_mask: list[bool],
|
||||
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
input_shape = input.shape
|
||||
input_rank = input.dim()
|
||||
assert input_rank >= 2, "rank of the input must be at least 2"
|
||||
@ -251,7 +251,7 @@ def native_batch_norm_backward(
|
||||
broadcast_mask = [1] * input_rank
|
||||
broadcast_mask[axis] = input_shape[axis]
|
||||
|
||||
reduction_axes: List[int] = []
|
||||
reduction_axes: list[int] = []
|
||||
for i in range(input_rank):
|
||||
if i != axis:
|
||||
reduction_axes.append(i)
|
||||
@ -305,9 +305,9 @@ def batch_norm_backward(
|
||||
save_var: Optional[Tensor],
|
||||
update: bool,
|
||||
eps: float,
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
reserve: Tensor,
|
||||
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
|
||||
return native_batch_norm_backward(
|
||||
grad_out,
|
||||
input,
|
||||
|
@ -2,7 +2,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch._decomp as decomp
|
||||
@ -12,7 +12,7 @@ from torch._ops import OpOverload
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
|
||||
rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict)
|
||||
|
||||
|
||||
def register_rng_decomposition(aten_op):
|
||||
|
@ -1,11 +1,11 @@
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch._C._lazy
|
||||
|
||||
|
||||
class DeviceContext:
|
||||
_CONTEXTS: Dict[str, Any] = {}
|
||||
_CONTEXTS: dict[str, Any] = {}
|
||||
_CONTEXTS_LOCK = threading.Lock()
|
||||
|
||||
def __init__(self, device: str) -> None:
|
||||
|
@ -3,7 +3,7 @@ import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._lazy as lazy
|
||||
@ -28,14 +28,14 @@ class GraphInputMatcher:
|
||||
TS/XLA graph inputs.
|
||||
"""
|
||||
|
||||
tensor_id_to_arg_idx: Dict[int, int]
|
||||
graph_input_tensor_ids: List[int]
|
||||
tensor_id_to_arg_idx: dict[int, int]
|
||||
graph_input_tensor_ids: list[int]
|
||||
# there are 2 categories of graph_input_tensors.
|
||||
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
|
||||
# most likely const tensors and we can get its content from graph_input_tensors
|
||||
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
|
||||
# the tensor from method arguments
|
||||
graph_input_ivalues: List[Any]
|
||||
graph_input_ivalues: list[Any]
|
||||
|
||||
# get the real graph input tensors
|
||||
def __call__(self, args):
|
||||
@ -71,10 +71,10 @@ class ReturnValueHandler:
|
||||
"""
|
||||
|
||||
def __init__(self, lazy_out_list):
|
||||
self.index: List[List[int]] = []
|
||||
self.index: list[list[int]] = []
|
||||
self.total_count = len(lazy_out_list)
|
||||
|
||||
tensor_id_to_idx: Dict[int, int] = {}
|
||||
tensor_id_to_idx: dict[int, int] = {}
|
||||
for dup_idx, lazy_tensor in enumerate(lazy_out_list):
|
||||
uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
|
||||
if uniq_idx is not None:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Optional, Protocol
|
||||
from typing import Any, Callable, Optional, Protocol
|
||||
|
||||
from torch import _C, _ops, autograd, Tensor
|
||||
from torch.utils import _pytree
|
||||
@ -28,7 +28,7 @@ def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
|
||||
@dataclass
|
||||
class Metadata:
|
||||
keyset: _C.DispatchKeySet
|
||||
keyword_only_args: Dict[str, Any]
|
||||
keyword_only_args: dict[str, Any]
|
||||
|
||||
def forward_no_grad(*args):
|
||||
metadata = args[-1]
|
||||
|
@ -2,20 +2,9 @@
|
||||
import inspect
|
||||
import logging
|
||||
import weakref
|
||||
from collections.abc import Iterable, Sequence
|
||||
from contextlib import contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
|
||||
import torch
|
||||
from torch import _C, _ops, Tensor
|
||||
@ -200,16 +189,16 @@ class CustomOpDef:
|
||||
|
||||
self._init_fn = fn
|
||||
|
||||
self._backend_fns: Dict[Union[str, None], Callable] = {}
|
||||
self._backend_fns: dict[Union[str, None], Callable] = {}
|
||||
self._abstract_fn: Optional[Callable] = None
|
||||
self._setup_context_fn: Optional[Callable] = None
|
||||
self._backward_fn: Optional[Callable] = None
|
||||
self._torch_dispatch_fns: Dict[type, Callable] = {}
|
||||
self._torch_dispatch_fns: dict[type, Callable] = {}
|
||||
self._vmap_fn: Optional[Callable] = None
|
||||
|
||||
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
|
||||
self._register_to_dispatcher()
|
||||
self._disabled_kernel: Set = set()
|
||||
self._disabled_kernel: set = set()
|
||||
OPDEFS[self._qualname] = self
|
||||
|
||||
@property
|
||||
@ -332,7 +321,7 @@ class CustomOpDef:
|
||||
|
||||
def inner(fn):
|
||||
if device_types is None or isinstance(device_types, str):
|
||||
dtypes: List[Union[str, None]] = [device_types]
|
||||
dtypes: list[Union[str, None]] = [device_types]
|
||||
else:
|
||||
dtypes = list(device_types)
|
||||
for device_type in dtypes:
|
||||
@ -807,7 +796,7 @@ def increment_version(val: Any) -> None:
|
||||
# decorator.
|
||||
|
||||
|
||||
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
|
||||
OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
|
||||
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import copy
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Protocol, Tuple, Union
|
||||
from typing import Any, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
from torch._library.utils import parse_namespace
|
||||
@ -56,7 +56,7 @@ class HasStaticMethodFromReal(Protocol):
|
||||
|
||||
class FakeClassRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._registered_class: Dict[str, Any] = {}
|
||||
self._registered_class: dict[str, Any] = {}
|
||||
|
||||
def has_impl(self, full_qualname: str) -> bool:
|
||||
return full_qualname in self._registered_class
|
||||
@ -290,7 +290,7 @@ def _full_qual_class_name(qualname: str) -> str:
|
||||
|
||||
|
||||
# Return the namespace and class name from fully qualified name.
|
||||
def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
|
||||
def _ns_and_class_name(full_qualname: str) -> tuple[str, str]:
|
||||
splits = full_qualname.split(".")
|
||||
assert len(splits) == 5, f"Could not split {full_qualname=}"
|
||||
_torch, _torch_ns, _classes, ns, class_name = splits
|
||||
|
@ -1,6 +1,7 @@
|
||||
import contextlib
|
||||
import threading
|
||||
from typing import Any, Callable, Generator, Iterable, Optional, Union
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from torch.utils._exposed_in import exposed_in
|
||||
|
||||
|
@ -3,7 +3,8 @@ import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
|
||||
from collections.abc import Iterable, Iterator
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -55,7 +56,7 @@ def get_source(stacklevel: int) -> str:
|
||||
return source
|
||||
|
||||
|
||||
def parse_namespace(qualname: str) -> Tuple[str, str]:
|
||||
def parse_namespace(qualname: str) -> tuple[str, str]:
|
||||
splits = qualname.split("::")
|
||||
if len(splits) != 2:
|
||||
raise ValueError(
|
||||
@ -189,8 +190,8 @@ def fill_defaults(schema, args, kwargs):
|
||||
|
||||
|
||||
def zip_schema(
|
||||
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
) -> Iterable[Tuple[_C.Argument, Any]]:
|
||||
schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> Iterable[tuple[_C.Argument, Any]]:
|
||||
"""zips schema.arguments and (args, kwargs) together.
|
||||
|
||||
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
|
||||
@ -332,7 +333,7 @@ def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
|
||||
|
||||
|
||||
def iter_tensors(
|
||||
args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
|
||||
args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
|
||||
) -> Iterator[torch.Tensor]:
|
||||
def check(arg):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
@ -465,7 +466,7 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
|
||||
def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
|
||||
idxs = []
|
||||
keys = []
|
||||
for i, info in enumerate(schema.arguments):
|
||||
|
@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import itertools
|
||||
import operator
|
||||
from typing import Optional, Sequence, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -20,6 +20,8 @@ from . import _dtypes_impl, _util
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ._normalizations import (
|
||||
ArrayLike,
|
||||
ArrayLikeOrScalar,
|
||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import builtins
|
||||
import math
|
||||
import operator
|
||||
from typing import Sequence
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,6 +12,10 @@ from . import _dtypes_impl, _util
|
||||
from ._normalizations import ArrayLike, KeepDims, normalizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class LinAlgError(Exception):
|
||||
pass
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
|
||||
from typing import Callable, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import torch._prims_common as utils
|
||||
@ -231,8 +232,8 @@ def TensorMeta(
|
||||
if isinstance(tensorlike, Number):
|
||||
assert not shape and (shape is None or isinstance(shape, Sequence))
|
||||
assert not strides and (strides is None or isinstance(strides, Sequence))
|
||||
inferred_shape: Tuple[int, ...] = ()
|
||||
inferred_strides: Tuple[int, ...] = ()
|
||||
inferred_shape: tuple[int, ...] = ()
|
||||
inferred_strides: tuple[int, ...] = ()
|
||||
inferred_dtype = type_to_dtype(type(tensorlike))
|
||||
inferred_device = torch.device("cpu")
|
||||
# TODO: This looks wrong, a number that is wrapped into a tensor
|
||||
@ -266,7 +267,7 @@ def TensorMeta(
|
||||
def _make_prim(
|
||||
*,
|
||||
schema: str,
|
||||
return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
|
||||
return_type: Union[RETURN_TYPE, tuple[RETURN_TYPE, ...]],
|
||||
meta: Callable,
|
||||
impl_aten: Callable,
|
||||
doc: str,
|
||||
@ -383,7 +384,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
|
||||
def _prim_elementwise_meta(
|
||||
*args,
|
||||
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
|
||||
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
|
||||
args_with_fixed_dtypes: Optional[tuple[TensorLikeType, ...]] = None,
|
||||
) -> FakeTensor:
|
||||
"""
|
||||
Meta function for elementwise operations that produce outputs in the same dtype
|
||||
@ -1358,7 +1359,7 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
|
||||
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]:
|
||||
"""
|
||||
Returns the shape of a with dims in [start, end) merged into a single dimension.
|
||||
"""
|
||||
@ -1374,7 +1375,7 @@ def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
|
||||
|
||||
def _collapse_view_helper(
|
||||
a: TensorLikeType, start: int, end: int
|
||||
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
|
||||
) -> tuple[Optional[ShapeType], Optional[StrideType]]:
|
||||
assert isinstance(a, TensorLike)
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
@ -1534,8 +1535,8 @@ def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLik
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
new_shape: List[int] = []
|
||||
new_strides: List[int] = []
|
||||
new_shape: list[int] = []
|
||||
new_strides: list[int] = []
|
||||
for idx in range(a.ndim):
|
||||
if idx == dim:
|
||||
new_shape.extend((outer_length, inner_length))
|
||||
@ -1797,7 +1798,7 @@ def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
|
||||
)
|
||||
|
||||
|
||||
def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
|
||||
def _cat_aten(tensors: Union[tuple[Tensor, ...], list[Tensor]], dim: int) -> Tensor:
|
||||
return torch.cat(tensors, dim)
|
||||
|
||||
|
||||
@ -2609,7 +2610,7 @@ scalar_tensor = _make_prim(
|
||||
|
||||
def _svd_meta(
|
||||
A: TensorLikeType, *, full_matrices: bool
|
||||
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
|
||||
) -> tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
|
||||
utils.check_is_matrix(A, "linalg.svd")
|
||||
utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
|
||||
|
||||
@ -2646,7 +2647,7 @@ def _svd_meta(
|
||||
|
||||
def _svd_aten(
|
||||
A: TensorLikeType, *, full_matrices: bool
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
return torch.linalg.svd(A, full_matrices=full_matrices)
|
||||
|
||||
|
||||
@ -2899,7 +2900,7 @@ fft_c2r = _make_prim(
|
||||
)
|
||||
|
||||
|
||||
def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
|
||||
def _frexp_meta(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
|
||||
torch._check(
|
||||
self.dtype.is_floating_point,
|
||||
lambda: "torch.frexp() only supports floating-point dtypes",
|
||||
|
@ -1,7 +1,8 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Dict, Optional, Sequence
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch._decomp
|
||||
@ -28,7 +29,7 @@ def torch_to_refs_map():
|
||||
(torch.fft, torch._refs.fft),
|
||||
(torch.linalg, torch._refs.linalg),
|
||||
]
|
||||
r: Dict[Any, Any] = {
|
||||
r: dict[Any, Any] = {
|
||||
torch.Tensor.__invert__: torch._refs.bitwise_not,
|
||||
torch.Tensor.__xor__: torch._refs.bitwise_xor,
|
||||
torch.Tensor.__and__: torch._refs.bitwise_and,
|
||||
@ -107,7 +108,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
||||
orig_func: Callable,
|
||||
types: Sequence,
|
||||
args: Sequence[Any] = (),
|
||||
kwargs: Optional[Dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -85,7 +85,7 @@ def register_philox_rand():
|
||||
shape: torch.Size,
|
||||
seed: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
stride: Optional[Tuple[int, ...]],
|
||||
stride: Optional[tuple[int, ...]],
|
||||
device: _device,
|
||||
dtype: _dtype,
|
||||
):
|
||||
@ -102,7 +102,7 @@ def register_philox_rand():
|
||||
shape: torch.Size,
|
||||
seed: torch.Tensor,
|
||||
offset: torch.Tensor,
|
||||
stride: Optional[Tuple[int, ...]],
|
||||
stride: Optional[tuple[int, ...]],
|
||||
device: _device,
|
||||
dtype: _dtype,
|
||||
):
|
||||
|
@ -2,6 +2,7 @@
|
||||
import inspect
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from types import GenericAlias
|
||||
from typing import (
|
||||
Callable,
|
||||
List,
|
||||
@ -260,6 +261,15 @@ def out_wrapper(
|
||||
Adds the out parameter to a Python reference.
|
||||
"""
|
||||
out_type = (
|
||||
TensorLikeType
|
||||
if is_tensor
|
||||
else GenericAlias(
|
||||
tuple, tuple(TensorLikeType for _ in range(len(out_names)))
|
||||
)
|
||||
)
|
||||
# For backward compatibility - should be able to remove once PEP585
|
||||
# conversion is complete.
|
||||
bc_out_type = (
|
||||
TensorLikeType
|
||||
if is_tensor
|
||||
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
|
||||
@ -301,12 +311,12 @@ def out_wrapper(
|
||||
assert (
|
||||
(isinstance(result, TensorLike) and is_tensor)
|
||||
or (
|
||||
isinstance(result, Tuple) # type: ignore[arg-type]
|
||||
isinstance(result, tuple) # type: ignore[arg-type]
|
||||
and len(result) == len(out_names) # type: ignore[arg-type]
|
||||
)
|
||||
or (
|
||||
fn.__name__ == "unbind"
|
||||
and isinstance(result, (List, Tuple)) # type: ignore[arg-type]
|
||||
and isinstance(result, (List, tuple)) # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
|
||||
@ -336,9 +346,9 @@ def out_wrapper(
|
||||
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
|
||||
else:
|
||||
if fn.__name__ != "unbind":
|
||||
assert isinstance(out, Tuple) # type: ignore[arg-type]
|
||||
assert isinstance(out, tuple) # type: ignore[arg-type]
|
||||
else:
|
||||
assert isinstance(out, (List, Tuple)) # type: ignore[arg-type]
|
||||
assert isinstance(out, (list, tuple)) # type: ignore[arg-type]
|
||||
torch._check_type(
|
||||
len(out) == len(result), # type: ignore[arg-type]
|
||||
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
|
||||
@ -362,6 +372,7 @@ def out_wrapper(
|
||||
assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
|
||||
sig.empty,
|
||||
out_type,
|
||||
bc_out_type,
|
||||
)
|
||||
params = *sig.parameters.values(), out_param
|
||||
|
||||
|
@ -7,21 +7,10 @@ import itertools
|
||||
import math
|
||||
import operator
|
||||
import warnings
|
||||
from collections.abc import Iterable
|
||||
from collections.abc import Iterable, Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, reduce, singledispatch, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
@ -411,7 +400,7 @@ def _broadcast_shapes(*_shapes):
|
||||
assert isinstance(shape, Sequence)
|
||||
|
||||
# Computes common shape
|
||||
common_shape: List[Union[int, torch.SymInt]] = [
|
||||
common_shape: list[Union[int, torch.SymInt]] = [
|
||||
1,
|
||||
] * reduce(max, (len(shape) for shape in shapes))
|
||||
for arg_idx, shape in enumerate(shapes):
|
||||
@ -1421,7 +1410,7 @@ def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
|
||||
|
||||
@register_decomposition(aten.frexp)
|
||||
@out_wrapper("mantissa", "exponent")
|
||||
def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
|
||||
def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
|
||||
return torch.return_types.frexp(prims.frexp(self))
|
||||
|
||||
|
||||
@ -2052,7 +2041,7 @@ def _to_device(
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
kwargs = {
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
@ -2070,7 +2059,7 @@ def _to_device_str(
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
kwargs = {
|
||||
"device": torch.device(device),
|
||||
"dtype": dtype,
|
||||
@ -2087,7 +2076,7 @@ def _to_dtype(
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
kwargs = {
|
||||
"dtype": dtype,
|
||||
"non_blocking": non_blocking,
|
||||
@ -2103,7 +2092,7 @@ def _to_other(
|
||||
non_blocking: bool = False,
|
||||
copy: bool = False,
|
||||
memory_format: Optional[torch.memory_format] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
device = other.device
|
||||
dtype = other.dtype
|
||||
layout = other.layout
|
||||
@ -2311,7 +2300,7 @@ def any(
|
||||
@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
|
||||
def sum(
|
||||
a: TensorLikeType,
|
||||
dim: Union[Optional[int], Optional[List[int]]] = None,
|
||||
dim: Union[Optional[int], Optional[list[int]]] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@ -2363,7 +2352,7 @@ def sum_to_size(
|
||||
@register_decomposition(aten.prod)
|
||||
def prod(
|
||||
a: TensorLikeType,
|
||||
dim: Union[Optional[int], Optional[List[int]]] = None,
|
||||
dim: Union[Optional[int], Optional[list[int]]] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
dtype=None,
|
||||
@ -2481,7 +2470,7 @@ def var(
|
||||
@out_wrapper()
|
||||
def std(
|
||||
a: TensorLikeType,
|
||||
dim: Union[Optional[int], Optional[List[int]]] = None,
|
||||
dim: Union[Optional[int], Optional[list[int]]] = None,
|
||||
unbiased: Optional[bool] = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
@ -2660,7 +2649,7 @@ def addr(
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
def atleast_1d(
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_1d`."""
|
||||
if not args and isinstance(arg, collections.abc.Sequence):
|
||||
args_ = arg
|
||||
@ -2684,7 +2673,7 @@ def _unsqueeze_atleast(
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
def atleast_2d(
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_2d`."""
|
||||
if not args and isinstance(arg, collections.abc.Sequence):
|
||||
args_ = arg
|
||||
@ -2699,7 +2688,7 @@ def atleast_2d(
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
def atleast_3d(
|
||||
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
|
||||
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
|
||||
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
|
||||
"""Reference implementation of :func:`torch.atleast_3d`."""
|
||||
if not args and isinstance(arg, collections.abc.Sequence):
|
||||
args_ = arg
|
||||
@ -2742,7 +2731,7 @@ def broadcast_shapes(*shapes) -> ShapeType:
|
||||
|
||||
@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
|
||||
@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
|
||||
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
|
||||
def broadcast_tensors(*tensors) -> list[TensorLikeType]:
|
||||
if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
|
||||
tensors = tensors[0]
|
||||
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
|
||||
@ -2900,7 +2889,7 @@ def conj(input: TensorLikeType) -> TensorLikeType:
|
||||
@register_decomposition(aten.constant_pad_nd)
|
||||
@out_wrapper()
|
||||
def constant_pad_nd(
|
||||
input: TensorLikeType, pad: List[int], value: NumberType = 0
|
||||
input: TensorLikeType, pad: list[int], value: NumberType = 0
|
||||
) -> TensorLikeType:
|
||||
torch._check(
|
||||
len(pad) % 2 == 0,
|
||||
@ -3045,7 +3034,7 @@ def expand_as(a: Tensor, b: Tensor) -> Tensor:
|
||||
return a.expand(b.shape)
|
||||
|
||||
|
||||
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
|
||||
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]:
|
||||
if chunks <= 0:
|
||||
msg = f"Expected at least one chunk, but got {chunks}!"
|
||||
raise ValueError(msg)
|
||||
@ -3148,7 +3137,7 @@ def narrow(
|
||||
|
||||
def _normalize(
|
||||
a: Tensor, norm_dims: DimsType, eps: float
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
"""Computes mean and 1/std of a tensor along norm_dims.
|
||||
|
||||
Used as a helper function for normalization layers.
|
||||
@ -3176,7 +3165,7 @@ def _normalize(
|
||||
|
||||
|
||||
# add all specified dimensions
|
||||
def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
|
||||
def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType:
|
||||
for dim in sorted(dimensions):
|
||||
x = torch.unsqueeze(x, dim)
|
||||
return x
|
||||
@ -3192,7 +3181,7 @@ def native_group_norm(
|
||||
flattened_inner_size: int,
|
||||
num_groups: int,
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
torch._check(
|
||||
input.ndim >= 2,
|
||||
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
|
||||
@ -3243,7 +3232,7 @@ def native_layer_norm(
|
||||
weight: Optional[Tensor],
|
||||
bias: Optional[Tensor],
|
||||
eps: float,
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
) -> tuple[Tensor, Tensor, Tensor]:
|
||||
normalized_ndim = len(normalized_shape)
|
||||
torch._check(
|
||||
normalized_ndim >= 1,
|
||||
@ -4163,8 +4152,8 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
|
||||
|
||||
@register_decomposition(aten.split_with_sizes)
|
||||
def split_with_sizes(
|
||||
self: Tensor, split_sizes: List[int], dim: int = 0
|
||||
) -> List[Tensor]:
|
||||
self: Tensor, split_sizes: list[int], dim: int = 0
|
||||
) -> list[Tensor]:
|
||||
# NB: Perform the check_is_size tests first so that the
|
||||
# sum test does not try to do a replacement
|
||||
for i in range(len(split_sizes)):
|
||||
@ -4197,7 +4186,7 @@ def tensor_split(
|
||||
a: TensorLikeType,
|
||||
indices_or_sections: Union[Tensor, DimsType],
|
||||
dim: int = 0,
|
||||
) -> Tuple[TensorLikeType, ...]:
|
||||
) -> tuple[TensorLikeType, ...]:
|
||||
_dim = utils.canonicalize_dim(a.ndim, dim)
|
||||
if a.ndim == 0:
|
||||
msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
|
||||
@ -4263,7 +4252,7 @@ def tensor_split(
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
def hsplit(
|
||||
a: TensorLikeType, indices_or_sections: DimsType
|
||||
) -> Tuple[TensorLikeType, ...]:
|
||||
) -> tuple[TensorLikeType, ...]:
|
||||
torch._check(
|
||||
a.ndim >= 1,
|
||||
lambda: (
|
||||
@ -4305,7 +4294,7 @@ def hsplit(
|
||||
# CompositeImplicitAutograd - don't register decomp
|
||||
def vsplit(
|
||||
a: TensorLikeType, indices_or_sections: DimsType
|
||||
) -> Tuple[TensorLikeType, ...]:
|
||||
) -> tuple[TensorLikeType, ...]:
|
||||
torch._check(
|
||||
a.ndim >= 2,
|
||||
lambda: (
|
||||
@ -4480,7 +4469,7 @@ def diag_embed(
|
||||
|
||||
@register_decomposition(aten.block_diag)
|
||||
@out_wrapper()
|
||||
def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
|
||||
def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType:
|
||||
"""
|
||||
Reference implementation of torch.block_diag
|
||||
"""
|
||||
@ -4516,7 +4505,7 @@ def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
|
||||
return torch.cat(result, dim=0)
|
||||
|
||||
|
||||
def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType:
|
||||
def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType:
|
||||
"""
|
||||
This is used as an input to PythonRefInfo. `torch.block_diag`
|
||||
expects arguments splatted, but `aten.block_diag` expects only
|
||||
@ -5305,9 +5294,9 @@ def meshgrid(*tensors: TensorLikeType, indexing: str):
|
||||
|
||||
@register_decomposition(aten.meshgrid) # type: ignore[misc]
|
||||
def meshgrid(
|
||||
*tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
|
||||
*tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]],
|
||||
indexing: str,
|
||||
) -> List[TensorLikeType]:
|
||||
) -> list[TensorLikeType]:
|
||||
# This ref simultaneously handles two overloads (see stubs above)
|
||||
# The `indexing` argument is currently optional for torch.meshgrid, but we
|
||||
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
|
||||
@ -5346,7 +5335,7 @@ def meshgrid(
|
||||
),
|
||||
)
|
||||
|
||||
result_shape: List[int] = []
|
||||
result_shape: list[int] = []
|
||||
for t in tensors:
|
||||
assert isinstance(t, TensorLike) # mypy
|
||||
torch._check(
|
||||
@ -5355,7 +5344,7 @@ def meshgrid(
|
||||
)
|
||||
result_shape.append(t.numel())
|
||||
|
||||
grids: List[TensorLikeType] = []
|
||||
grids: list[TensorLikeType] = []
|
||||
for i, t in enumerate(tensors):
|
||||
assert isinstance(t, TensorLike) # mypy
|
||||
if t.ndim == 0:
|
||||
@ -5436,7 +5425,7 @@ def movedim(
|
||||
@register_decomposition(aten.empty_strided)
|
||||
@out_wrapper()
|
||||
def empty_strided(
|
||||
shape: Union[ShapeType, Tuple[ShapeType]],
|
||||
shape: Union[ShapeType, tuple[ShapeType]],
|
||||
strides: StrideType,
|
||||
*,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
@ -5854,7 +5843,7 @@ def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
|
||||
# form a pentagon that can be broken down into a top trapezoid and a bottom
|
||||
# rectangle. For the implementation of tril_indices, we need the sizes of
|
||||
# both of these, as well as the length of the top side of the trapezoid.
|
||||
def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
|
||||
def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
|
||||
if row == 0 or col == 0:
|
||||
return 0, 0, 0
|
||||
|
||||
@ -5932,7 +5921,7 @@ def tril_indices(
|
||||
# a bottom rectangle instead. Note that you can't reduce this to
|
||||
# _get_tril_sizes(col, row, -offset) because that would correspond to
|
||||
# decomposing into a left trapezoid and right rectangle.
|
||||
def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
|
||||
def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
|
||||
if row == 0 or col == 0:
|
||||
return 0, 0, 0
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import math
|
||||
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Literal, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._prims as prims
|
||||
@ -88,7 +89,7 @@ def _maybe_promote_tensor_fft(
|
||||
|
||||
|
||||
def _resize_fft_input(
|
||||
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
|
||||
x: TensorLikeType, dims: tuple[int, ...], sizes: tuple[int, ...]
|
||||
) -> TensorLikeType:
|
||||
"""
|
||||
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
|
||||
@ -268,8 +269,8 @@ def ihfft(
|
||||
|
||||
|
||||
class _ShapeAndDims(NamedTuple):
|
||||
shape: Tuple[int, ...]
|
||||
dims: Tuple[int, ...]
|
||||
shape: tuple[int, ...]
|
||||
dims: tuple[int, ...]
|
||||
|
||||
|
||||
def _canonicalize_fft_shape_and_dim_args(
|
||||
@ -339,8 +340,8 @@ def _prod(xs: Iterable[int]) -> int:
|
||||
def _fftn_c2c(
|
||||
function_name: str,
|
||||
input: TensorLikeType,
|
||||
shape: Tuple[int, ...],
|
||||
dim: Tuple[int, ...],
|
||||
shape: tuple[int, ...],
|
||||
dim: tuple[int, ...],
|
||||
norm: NormType,
|
||||
forward: bool,
|
||||
) -> TensorLikeType:
|
||||
@ -429,8 +430,8 @@ def ihfftn(
|
||||
|
||||
|
||||
class _CanonicalizeC2rReturn(NamedTuple):
|
||||
shape: Tuple[int, ...]
|
||||
dim: Tuple[int, ...]
|
||||
shape: tuple[int, ...]
|
||||
dim: tuple[int, ...]
|
||||
last_dim_size: int
|
||||
|
||||
|
||||
@ -566,7 +567,7 @@ def ihfft2(
|
||||
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
|
||||
|
||||
|
||||
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
|
||||
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> list[int]:
|
||||
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
|
||||
if dim is None:
|
||||
return list(range(x.ndim))
|
||||
|
@ -288,7 +288,7 @@ def norm(
|
||||
|
||||
# CompositeImplicitAutograd
|
||||
@out_wrapper("U", "S", "Vh", exact_dtype=True)
|
||||
def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]:
|
||||
return prims.svd(A, full_matrices=full_matrices)
|
||||
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
|
||||
|
||||
__all__: List[str] = []
|
||||
__all__: list[str] = []
|
||||
|
@ -6,9 +6,10 @@ import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, Callable, List, Optional, Sequence, TypeVar
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
|
||||
@ -77,8 +78,8 @@ class StrobelightCLIFunctionProfiler:
|
||||
run_user_name: str = "pytorch-strobelight-ondemand",
|
||||
timeout_wait_for_running_sec: int = 60,
|
||||
timeout_wait_for_finished_sec: int = 60,
|
||||
recorded_env_variables: Optional[List[str]] = None,
|
||||
sample_tags: Optional[List[str]] = None,
|
||||
recorded_env_variables: Optional[list[str]] = None,
|
||||
sample_tags: Optional[list[str]] = None,
|
||||
stack_max_len: int = 127,
|
||||
async_stack_max_len: int = 127,
|
||||
):
|
||||
@ -91,7 +92,7 @@ class StrobelightCLIFunctionProfiler:
|
||||
# Results of the most recent run.
|
||||
# Tracks the strobelight run id of the most recent run
|
||||
self.current_run_id: Optional[int] = None
|
||||
self.profile_result: Optional[List[str]] = None
|
||||
self.profile_result: Optional[list[str]] = None
|
||||
self.sample_tags = sample_tags
|
||||
|
||||
def _run_async(self) -> None:
|
||||
|
Reference in New Issue
Block a user