PEP585 update - torch/_C torch/_decomp torch/_lazy torch/_library torch/_numpy torch/_prims torch/_refs torch/_strobelight (#145102)

See #145101 for details.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145102
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145105
This commit is contained in:
Aaron Orenstein
2025-01-18 08:56:06 -08:00
committed by PyTorch MergeBot
parent a79100ab11
commit 5b5766665d
27 changed files with 268 additions and 266 deletions

View File

@ -1,11 +1,11 @@
from typing import Callable, Tuple
from typing import Callable
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
def set_autograd_compiler(
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
dynamic: bool,
) -> Tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
) -> tuple[Callable[[], AutogradCompilerInstance] | None, bool]: ...
def clear_cache() -> None: ...
def is_cache_empty() -> bool: ...
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...

View File

@ -1,5 +1,5 @@
import types
from typing import Dict, NewType, Tuple
from typing import NewType
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
@ -31,17 +31,17 @@ class _ExtraState:
# properties Dynamo cares about for a frame.
class _PyInterpreterFrame:
f_code: types.CodeType
f_locals: Dict[str, object]
f_globals: Dict[str, object]
f_builtins: Dict[str, object]
f_locals: dict[str, object]
f_globals: dict[str, object]
f_builtins: dict[str, object]
f_lasti: int
f_lineo: int
f_back: types.FrameType
# A tuple containing cell objects captured by this frame.
closure: Tuple[types.CellType]
closure: tuple[types.CellType]
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
py_opcode_caches: list[int]
def code_framelocals_names(code: types.CodeType) -> Tuple[str]: ...
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, Callable, Dict
from typing import Any, Callable
import torch
@ -121,7 +121,7 @@ def install_storage_overlapping_guard(
): ...
def profile_guard_manager(
guard_manager: GuardManager,
f_locals: Dict[str, Any],
f_locals: dict[str, Any],
) -> float: ...
class TensorGuards:

View File

@ -1,4 +1,4 @@
from typing import AnyStr, overload, Tuple
from typing import AnyStr, overload
from torch import Tensor
@ -16,4 +16,4 @@ class DelayedError:
@overload
def __call__(self, i0: Tensor) -> Tensor: ...
@overload
def __call__(self, *args: Tensor) -> Tuple[Tensor, ...]: ...
def __call__(self, *args: Tensor) -> tuple[Tensor, ...]: ...

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import inspect
from collections import defaultdict
from collections.abc import Sequence
from functools import lru_cache, partial, wraps
from itertools import chain
from typing import (
@ -9,7 +10,6 @@ from typing import (
FrozenSet,
List,
Optional,
Sequence,
Set,
TYPE_CHECKING,
TypeVar,
@ -44,8 +44,8 @@ _P = ParamSpec("_P")
# TODO: relax key type here; torch registrations should be possible to; but
# right now this type is accurate
global_decomposition_table: Dict[
str, Dict[torch._ops.OperatorBase, Callable]
global_decomposition_table: dict[
str, dict[torch._ops.OperatorBase, Callable]
] = defaultdict(dict)
decomposition_table = global_decomposition_table["post_autograd"]
@ -78,7 +78,7 @@ def _add_op_to_registry(registry, op, fn):
If op is OpOverload, it will be added to the registry directly.
If op is OpOverloadPacket, all the valid op_overloads in the packet will be added to the registry.
"""
overloads: List[Union[torch._ops.OperatorBase]] = []
overloads: list[Union[torch._ops.OperatorBase]] = []
if isinstance(op, HigherOrderOperator):
# There's no concept of overloads for HigherOrderOperator
registry[op] = fn
@ -232,7 +232,7 @@ def register_decomposition(
def get_decompositions(
aten_ops: Sequence[Union[torch._ops.OperatorBase, OpOverloadPacket]],
type: str = "post_autograd",
) -> Dict[torch._ops.OperatorBase, Callable]:
) -> dict[torch._ops.OperatorBase, Callable]:
"""
Retrieve a dictionary of decompositions corresponding to the list of
operator overloads and overload packets passed as input. Overload
@ -251,7 +251,7 @@ def get_decompositions(
for opo in registry:
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)
decompositions: Dict[torch._ops.OperatorBase, Callable] = {}
decompositions: dict[torch._ops.OperatorBase, Callable] = {}
for op in aten_ops:
if isinstance(op, OpOverloadPacket) and op in packets_to_overloads:
for op_overload in packets_to_overloads[op]:
@ -262,7 +262,7 @@ def get_decompositions(
def remove_decompositions(
decompositions: Dict[torch._ops.OperatorBase, Callable],
decompositions: dict[torch._ops.OperatorBase, Callable],
aten_ops: Sequence[Union[OpOverload, OpOverloadPacket]],
) -> None:
"""
@ -297,7 +297,7 @@ def core_aten_decompositions() -> "CustomDecompTable":
# excluding decompositions that results in prim ops
# Resulting opset of decomposition is core aten ops
def _core_aten_decompositions_post_autograd() -> (
Dict[torch._ops.OperatorBase, Callable]
dict[torch._ops.OperatorBase, Callable]
):
aten = torch.ops.aten
return get_decompositions(

View File

@ -5,10 +5,11 @@ import itertools
import numbers
import operator
import sys
from collections.abc import Iterable
from enum import Enum
from functools import partial, reduce
from itertools import chain, product
from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, cast, Optional, Union
import torch
import torch._meta_registrations
@ -39,7 +40,7 @@ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []
__all__: list[str] = []
aten = torch._ops.ops.aten
@ -299,7 +300,7 @@ def _prelu_kernel_backward(
grad_output: Tensor,
self: Tensor,
weight: Tensor,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
input_grad = torch.where(self > 0, grad_output, weight * grad_output)
weight_grad = torch.where(self > 0, 0.0, self * grad_output)
return (input_grad, weight_grad)
@ -690,7 +691,7 @@ def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
@out_wrapper()
def slice_backward(
grad_output: Tensor,
input_sizes: List[int],
input_sizes: list[int],
dim: int,
start: int,
end: int,
@ -760,7 +761,7 @@ def slice_forward(
def _normalize_start_end(
x: Tensor, dim: int, start: Optional[int], end: Optional[int]
) -> Tuple[int, int]:
) -> tuple[int, int]:
"""
Normalize start and end such that both are in the range
[0, x.get_size()[dim]] and start <= end.
@ -824,7 +825,7 @@ def slice_scatter(
@register_decomposition(aten.select_backward)
@out_wrapper()
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int):
grad_input = grad_output.new_zeros(input_sizes)
return torch.select_scatter(grad_input, grad_output, dim, index)
@ -832,7 +833,7 @@ def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index
@register_decomposition(aten.diagonal_backward)
@out_wrapper()
def diagonal_backward(
grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
):
grad_input = grad_output.new_zeros(input_sizes)
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
@ -899,10 +900,10 @@ def _im2col_col2im_indices_along_dim(
@out_wrapper()
def im2col(
input: Tensor,
kernel_size: List[int],
dilation: List[int],
padding: List[int],
stride: List[int],
kernel_size: list[int],
dilation: list[int],
padding: list[int],
stride: list[int],
) -> Tensor:
torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
@ -982,11 +983,11 @@ def im2col(
@pw_cast_for_opmath
def col2im(
input: Tensor,
output_size: List[int],
kernel_size: List[int],
dilation: List[int],
padding: List[int],
stride: List[int],
output_size: list[int],
kernel_size: list[int],
dilation: list[int],
padding: list[int],
stride: list[int],
) -> Tensor:
torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
@ -1094,7 +1095,7 @@ def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
@register_decomposition(aten.unfold_backward)
@out_wrapper()
def unfold_backward(
grad: Tensor, input_size: List[int], dimension: int, size: int, step: int
grad: Tensor, input_size: list[int], dimension: int, size: int, step: int
) -> Tensor:
if len(input_size) == 0:
return torch.squeeze_copy(grad, 0)
@ -1257,7 +1258,7 @@ def embedding_dense_backward(
)
def prod(x: List[int]):
def prod(x: list[int]):
r = 1
for i in x:
r *= i
@ -1265,10 +1266,10 @@ def prod(x: List[int]):
def _pad_chunk(
tensors: List[Tensor],
tensors: list[Tensor],
dim: int,
num_chunks: int,
) -> List[Tensor]:
) -> list[Tensor]:
padded_tensors = []
for tensor in tensors:
tensor_size = tensor.size()
@ -1285,7 +1286,7 @@ def _pad_chunk(
return padded_tensors
def have_same_ndims(tensors: List[Tensor]):
def have_same_ndims(tensors: list[Tensor]):
ndim = tensors[0].ndim
for tensor in tensors:
if tensor.ndim != ndim:
@ -1293,7 +1294,7 @@ def have_same_ndims(tensors: List[Tensor]):
return True
def leading_dimension_matches(tensors: List[Tensor], dim: int):
def leading_dimension_matches(tensors: list[Tensor], dim: int):
leading_dim_sizes = tensors[0].size()[:dim]
for tensor in tensors:
torch._check(
@ -1303,7 +1304,7 @@ def leading_dimension_matches(tensors: List[Tensor], dim: int):
def _preprocess_chunk_cat_inputs(
tensors: List[Tensor],
tensors: list[Tensor],
dim: int,
num_chunks: int,
):
@ -1341,7 +1342,7 @@ def _preprocess_chunk_cat_inputs(
@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
def _chunk_cat(
tensors: List[Tensor],
tensors: list[Tensor],
dim: int,
num_chunks: int,
out: Optional[Tensor] = None,
@ -1361,10 +1362,10 @@ def _chunk_cat(
)
def split_with_sizes_copy(
self: Tensor,
split_sizes: List[int],
split_sizes: list[int],
dim: int = 0,
out: Optional[List[Tensor]] = None,
) -> Optional[List[Tensor]]:
out: Optional[list[Tensor]] = None,
) -> Optional[list[Tensor]]:
splits = aten.split_with_sizes(self, split_sizes, dim=dim)
if out is None:
return [s.clone(memory_format=torch.contiguous_format) for s in splits]
@ -1376,19 +1377,19 @@ def split_with_sizes_copy(
@register_decomposition(aten.unsafe_split.Tensor)
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
return aten.split.Tensor(input, split_size, dim)
@register_decomposition(aten.unsafe_split_with_sizes.default)
def unsafe_split_with_sizes(
input: Tensor, split_sizes: List[int], dim: int = 0
) -> Tuple[Tensor, ...]:
input: Tensor, split_sizes: list[int], dim: int = 0
) -> tuple[Tensor, ...]:
return aten.split_with_sizes.default(input, split_sizes, dim)
@register_decomposition(aten.split.Tensor)
def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
input_sizes = self.shape
dim_size = input_sizes[dim]
if split_size == 0:
@ -1412,7 +1413,7 @@ def tensor_split_tensor_indices_or_sections_py_impl(
self: Tensor,
tensor_indices_or_sections: Tensor,
dim: int = 0,
) -> Tuple[Tensor, ...]:
) -> tuple[Tensor, ...]:
assert tensor_indices_or_sections.device.type == "cpu"
assert tensor_indices_or_sections.dtype == torch.int64
split_dim = tensor_indices_or_sections.dim()
@ -1505,8 +1506,8 @@ def native_group_norm_backward(
C: int,
HxW: int,
group: int,
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
utils.check_same_device(
grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
)
@ -1593,12 +1594,12 @@ def native_group_norm_backward_out(
C: int,
HxW: int,
group: int,
output_mask: List[bool],
output_mask: list[bool],
*,
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
result = native_group_norm_backward(
grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
)
@ -1622,13 +1623,13 @@ def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
normalized_shape: list[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
@ -1643,8 +1644,8 @@ def native_layer_norm_backward(
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices: List[int] = []
outer_dim_indices: List[int] = []
inner_dim_indices: list[int] = []
outer_dim_indices: list[int] = []
for i in range(input_ndim):
if i >= axis:
inner_dim_indices.append(i)
@ -1705,17 +1706,17 @@ def native_layer_norm_backward(
def native_layer_norm_backward_out(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
normalized_shape: list[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
output_mask: list[bool],
*,
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
result = native_layer_norm_backward(
grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
)
@ -1738,7 +1739,7 @@ def native_batch_norm_helper(
momentum: float,
eps: float,
functional: bool,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
reduction_dims = [0] + list(range(2, input.dim()))
computation_dtype = utils.get_computation_dtype(input.dtype)
new_running_mean = running_mean
@ -1821,7 +1822,7 @@ def native_batch_norm(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, running_mean, running_var, training, momentum, eps, False
)
@ -1849,7 +1850,7 @@ def native_batch_norm_decomposition(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
if running_mean is None and running_var is None:
return aten._native_batch_norm_legit(
input, weight, bias, training, momentum, eps
@ -1876,7 +1877,7 @@ def native_batch_norm_decomposition(
@aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> List[Tensor]:
def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]:
dim_size = tensor.size(dim)
split_size = (dim_size + chunks - 1) // chunks
@ -1896,7 +1897,7 @@ def _native_batch_norm_legit_no_training(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
return aten._native_batch_norm_legit.default(
input,
weight,
@ -1919,7 +1920,7 @@ def _native_batch_norm_legit(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, running_mean, running_var, training, momentum, eps, False
)
@ -1934,7 +1935,7 @@ def _native_batch_norm_legit_no_stats(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input, weight, bias, None, None, training, momentum, eps, False
)
@ -1951,7 +1952,7 @@ def _native_batch_norm_legit_functional(
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
(
output,
save_mean,
@ -2002,7 +2003,7 @@ def _batch_norm_with_update(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
@ -2029,7 +2030,7 @@ def _batch_norm_with_update_functional(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
(
output,
save_mean,
@ -2056,7 +2057,7 @@ def _batch_norm_no_update(
running_var: Tensor,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
input,
weight,
@ -2190,9 +2191,9 @@ def batch_norm_backward(
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
output_mask: list[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,
@ -2218,8 +2219,8 @@ def native_batch_norm_backward(
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
output_mask: list[bool],
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_dtype = input.dtype
if weight is not None:
weight_dtype = weight.dtype
@ -2262,10 +2263,10 @@ def native_batch_norm_backward(
mean = running_mean_cast
invstd = torch.rsqrt(running_var_cast + eps)
broadcast_mask: List[int] = [1] * input_rank
broadcast_mask: list[int] = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]
reduction_axes: List[int] = []
reduction_axes: list[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)
@ -2320,12 +2321,12 @@ def native_batch_norm_backward_out(
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
output_mask: list[bool],
*,
out0: torch.Tensor,
out1: torch.Tensor,
out2: torch.Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
result = native_batch_norm_backward(
grad_out,
input,
@ -2403,7 +2404,7 @@ def cudnn_batch_norm_backward(
@register_decomposition(aten._adaptive_avg_pool2d)
@out_wrapper()
@pw_cast_for_opmath
def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]):
# Preconditions
device = input.device
shape = input.shape
@ -2506,7 +2507,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Tuple[int, int]):
def _max_unpoolnd(
self: TensorLike, indices: TensorLike, output_size: List[int], dim: int
self: TensorLike, indices: TensorLike, output_size: list[int], dim: int
):
# If the input tensors self and indices came from max_pool call as
# required by the documentation, this operation is deterministic
@ -2534,7 +2535,7 @@ def _max_unpoolnd(
def max_unpool2d(
self: TensorLike,
indices: TensorLike,
output_size: List[int],
output_size: list[int],
):
torch._check(
indices.dtype == torch.int64,
@ -2581,9 +2582,9 @@ def max_unpool2d(
def max_unpool3d(
input: TensorLike,
indices: TensorLike,
output_size: List[int],
stride: List[int],
padding: List[int],
output_size: list[int],
stride: list[int],
padding: list[int],
):
torch._check(
indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
@ -2761,7 +2762,7 @@ def _index_copy(
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper("output", "buffer")
@pw_cast_for_opmath
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda or self.is_xpu:
@ -2840,8 +2841,8 @@ def get_scale_value(scales, idx):
@aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_vec(
input: Tensor,
output_size: Optional[List[int]],
scale_factors: Optional[List[float]],
output_size: Optional[list[int]],
scale_factors: Optional[list[float]],
) -> Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scales = (
@ -2861,8 +2862,8 @@ def _upsample_nearest_vec(
@aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
def _upsample_nearest_exact_vec(
input: Tensor,
output_size: Optional[List[int]],
scale_factors: Optional[List[float]],
output_size: Optional[list[int]],
scale_factors: Optional[list[float]],
) -> Tensor:
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
scales = (
@ -2909,7 +2910,7 @@ def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest1d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales: Optional[float] = None,
) -> Tensor:
return _upsample_nearest(input, output_size, [scales])
@ -2923,7 +2924,7 @@ def upsample_nearest1d(
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest_exact1d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales: Optional[float] = None,
) -> Tensor:
return _upsample_nearest(input, output_size, [scales], exact=True)
@ -2935,7 +2936,7 @@ def upsample_nearest_exact1d(
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest2d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
@ -2950,7 +2951,7 @@ def upsample_nearest2d(
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact2d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
) -> Tensor:
@ -2963,7 +2964,7 @@ def _upsample_nearest_exact2d(
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def upsample_nearest3d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
@ -2979,7 +2980,7 @@ def upsample_nearest3d(
@out_wrapper(preserve_memory_format=True, exact_dtype=True)
def _upsample_nearest_exact3d(
input: Tensor,
output_size: List[int],
output_size: list[int],
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
@ -2992,8 +2993,8 @@ def _upsample_nearest_exact3d(
@pw_cast_for_opmath
def _upsample_nearest(
input: Tensor,
output_size: List[int],
scales: List[Optional[float]],
output_size: list[int],
scales: list[Optional[float]],
exact: bool = False,
) -> Tensor:
spatial_indices = _compute_upsample_nearest_indices(
@ -3072,7 +3073,7 @@ def one_layer_rnn_data(
hh_bias = params[3] if has_biases else None
step_output = []
hiddens: List[torch.Tensor] = []
hiddens: list[torch.Tensor] = []
last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
cur_hidden = hidden.narrow(0, 0, last_batch_size)
@ -3159,7 +3160,7 @@ def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
hx = hidden[0].unsqueeze(0)
cx = hidden[1].unsqueeze(0)
batch_sizes: List[int] = []
batch_sizes: list[int] = []
mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
hidden_size = hx.size(2)
num_layers = 1
@ -3710,7 +3711,7 @@ def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
@out_wrapper()
def upsample_linear1d(
input: Tensor,
output_size: List[int],
output_size: list[int],
align_corners: bool,
scales_w: Optional[float] = None,
) -> Tensor:
@ -3724,7 +3725,7 @@ def upsample_linear1d(
@out_wrapper()
def upsample_bilinear2d(
input: Tensor,
output_size: List[int],
output_size: list[int],
align_corners: bool,
scales_h: Optional[float] = None,
scales_w: Optional[float] = None,
@ -3738,7 +3739,7 @@ def upsample_bilinear2d(
@out_wrapper()
def upsample_trilinear3d(
input: Tensor,
output_size: List[int],
output_size: list[int],
align_corners: bool,
scales_d: Optional[float] = None,
scales_h: Optional[float] = None,
@ -3785,9 +3786,9 @@ def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
@pw_cast_for_opmath
def _upsample_linear(
input: Tensor,
output_size: List[int],
output_size: list[int],
align_corners: bool,
scales: List[Optional[float]],
scales: list[Optional[float]],
) -> Tensor:
# get dimensions of original image
n_channels = input.shape[1]
@ -3937,7 +3938,7 @@ def _nll_loss_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
# self can be [N, C] or [C]
# target can be [N] or []
@ -3992,7 +3993,7 @@ def nll_loss_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
assert self.dim() > 0 and self.dim() <= 2, "input tensor should be 1D or 2D"
assert (
target.dim() <= 1
@ -4020,7 +4021,7 @@ def nll_loss2d_forward(
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
return _nll_loss_forward(self, target, weight, reduction, ignore_index)
@ -4108,7 +4109,7 @@ def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: boo
return grid_x + grid_y + grid_z + grid_one
def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: bool):
def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool):
n, _, h, w = size
base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
# base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
@ -4118,7 +4119,7 @@ def _affine_grid_generator_4d(theta: Tensor, size: List[int], align_corners: boo
return grid.view(n, h, w, 2)
def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: bool):
def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool):
n, _, d, h, w = size
base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
# base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
@ -4131,7 +4132,7 @@ def _affine_grid_generator_5d(theta: Tensor, size: List[int], align_corners: boo
@register_decomposition(aten.affine_grid_generator)
@out_wrapper()
@pw_cast_for_opmath
def affine_grid_generator(theta: Tensor, size: List[int], align_corners: bool):
def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool):
torch._check(
len(size) in (4, 5),
lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
@ -4454,7 +4455,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
p = tensor2.size(-1) if dim_tensor2 > 1 else 1
batch_tensor2: List[int] = []
batch_tensor2: list[int] = []
# TODO: handling of slice
for i in range(dim_tensor2 - 2):
batch_tensor2.append(tensor2.size(i))
@ -4520,7 +4521,7 @@ def matmul(tensor1, tensor2, *, is_out=False):
@pw_cast_for_opmath
def upsample_bicubic2d_default(
input: Tensor,
output_size: Tuple[int, int],
output_size: tuple[int, int],
align_corners: bool,
scale_h: Optional[float] = None,
scale_w: Optional[float] = None,
@ -4608,9 +4609,9 @@ def upsample_bicubic2d_default(
@pw_cast_for_opmath
def upsample_bicubic2d_vec(
a: Tensor,
output_size: Optional[Tuple[int, int]],
output_size: Optional[tuple[int, int]],
align_corners: bool,
scale_factors: Optional[Tuple[float, float]] = None,
scale_factors: Optional[tuple[float, float]] = None,
) -> Tensor:
torch._check(
bool(output_size) + bool(scale_factors) == 1,
@ -4619,7 +4620,7 @@ def upsample_bicubic2d_vec(
if output_size is None:
assert scale_factors is not None
output_size = cast(
Tuple[int, int],
tuple[int, int],
tuple(
sym_int(sym_float(w) * scale)
for w, scale in zip(a.shape[2:], scale_factors)
@ -4634,7 +4635,7 @@ def upsample_bicubic2d_vec(
@register_decomposition(aten.reflection_pad3d)
@pw_cast_for_opmath
@out_wrapper()
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
@ -4651,7 +4652,7 @@ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
@register_decomposition(aten.replication_pad3d)
@pw_cast_for_opmath
@out_wrapper()
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
def idx(left, middle, right):
dim_idx = torch.arange(-left, middle + right, device=a.device)
return torch.clamp(dim_idx, 0, middle - 1)
@ -4665,7 +4666,7 @@ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
def _reflection_or_replication_pad(
a: Tensor,
padding: Tuple[int, ...],
padding: tuple[int, ...],
idx_fn: Callable[[int, int, int], Tensor],
) -> Tensor:
dim = len(padding) // 2
@ -4681,7 +4682,7 @@ def _reflection_or_replication_pad(
result = a
for i in range(dim):
idx: List[Any] = [None] * result.dim()
idx: list[Any] = [None] * result.dim()
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
result = aten._unsafe_index(result, idx)
@ -4887,7 +4888,7 @@ def multilabel_margin_loss_forward(
input: Tensor,
target: Tensor,
reduction: int,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
orig_input_shape = input.shape
orig_target_shape = target.shape
input = torch.atleast_2d(input)
@ -4953,7 +4954,7 @@ def scaled_dot_product_flash_attention_for_cpu(
*,
attn_mask: Optional[Tensor] = None,
scale: Optional[float] = None,
) -> Tuple[Tensor, Tensor]:
) -> tuple[Tensor, Tensor]:
torch._check(
torch.is_floating_point(query),
lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import inspect
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Optional
import torch
import torch._decomp
@ -10,7 +10,7 @@ from torch._prims_common.wrappers import _maybe_remove_out_wrapper
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
decomposition_table_for_jvp: dict[torch._ops.OperatorBase, Callable] = {}
register_decomposition = torch._decomp.register_decomposition
aten = torch.ops.aten
@ -104,7 +104,7 @@ def trace(self: Tensor) -> Tensor:
@maybe_register_decomposition(aten.log_sigmoid_forward.default)
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda or self.is_xpu:
@ -115,7 +115,7 @@ def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
def recompute_mean_var(
input: Tensor, rstd: Tensor, inner_dim_indices: List[int], keepdim: bool
input: Tensor, rstd: Tensor, inner_dim_indices: list[int], keepdim: bool
):
# for most norm decompositions, it will be the same as the core version except for here.
# We recompute the mean and variance so that they track gradients through input
@ -132,13 +132,13 @@ def recompute_mean_var(
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
normalized_shape: list[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
output_mask: list[bool],
) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
@ -205,7 +205,7 @@ def native_layer_norm_backward(
return (d_input, d_weight, d_bias)
def prod(x: List[int]):
def prod(x: list[int]):
r = 1
for i in x:
r *= i
@ -223,8 +223,8 @@ def native_batch_norm_backward(
save_invstd: Optional[Tensor],
train: bool,
eps: float,
output_mask: List[bool],
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
output_mask: list[bool],
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"
@ -251,7 +251,7 @@ def native_batch_norm_backward(
broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]
reduction_axes: List[int] = []
reduction_axes: list[int] = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)
@ -305,9 +305,9 @@ def batch_norm_backward(
save_var: Optional[Tensor],
update: bool,
eps: float,
output_mask: List[bool],
output_mask: list[bool],
reserve: Tensor,
) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
return native_batch_norm_backward(
grad_out,
input,

View File

@ -2,7 +2,7 @@
# mypy: allow-untyped-defs
import functools
from collections import defaultdict
from typing import Callable, Dict
from typing import Callable
import torch
import torch._decomp as decomp
@ -12,7 +12,7 @@ from torch._ops import OpOverload
aten = torch.ops.aten
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
rng_decompositions: dict[str, dict[OpOverload, Callable]] = defaultdict(dict)
def register_rng_decomposition(aten_op):

View File

@ -1,11 +1,11 @@
import threading
from typing import Any, Dict, Optional
from typing import Any, Optional
import torch._C._lazy
class DeviceContext:
_CONTEXTS: Dict[str, Any] = {}
_CONTEXTS: dict[str, Any] = {}
_CONTEXTS_LOCK = threading.Lock()
def __init__(self, device: str) -> None:

View File

@ -3,7 +3,7 @@ import copy
import dataclasses
import itertools
import os
from typing import Any, Callable, Dict, List
from typing import Any, Callable
import torch
import torch._lazy as lazy
@ -28,14 +28,14 @@ class GraphInputMatcher:
TS/XLA graph inputs.
"""
tensor_id_to_arg_idx: Dict[int, int]
graph_input_tensor_ids: List[int]
tensor_id_to_arg_idx: dict[int, int]
graph_input_tensor_ids: list[int]
# there are 2 categories of graph_input_tensors.
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
# most likely const tensors and we can get its content from graph_input_tensors
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
# the tensor from method arguments
graph_input_ivalues: List[Any]
graph_input_ivalues: list[Any]
# get the real graph input tensors
def __call__(self, args):
@ -71,10 +71,10 @@ class ReturnValueHandler:
"""
def __init__(self, lazy_out_list):
self.index: List[List[int]] = []
self.index: list[list[int]] = []
self.total_count = len(lazy_out_list)
tensor_id_to_idx: Dict[int, int] = {}
tensor_id_to_idx: dict[int, int] = {}
for dup_idx, lazy_tensor in enumerate(lazy_out_list):
uniq_idx = tensor_id_to_idx.get(id(lazy_tensor), None)
if uniq_idx is not None:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import dataclasses
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Protocol
from typing import Any, Callable, Optional, Protocol
from torch import _C, _ops, autograd, Tensor
from torch.utils import _pytree
@ -28,7 +28,7 @@ def make_autograd_impl(op: _ops.OpOverload, info: InfoProtocol) -> Callable:
@dataclass
class Metadata:
keyset: _C.DispatchKeySet
keyword_only_args: Dict[str, Any]
keyword_only_args: dict[str, Any]
def forward_no_grad(*args):
metadata = args[-1]

View File

@ -2,20 +2,9 @@
import inspect
import logging
import weakref
from collections.abc import Iterable, Sequence
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
overload,
Sequence,
Set,
Union,
)
from typing import Any, Callable, Literal, Optional, overload, Union
import torch
from torch import _C, _ops, Tensor
@ -200,16 +189,16 @@ class CustomOpDef:
self._init_fn = fn
self._backend_fns: Dict[Union[str, None], Callable] = {}
self._backend_fns: dict[Union[str, None], Callable] = {}
self._abstract_fn: Optional[Callable] = None
self._setup_context_fn: Optional[Callable] = None
self._backward_fn: Optional[Callable] = None
self._torch_dispatch_fns: Dict[type, Callable] = {}
self._torch_dispatch_fns: dict[type, Callable] = {}
self._vmap_fn: Optional[Callable] = None
self._lib = get_library_allowing_overwrite(self._namespace, self._name)
self._register_to_dispatcher()
self._disabled_kernel: Set = set()
self._disabled_kernel: set = set()
OPDEFS[self._qualname] = self
@property
@ -332,7 +321,7 @@ class CustomOpDef:
def inner(fn):
if device_types is None or isinstance(device_types, str):
dtypes: List[Union[str, None]] = [device_types]
dtypes: list[Union[str, None]] = [device_types]
else:
dtypes = list(device_types)
for device_type in dtypes:
@ -807,7 +796,7 @@ def increment_version(val: Any) -> None:
# decorator.
OPDEF_TO_LIB: Dict[str, "torch.library.Library"] = {}
OPDEF_TO_LIB: dict[str, "torch.library.Library"] = {}
OPDEFS: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import copy
import logging
from typing import Any, Dict, Optional, Protocol, Tuple, Union
from typing import Any, Optional, Protocol, Union
import torch
from torch._library.utils import parse_namespace
@ -56,7 +56,7 @@ class HasStaticMethodFromReal(Protocol):
class FakeClassRegistry:
def __init__(self) -> None:
self._registered_class: Dict[str, Any] = {}
self._registered_class: dict[str, Any] = {}
def has_impl(self, full_qualname: str) -> bool:
return full_qualname in self._registered_class
@ -290,7 +290,7 @@ def _full_qual_class_name(qualname: str) -> str:
# Return the namespace and class name from fully qualified name.
def _ns_and_class_name(full_qualname: str) -> Tuple[str, str]:
def _ns_and_class_name(full_qualname: str) -> tuple[str, str]:
splits = full_qualname.split(".")
assert len(splits) == 5, f"Could not split {full_qualname=}"
_torch, _torch_ns, _classes, ns, class_name = splits

View File

@ -1,6 +1,7 @@
import contextlib
import threading
from typing import Any, Callable, Generator, Iterable, Optional, Union
from collections.abc import Generator, Iterable
from typing import Any, Callable, Optional, Union
from torch.utils._exposed_in import exposed_in

View File

@ -3,7 +3,8 @@ import dataclasses
import inspect
import sys
import warnings
from typing import Any, Callable, Dict, Iterable, Iterator, List, Tuple, Union
from collections.abc import Iterable, Iterator
from typing import Any, Callable, Union
import torch
import torch.utils._pytree as pytree
@ -55,7 +56,7 @@ def get_source(stacklevel: int) -> str:
return source
def parse_namespace(qualname: str) -> Tuple[str, str]:
def parse_namespace(qualname: str) -> tuple[str, str]:
splits = qualname.split("::")
if len(splits) != 2:
raise ValueError(
@ -189,8 +190,8 @@ def fill_defaults(schema, args, kwargs):
def zip_schema(
schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any]
) -> Iterable[Tuple[_C.Argument, Any]]:
schema: _C.FunctionSchema, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> Iterable[tuple[_C.Argument, Any]]:
"""zips schema.arguments and (args, kwargs) together.
Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload:
@ -332,7 +333,7 @@ def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]:
def iter_tensors(
args: Tuple[Any], kwargs: Dict[str, Any], allowed_nesting: int = 1
args: tuple[Any], kwargs: dict[str, Any], allowed_nesting: int = 1
) -> Iterator[torch.Tensor]:
def check(arg):
if isinstance(arg, torch.Tensor):
@ -465,7 +466,7 @@ def has_fake_kernel(op: torch._ops.OpOverload) -> bool:
return False
def mutated_args_kwargs(schema: _C.FunctionSchema) -> Tuple[List[int], List[str]]:
def mutated_args_kwargs(schema: _C.FunctionSchema) -> tuple[list[int], list[str]]:
idxs = []
keys = []
for i, info in enumerate(schema.arguments):

View File

@ -12,7 +12,7 @@ from __future__ import annotations
import builtins
import itertools
import operator
from typing import Optional, Sequence, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
import torch
@ -20,6 +20,8 @@ from . import _dtypes_impl, _util
if TYPE_CHECKING:
from collections.abc import Sequence
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import builtins
import math
import operator
from typing import Sequence
from collections.abc import Sequence
import torch

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import functools
import math
from typing import Sequence
from typing import TYPE_CHECKING
import torch
@ -12,6 +12,10 @@ from . import _dtypes_impl, _util
from ._normalizations import ArrayLike, KeepDims, normalizer
if TYPE_CHECKING:
from collections.abc import Sequence
class LinAlgError(Exception):
pass

View File

@ -1,8 +1,9 @@
# mypy: allow-untyped-defs
import operator
from collections.abc import Sequence
from enum import Enum
from functools import partial, reduce
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
from typing import Callable, List, Optional, Tuple, Type, Union
import torch
import torch._prims_common as utils
@ -231,8 +232,8 @@ def TensorMeta(
if isinstance(tensorlike, Number):
assert not shape and (shape is None or isinstance(shape, Sequence))
assert not strides and (strides is None or isinstance(strides, Sequence))
inferred_shape: Tuple[int, ...] = ()
inferred_strides: Tuple[int, ...] = ()
inferred_shape: tuple[int, ...] = ()
inferred_strides: tuple[int, ...] = ()
inferred_dtype = type_to_dtype(type(tensorlike))
inferred_device = torch.device("cpu")
# TODO: This looks wrong, a number that is wrapped into a tensor
@ -266,7 +267,7 @@ def TensorMeta(
def _make_prim(
*,
schema: str,
return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
return_type: Union[RETURN_TYPE, tuple[RETURN_TYPE, ...]],
meta: Callable,
impl_aten: Callable,
doc: str,
@ -383,7 +384,7 @@ class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
def _prim_elementwise_meta(
*args,
type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
args_with_fixed_dtypes: Optional[Tuple[TensorLikeType, ...]] = None,
args_with_fixed_dtypes: Optional[tuple[TensorLikeType, ...]] = None,
) -> FakeTensor:
"""
Meta function for elementwise operations that produce outputs in the same dtype
@ -1358,7 +1359,7 @@ def _validate_collapse_args(a: Tensor, start: int, end: int) -> None:
)
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
def _collapsed_shape(shape: ShapeType, start: int, end: int) -> tuple[int, ...]:
"""
Returns the shape of a with dims in [start, end) merged into a single dimension.
"""
@ -1374,7 +1375,7 @@ def _collapsed_shape(shape: ShapeType, start: int, end: int) -> Tuple[int, ...]:
def _collapse_view_helper(
a: TensorLikeType, start: int, end: int
) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
) -> tuple[Optional[ShapeType], Optional[StrideType]]:
assert isinstance(a, TensorLike)
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@ -1534,8 +1535,8 @@ def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLik
)
raise ValueError(msg)
new_shape: List[int] = []
new_strides: List[int] = []
new_shape: list[int] = []
new_strides: list[int] = []
for idx in range(a.ndim):
if idx == dim:
new_shape.extend((outer_length, inner_length))
@ -1797,7 +1798,7 @@ def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
)
def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
def _cat_aten(tensors: Union[tuple[Tensor, ...], list[Tensor]], dim: int) -> Tensor:
return torch.cat(tensors, dim)
@ -2609,7 +2610,7 @@ scalar_tensor = _make_prim(
def _svd_meta(
A: TensorLikeType, *, full_matrices: bool
) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
) -> tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
utils.check_is_matrix(A, "linalg.svd")
utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
@ -2646,7 +2647,7 @@ def _svd_meta(
def _svd_aten(
A: TensorLikeType, *, full_matrices: bool
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
return torch.linalg.svd(A, full_matrices=full_matrices)
@ -2899,7 +2900,7 @@ fft_c2r = _make_prim(
)
def _frexp_meta(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
def _frexp_meta(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
torch._check(
self.dtype.is_floating_point,
lambda: "torch.frexp() only supports floating-point dtypes",

View File

@ -1,7 +1,8 @@
# mypy: allow-untyped-defs
import functools
from collections.abc import Sequence
from contextlib import nullcontext
from typing import Any, Callable, Dict, Optional, Sequence
from typing import Any, Callable, Optional
import torch
import torch._decomp
@ -28,7 +29,7 @@ def torch_to_refs_map():
(torch.fft, torch._refs.fft),
(torch.linalg, torch._refs.linalg),
]
r: Dict[Any, Any] = {
r: dict[Any, Any] = {
torch.Tensor.__invert__: torch._refs.bitwise_not,
torch.Tensor.__xor__: torch._refs.bitwise_xor,
torch.Tensor.__and__: torch._refs.bitwise_and,
@ -107,7 +108,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
orig_func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
kwargs: Optional[dict] = None,
):
if kwargs is None:
kwargs = {}

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.utils._pytree as pytree
@ -85,7 +85,7 @@ def register_philox_rand():
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
stride: Optional[tuple[int, ...]],
device: _device,
dtype: _dtype,
):
@ -102,7 +102,7 @@ def register_philox_rand():
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
stride: Optional[tuple[int, ...]],
device: _device,
dtype: _dtype,
):

View File

@ -2,6 +2,7 @@
import inspect
import warnings
from functools import wraps
from types import GenericAlias
from typing import (
Callable,
List,
@ -260,6 +261,15 @@ def out_wrapper(
Adds the out parameter to a Python reference.
"""
out_type = (
TensorLikeType
if is_tensor
else GenericAlias(
tuple, tuple(TensorLikeType for _ in range(len(out_names)))
)
)
# For backward compatibility - should be able to remove once PEP585
# conversion is complete.
bc_out_type = (
TensorLikeType
if is_tensor
else Tuple[tuple(TensorLikeType for _ in range(len(out_names)))]
@ -301,12 +311,12 @@ def out_wrapper(
assert (
(isinstance(result, TensorLike) and is_tensor)
or (
isinstance(result, Tuple) # type: ignore[arg-type]
isinstance(result, tuple) # type: ignore[arg-type]
and len(result) == len(out_names) # type: ignore[arg-type]
)
or (
fn.__name__ == "unbind"
and isinstance(result, (List, Tuple)) # type: ignore[arg-type]
and isinstance(result, (List, tuple)) # type: ignore[arg-type]
)
)
# unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829
@ -336,9 +346,9 @@ def out_wrapper(
_safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type]
else:
if fn.__name__ != "unbind":
assert isinstance(out, Tuple) # type: ignore[arg-type]
assert isinstance(out, tuple) # type: ignore[arg-type]
else:
assert isinstance(out, (List, Tuple)) # type: ignore[arg-type]
assert isinstance(out, (list, tuple)) # type: ignore[arg-type]
torch._check_type(
len(out) == len(result), # type: ignore[arg-type]
lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type]
@ -362,6 +372,7 @@ def out_wrapper(
assert isinstance(sig.return_annotation, str) or sig.return_annotation in (
sig.empty,
out_type,
bc_out_type,
)
params = *sig.parameters.values(), out_param

View File

@ -7,21 +7,10 @@ import itertools
import math
import operator
import warnings
from collections.abc import Iterable
from collections.abc import Iterable, Sequence
from enum import Enum
from functools import partial, reduce, singledispatch, wraps
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
overload,
Sequence,
Tuple,
Union,
)
from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple, Union
import torch
import torch._prims as prims
@ -411,7 +400,7 @@ def _broadcast_shapes(*_shapes):
assert isinstance(shape, Sequence)
# Computes common shape
common_shape: List[Union[int, torch.SymInt]] = [
common_shape: list[Union[int, torch.SymInt]] = [
1,
] * reduce(max, (len(shape) for shape in shapes))
for arg_idx, shape in enumerate(shapes):
@ -1421,7 +1410,7 @@ def fmod(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
@register_decomposition(aten.frexp)
@out_wrapper("mantissa", "exponent")
def frexp(self: TensorLikeType) -> Tuple[TensorLikeType, TensorLikeType]:
def frexp(self: TensorLikeType) -> tuple[TensorLikeType, TensorLikeType]:
return torch.return_types.frexp(prims.frexp(self))
@ -2052,7 +2041,7 @@ def _to_device(
non_blocking: bool = False,
copy: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
kwargs = {
"device": device,
"dtype": dtype,
@ -2070,7 +2059,7 @@ def _to_device_str(
non_blocking: bool = False,
copy: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
kwargs = {
"device": torch.device(device),
"dtype": dtype,
@ -2087,7 +2076,7 @@ def _to_dtype(
non_blocking: bool = False,
copy: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
kwargs = {
"dtype": dtype,
"non_blocking": non_blocking,
@ -2103,7 +2092,7 @@ def _to_other(
non_blocking: bool = False,
copy: bool = False,
memory_format: Optional[torch.memory_format] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
device = other.device
dtype = other.dtype
layout = other.layout
@ -2311,7 +2300,7 @@ def any(
@register_decomposition([aten.sum.dim_IntList, aten.sum.IntList_out])
def sum(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Union[Optional[int], Optional[list[int]]] = None,
keepdim: bool = False,
*,
dtype: Optional[torch.dtype] = None,
@ -2363,7 +2352,7 @@ def sum_to_size(
@register_decomposition(aten.prod)
def prod(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Union[Optional[int], Optional[list[int]]] = None,
keepdim: bool = False,
*,
dtype=None,
@ -2481,7 +2470,7 @@ def var(
@out_wrapper()
def std(
a: TensorLikeType,
dim: Union[Optional[int], Optional[List[int]]] = None,
dim: Union[Optional[int], Optional[list[int]]] = None,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
@ -2660,7 +2649,7 @@ def addr(
# CompositeImplicitAutograd - don't register decomp
def atleast_1d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_1d`."""
if not args and isinstance(arg, collections.abc.Sequence):
args_ = arg
@ -2684,7 +2673,7 @@ def _unsqueeze_atleast(
# CompositeImplicitAutograd - don't register decomp
def atleast_2d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_2d`."""
if not args and isinstance(arg, collections.abc.Sequence):
args_ = arg
@ -2699,7 +2688,7 @@ def atleast_2d(
# CompositeImplicitAutograd - don't register decomp
def atleast_3d(
arg: Union[TensorLikeType, Sequence[TensorLikeType]], *args: TensorLikeType
) -> Union[TensorLikeType, Tuple[TensorLikeType, ...]]:
) -> Union[TensorLikeType, tuple[TensorLikeType, ...]]:
"""Reference implementation of :func:`torch.atleast_3d`."""
if not args and isinstance(arg, collections.abc.Sequence):
args_ = arg
@ -2742,7 +2731,7 @@ def broadcast_shapes(*shapes) -> ShapeType:
@aten.broadcast_tensors.default.py_impl(DispatchKey.CompositeImplicitAutograd)
@aten.broadcast_tensors.default.py_impl(DispatchKey.Meta)
def broadcast_tensors(*tensors) -> List[TensorLikeType]:
def broadcast_tensors(*tensors) -> list[TensorLikeType]:
if len(tensors) == 1 and not isinstance(tensors[0], Tensor):
tensors = tensors[0]
return list(_maybe_broadcast(*tensors, preserve_cpu_scalar_tensors=False))
@ -2900,7 +2889,7 @@ def conj(input: TensorLikeType) -> TensorLikeType:
@register_decomposition(aten.constant_pad_nd)
@out_wrapper()
def constant_pad_nd(
input: TensorLikeType, pad: List[int], value: NumberType = 0
input: TensorLikeType, pad: list[int], value: NumberType = 0
) -> TensorLikeType:
torch._check(
len(pad) % 2 == 0,
@ -3045,7 +3034,7 @@ def expand_as(a: Tensor, b: Tensor) -> Tensor:
return a.expand(b.shape)
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType, ...]:
def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> tuple[TensorLikeType, ...]:
if chunks <= 0:
msg = f"Expected at least one chunk, but got {chunks}!"
raise ValueError(msg)
@ -3148,7 +3137,7 @@ def narrow(
def _normalize(
a: Tensor, norm_dims: DimsType, eps: float
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
"""Computes mean and 1/std of a tensor along norm_dims.
Used as a helper function for normalization layers.
@ -3176,7 +3165,7 @@ def _normalize(
# add all specified dimensions
def _unsqueeze_multiple(x: TensorLikeType, dimensions: List[int]) -> TensorLikeType:
def _unsqueeze_multiple(x: TensorLikeType, dimensions: list[int]) -> TensorLikeType:
for dim in sorted(dimensions):
x = torch.unsqueeze(x, dim)
return x
@ -3192,7 +3181,7 @@ def native_group_norm(
flattened_inner_size: int,
num_groups: int,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
torch._check(
input.ndim >= 2,
lambda: f"Expected at least 2 dimensions for input tensor but received {input.ndim}",
@ -3243,7 +3232,7 @@ def native_layer_norm(
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
) -> tuple[Tensor, Tensor, Tensor]:
normalized_ndim = len(normalized_shape)
torch._check(
normalized_ndim >= 1,
@ -4163,8 +4152,8 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
@register_decomposition(aten.split_with_sizes)
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
self: Tensor, split_sizes: list[int], dim: int = 0
) -> list[Tensor]:
# NB: Perform the check_is_size tests first so that the
# sum test does not try to do a replacement
for i in range(len(split_sizes)):
@ -4197,7 +4186,7 @@ def tensor_split(
a: TensorLikeType,
indices_or_sections: Union[Tensor, DimsType],
dim: int = 0,
) -> Tuple[TensorLikeType, ...]:
) -> tuple[TensorLikeType, ...]:
_dim = utils.canonicalize_dim(a.ndim, dim)
if a.ndim == 0:
msg = "tensor_split: received a rank zero tensor, but expected a tensor of rank one or greater!"
@ -4263,7 +4252,7 @@ def tensor_split(
# CompositeImplicitAutograd - don't register decomp
def hsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
) -> tuple[TensorLikeType, ...]:
torch._check(
a.ndim >= 1,
lambda: (
@ -4305,7 +4294,7 @@ def hsplit(
# CompositeImplicitAutograd - don't register decomp
def vsplit(
a: TensorLikeType, indices_or_sections: DimsType
) -> Tuple[TensorLikeType, ...]:
) -> tuple[TensorLikeType, ...]:
torch._check(
a.ndim >= 2,
lambda: (
@ -4480,7 +4469,7 @@ def diag_embed(
@register_decomposition(aten.block_diag)
@out_wrapper()
def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
def _block_diag_iterable(tensors: list[TensorLikeType]) -> TensorLikeType:
"""
Reference implementation of torch.block_diag
"""
@ -4516,7 +4505,7 @@ def _block_diag_iterable(tensors: List[TensorLikeType]) -> TensorLikeType:
return torch.cat(result, dim=0)
def block_diag(*tensors: List[TensorLikeType]) -> TensorLikeType:
def block_diag(*tensors: list[TensorLikeType]) -> TensorLikeType:
"""
This is used as an input to PythonRefInfo. `torch.block_diag`
expects arguments splatted, but `aten.block_diag` expects only
@ -5305,9 +5294,9 @@ def meshgrid(*tensors: TensorLikeType, indexing: str):
@register_decomposition(aten.meshgrid) # type: ignore[misc]
def meshgrid(
*tensors: Union[TensorLikeType, List[TensorLikeType], Tuple[TensorLikeType]],
*tensors: Union[TensorLikeType, list[TensorLikeType], tuple[TensorLikeType]],
indexing: str,
) -> List[TensorLikeType]:
) -> list[TensorLikeType]:
# This ref simultaneously handles two overloads (see stubs above)
# The `indexing` argument is currently optional for torch.meshgrid, but we
# plan to make the argument required: https://github.com/pytorch/pytorch/issues/50276
@ -5346,7 +5335,7 @@ def meshgrid(
),
)
result_shape: List[int] = []
result_shape: list[int] = []
for t in tensors:
assert isinstance(t, TensorLike) # mypy
torch._check(
@ -5355,7 +5344,7 @@ def meshgrid(
)
result_shape.append(t.numel())
grids: List[TensorLikeType] = []
grids: list[TensorLikeType] = []
for i, t in enumerate(tensors):
assert isinstance(t, TensorLike) # mypy
if t.ndim == 0:
@ -5436,7 +5425,7 @@ def movedim(
@register_decomposition(aten.empty_strided)
@out_wrapper()
def empty_strided(
shape: Union[ShapeType, Tuple[ShapeType]],
shape: Union[ShapeType, tuple[ShapeType]],
strides: StrideType,
*,
dtype: Optional[torch.dtype] = None,
@ -5854,7 +5843,7 @@ def tril(a: TensorLikeType, diagonal: int = 0) -> TensorLikeType:
# form a pentagon that can be broken down into a top trapezoid and a bottom
# rectangle. For the implementation of tril_indices, we need the sizes of
# both of these, as well as the length of the top side of the trapezoid.
def _get_tril_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
def _get_tril_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
if row == 0 or col == 0:
return 0, 0, 0
@ -5932,7 +5921,7 @@ def tril_indices(
# a bottom rectangle instead. Note that you can't reduce this to
# _get_tril_sizes(col, row, -offset) because that would correspond to
# decomposing into a left trapezoid and right rectangle.
def _get_triu_sizes(row: int, col: int, offset: int) -> Tuple[int, int, int]:
def _get_triu_sizes(row: int, col: int, offset: int) -> tuple[int, int, int]:
if row == 0 or col == 0:
return 0, 0, 0

View File

@ -1,5 +1,6 @@
import math
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
from collections.abc import Iterable, Sequence
from typing import Literal, NamedTuple, Optional, Union
import torch
import torch._prims as prims
@ -88,7 +89,7 @@ def _maybe_promote_tensor_fft(
def _resize_fft_input(
x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
x: TensorLikeType, dims: tuple[int, ...], sizes: tuple[int, ...]
) -> TensorLikeType:
"""
Fixes the shape of x such that x.size(dims[i]) == sizes[i],
@ -268,8 +269,8 @@ def ihfft(
class _ShapeAndDims(NamedTuple):
shape: Tuple[int, ...]
dims: Tuple[int, ...]
shape: tuple[int, ...]
dims: tuple[int, ...]
def _canonicalize_fft_shape_and_dim_args(
@ -339,8 +340,8 @@ def _prod(xs: Iterable[int]) -> int:
def _fftn_c2c(
function_name: str,
input: TensorLikeType,
shape: Tuple[int, ...],
dim: Tuple[int, ...],
shape: tuple[int, ...],
dim: tuple[int, ...],
norm: NormType,
forward: bool,
) -> TensorLikeType:
@ -429,8 +430,8 @@ def ihfftn(
class _CanonicalizeC2rReturn(NamedTuple):
shape: Tuple[int, ...]
dim: Tuple[int, ...]
shape: tuple[int, ...]
dim: tuple[int, ...]
last_dim_size: int
@ -566,7 +567,7 @@ def ihfft2(
return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> list[int]:
"""Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
if dim is None:
return list(range(x.ndim))

View File

@ -288,7 +288,7 @@ def norm(
# CompositeImplicitAutograd
@out_wrapper("U", "S", "Vh", exact_dtype=True)
def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
def svd(A: TensorLikeType, full_matrices: bool = True) -> tuple[Tensor, Tensor, Tensor]:
return prims.svd(A, full_matrices=full_matrices)

View File

@ -1,4 +1,4 @@
from typing import List
__all__: List[str] = []
__all__: list[str] = []

View File

@ -6,9 +6,10 @@ import os
import re
import subprocess
import time
from collections.abc import Sequence
from threading import Lock
from timeit import default_timer as timer
from typing import Any, Callable, List, Optional, Sequence, TypeVar
from typing import Any, Callable, Optional, TypeVar
from typing_extensions import ParamSpec
@ -77,8 +78,8 @@ class StrobelightCLIFunctionProfiler:
run_user_name: str = "pytorch-strobelight-ondemand",
timeout_wait_for_running_sec: int = 60,
timeout_wait_for_finished_sec: int = 60,
recorded_env_variables: Optional[List[str]] = None,
sample_tags: Optional[List[str]] = None,
recorded_env_variables: Optional[list[str]] = None,
sample_tags: Optional[list[str]] = None,
stack_max_len: int = 127,
async_stack_max_len: int = 127,
):
@ -91,7 +92,7 @@ class StrobelightCLIFunctionProfiler:
# Results of the most recent run.
# Tracks the strobelight run id of the most recent run
self.current_run_id: Optional[int] = None
self.profile_result: Optional[List[str]] = None
self.profile_result: Optional[list[str]] = None
self.sample_tags = sample_tags
def _run_async(self) -> None: