PEP585 update - torch/testing (#145200)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-19 21:32:39 -08:00
committed by PyTorch MergeBot
parent 805c4b597a
commit dea7ad3371
37 changed files with 262 additions and 298 deletions

View File

@ -3,18 +3,8 @@ import abc
import cmath
import collections.abc
import contextlib
from typing import (
Any,
Callable,
Collection,
Dict,
List,
NoReturn,
Optional,
Sequence,
Type,
Union,
)
from collections.abc import Collection, Sequence
from typing import Any, Callable, NoReturn, Optional, Union
from typing_extensions import deprecated
import torch
@ -33,7 +23,7 @@ class ErrorMeta(Exception):
"""Internal testing exception that makes that carries error metadata."""
def __init__(
self, type: Type[Exception], msg: str, *, id: tuple[Any, ...] = ()
self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = ()
) -> None:
super().__init__(
"If you are a user and see this message during normal operation "
@ -82,7 +72,7 @@ _DTYPE_PRECISIONS.update(
def default_tolerances(
*inputs: Union[torch.Tensor, torch.dtype],
dtype_precisions: Optional[Dict[torch.dtype, tuple[float, float]]] = None,
dtype_precisions: Optional[dict[torch.dtype, tuple[float, float]]] = None,
) -> tuple[float, float]:
"""Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
@ -341,13 +331,13 @@ class Pair(abc.ABC):
raise UnsupportedInputs
@staticmethod
def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, tuple[Type, ...]]):
def _check_inputs_isinstance(*inputs: Any, cls: Union[type, tuple[type, ...]]):
"""Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
if not all(isinstance(input, cls) for input in inputs):
Pair._inputs_not_supported()
def _fail(
self, type: Type[Exception], msg: str, *, id: tuple[Any, ...] = ()
self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = ()
) -> NoReturn:
"""Raises an :class:`ErrorMeta` from a given exception type and message and the stored id.
@ -451,8 +441,8 @@ class BooleanPair(Pair):
super().__init__(actual, expected, **other_parameters)
@property
def _supported_types(self) -> tuple[Type, ...]:
cls: List[Type] = [bool]
def _supported_types(self) -> tuple[type, ...]:
cls: list[type] = [bool]
if HAS_NUMPY:
cls.append(np.bool_)
return tuple(cls)
@ -545,7 +535,7 @@ class NumberPair(Pair):
self.check_dtype = check_dtype
@property
def _supported_types(self) -> tuple[Type, ...]:
def _supported_types(self) -> tuple[type, ...]:
cls = list(self._NUMBER_TYPES)
if HAS_NUMPY:
cls.append(np.number)
@ -1052,12 +1042,12 @@ def originate_pairs(
actual: Any,
expected: Any,
*,
pair_types: Sequence[Type[Pair]],
sequence_types: tuple[Type, ...] = (collections.abc.Sequence,),
mapping_types: tuple[Type, ...] = (collections.abc.Mapping,),
pair_types: Sequence[type[Pair]],
sequence_types: tuple[type, ...] = (collections.abc.Sequence,),
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
id: tuple[Any, ...] = (),
**options: Any,
) -> List[Pair]:
) -> list[Pair]:
"""Originates pairs from the individual inputs.
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
@ -1092,8 +1082,8 @@ def originate_pairs(
and isinstance(expected, sequence_types)
and not isinstance(expected, str)
):
actual_len = len(actual)
expected_len = len(expected)
actual_len = len(actual) # type: ignore[arg-type]
expected_len = len(expected) # type: ignore[arg-type]
if actual_len != expected_len:
raise ErrorMeta(
AssertionError,
@ -1105,8 +1095,8 @@ def originate_pairs(
for idx in range(actual_len):
pairs.extend(
originate_pairs(
actual[idx],
expected[idx],
actual[idx], # type: ignore[index]
expected[idx], # type: ignore[index]
pair_types=pair_types,
sequence_types=sequence_types,
mapping_types=mapping_types,
@ -1117,8 +1107,8 @@ def originate_pairs(
return pairs
elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
actual_keys = set(actual.keys())
expected_keys = set(expected.keys())
actual_keys = set(actual.keys()) # type: ignore[attr-defined]
expected_keys = set(expected.keys()) # type: ignore[attr-defined]
if actual_keys != expected_keys:
missing_keys = expected_keys - actual_keys
additional_keys = actual_keys - expected_keys
@ -1141,8 +1131,8 @@ def originate_pairs(
for key in keys:
pairs.extend(
originate_pairs(
actual[key],
expected[key],
actual[key], # type: ignore[index]
expected[key], # type: ignore[index]
pair_types=pair_types,
sequence_types=sequence_types,
mapping_types=mapping_types,
@ -1190,11 +1180,11 @@ def not_close_error_metas(
actual: Any,
expected: Any,
*,
pair_types: Sequence[Type[Pair]] = (ObjectPair,),
sequence_types: tuple[Type, ...] = (collections.abc.Sequence,),
mapping_types: tuple[Type, ...] = (collections.abc.Mapping,),
pair_types: Sequence[type[Pair]] = (ObjectPair,),
sequence_types: tuple[type, ...] = (collections.abc.Sequence,),
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
**options: Any,
) -> List[ErrorMeta]:
) -> list[ErrorMeta]:
"""Asserts that inputs are equal.
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
@ -1225,7 +1215,7 @@ def not_close_error_metas(
# Explicitly raising from None to hide the internal traceback
raise error_meta.to_error() from None # noqa: RSE102
error_metas: List[ErrorMeta] = []
error_metas: list[ErrorMeta] = []
for pair in pairs:
try:
pair.compare()

View File

@ -6,7 +6,7 @@ import collections.abc
import functools
import math
import warnings
from typing import cast, List, Optional, Union
from typing import cast, Optional, Union
import torch
@ -43,7 +43,7 @@ def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
def make_tensor(
*shape: Union[int, torch.Size, List[int], tuple[int, ...]],
*shape: Union[int, torch.Size, list[int], tuple[int, ...]],
dtype: torch.dtype,
device: Union[str, torch.device],
low: Optional[float] = None,

View File

@ -3,7 +3,6 @@
import os
import re
import sys
from typing import List
__all__ = [
"check_code_for_cuda_kernel_launches",
@ -15,7 +14,7 @@ __all__ = [
# launch a kernel without some safety? Use this as a quick workaround
# for a problem with the checker, fix the checker, then de-exclude
# the files in question.
exclude_files: List[str] = []
exclude_files: list[str] = []
# Without using a C++ AST we can't 100% detect kernel launches, so we
# model them as having the pattern "<<<parameters>>>(arguments);"

View File

@ -9,21 +9,10 @@ import sys
import threading
import unittest
from collections import namedtuple
from collections.abc import Iterable, Sequence
from enum import Enum
from functools import partial, wraps
from typing import (
Any,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from typing import Any, Callable, ClassVar, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -505,7 +494,7 @@ class DeviceTypeTestBase(TestCase):
def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes):
for dtype in dtypes:
param_kwargs: Dict[str, Any] = {}
param_kwargs: dict[str, Any] = {}
_update_param_kwargs(param_kwargs, "dtype", dtype)
# Note that an empty test suffix is set here so that the dtype can be appended
@ -728,7 +717,7 @@ class PrivateUse1TestBase(DeviceTypeTestBase):
def get_device_type_test_bases():
# set type to List[Any] due to mypy list-of-union issue:
# https://github.com/python/mypy/issues/3351
test_bases: List[Any] = []
test_bases: list[Any] = []
if IS_SANDCASTLE or IS_FBCODE:
if IS_REMOTE_GPU:
@ -1089,7 +1078,7 @@ class ops(_TestParametrizer):
op = check_exhausted_iterator = object()
for op in self.op_list:
# Determine the set of dtypes to use.
dtypes: Union[Set[torch.dtype], Set[None]]
dtypes: Union[set[torch.dtype], set[None]]
if isinstance(self.opinfo_dtypes, Sequence):
dtypes = set(self.opinfo_dtypes)
elif self.opinfo_dtypes == OpDTypes.unsupported_backward:
@ -1854,7 +1843,7 @@ def skipCUDAIfNotMiopenSuggestNHWC(fn):
# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s.
def skipCUDAVersionIn(versions: Optional[List[tuple[int, int]]] = None):
def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None):
def dec_fn(fn):
@wraps(fn)
def wrap_fn(self, *args, **kwargs):
@ -1969,7 +1958,7 @@ def skipPRIVATEUSE1(fn):
# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
# This should probably enumerate all available device type test base classes.
def get_all_device_types() -> List[str]:
def get_all_device_types() -> list[str]:
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]

View File

@ -21,7 +21,7 @@ from datetime import timedelta
from enum import Enum
from functools import partial, reduce, wraps
from io import StringIO
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable
from typing import NamedTuple, Optional, Union, Any, Callable
from unittest.mock import patch
from torch._logging._internal import trace_log
@ -963,7 +963,7 @@ class DistributedTestBase(MultiProcessTestCase):
def run_subtests(
cls_inst,
subtest_config: Dict[str, List[Any]],
subtest_config: dict[str, list[Any]],
test_fn: Callable,
*test_args,
**test_kwargs: Any,
@ -982,9 +982,9 @@ def run_subtests(
test_kwargs: Keyword arguments to pass to ``test_fn``.
"""
# Convert the config mapping to a list to have a fixed order
subtest_config_items: List[tuple[str, List[Any]]] = list(subtest_config.items())
subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
for values in itertools.product(*subtest_config_values):
# Map keyword to chosen value
subtest_kwargs = dict(zip(subtest_config_keys, values))
@ -1314,7 +1314,7 @@ class MultiThreadedTestCase(TestCase):
class SaveForwardInputsModule(nn.Module):
def __init__(
self,
forward_inputs: Dict[nn.Module, torch.Tensor],
forward_inputs: dict[nn.Module, torch.Tensor],
cast_forward_inputs: bool,
) -> None:
super().__init__()
@ -1330,7 +1330,7 @@ class SaveForwardInputsModule(nn.Module):
class SaveForwardInputsModel(nn.Module):
def __init__(
self,
forward_inputs: Dict[nn.Module, torch.Tensor],
forward_inputs: dict[nn.Module, torch.Tensor],
cast_forward_inputs: bool,
) -> None:
super().__init__()

View File

@ -1,6 +1,5 @@
# mypy: ignore-errors
from typing import List
import torch
@ -158,7 +157,7 @@ def get_all_dtypes(
include_complex=True,
include_complex32=False,
include_qint=False,
) -> List[torch.dtype]:
) -> list[torch.dtype]:
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
include_half=include_half, include_bfloat16=include_bfloat16
)
@ -171,7 +170,7 @@ def get_all_dtypes(
return dtypes
def get_all_math_dtypes(device) -> List[torch.dtype]:
def get_all_math_dtypes(device) -> list[torch.dtype]:
return (
get_all_int_dtypes()
+ get_all_fp_dtypes(
@ -181,7 +180,7 @@ def get_all_math_dtypes(device) -> List[torch.dtype]:
)
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]:
return (
[torch.complex32, torch.complex64, torch.complex128]
if include_complex32
@ -189,11 +188,11 @@ def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
)
def get_all_int_dtypes() -> List[torch.dtype]:
def get_all_int_dtypes() -> list[torch.dtype]:
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]:
dtypes = [torch.float32, torch.float64]
if include_half:
dtypes.append(torch.float16)
@ -202,7 +201,7 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dt
return dtypes
def get_all_qint_dtypes() -> List[torch.dtype]:
def get_all_qint_dtypes() -> list[torch.dtype]:
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]

View File

@ -12,7 +12,7 @@ from contextlib import nullcontext
from copy import deepcopy
from enum import auto, Enum
from functools import wraps
from typing import Any, Callable, cast, Dict, List, no_type_check, Optional, Type, Union
from typing import Any, Callable, cast, no_type_check, Optional, Union
from unittest import mock
import torch
@ -199,7 +199,7 @@ def _broadcast_state_dict(rank, state_dict):
olist = [state_dict if rank == 0 else None]
dist.broadcast_object_list(olist)
state_dict = cast(Dict[str, torch.Tensor], olist[0])
state_dict = cast(dict[str, torch.Tensor], olist[0])
# Ensure that the state is on DEVICE
for param_name in state_dict.keys():
state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
@ -322,7 +322,7 @@ class TransformerWithSharedParams(FSDPTestModel):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
add_bn: bool = True,
) -> Union[nn.Module, FSDP]:
@ -451,7 +451,7 @@ class NestedWrappedModule(FSDPTestModel):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
) -> nn.Module:
"""
@ -499,7 +499,7 @@ class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
):
"""
@ -581,7 +581,7 @@ class NonUniformReqGradNWM(NestedWrappedModule):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
):
"""
@ -674,7 +674,7 @@ class ModuleWithDelay(FSDPTestModel):
@staticmethod
def init(
module_class: Type[FSDPTestModel],
module_class: type[FSDPTestModel],
*model_args: Any,
delay_after_loss_ms: int,
delay_before_reduction_ms: int,
@ -706,7 +706,7 @@ class NestedWrappedModuleWithDelay(ModuleWithDelay):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
delay_after_loss_ms: int = 0,
delay_before_reduction_ms: int = 0,
@ -826,7 +826,7 @@ class MixtureOfExperts(NestedWrappedModule):
group: dist.ProcessGroup,
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
fsdp_kwargs: Optional[Dict[str, Any]] = None,
fsdp_kwargs: Optional[dict[str, Any]] = None,
deterministic: bool = False,
delay_before_free_ms: int = 0,
):
@ -907,7 +907,7 @@ class MLP(nn.Module):
class MLPStack(nn.Sequential):
def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
modules: List[nn.Module] = [
modules: list[nn.Module] = [
# Use multiplier of 3 to exercise uneven case
MLP(mlp_dim, dim_multiplier=3),
MLP(mlp_dim),
@ -1237,7 +1237,7 @@ class FSDPTest(MultiProcessTestCase):
mixed_precision: Optional[MixedPrecision] = None,
enable_sharded_grad_scaler: bool = False,
use_pure_fp16: bool = False,
sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
):
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
@ -1315,7 +1315,7 @@ class FSDPTest(MultiProcessTestCase):
def _test_fsdp_parity(
self,
model_class: Type[FSDPTestModel],
model_class: type[FSDPTestModel],
fsdp_init_mode: FSDPInitMode,
device_init_mode: DEVICEInitMode,
ref_init_fn: Optional[Callable] = None,
@ -1329,8 +1329,8 @@ class FSDPTest(MultiProcessTestCase):
use_orig_params: bool = False,
enable_sharded_grad_scaler: bool = False,
use_pure_fp16: bool = False,
init_kwargs: Optional[Dict[str, Any]] = None,
sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
init_kwargs: Optional[dict[str, Any]] = None,
sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
**fsdp_kwargs,
):
"""

View File

@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import enable_profiling_mode # noqa:
# Standard library
from itertools import chain
from typing import List, Union
from typing import Union
from torch._C import TensorType
import io
@ -62,7 +62,7 @@ def check_against_reference(self, func, reference_func, output_func, args, kwarg
return t.detach().clone().requires_grad_(require_grad)
def clone_inputs(preserve_requires_grad: bool):
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
for arg in args:
if isinstance(arg, torch.Tensor):
@ -76,7 +76,7 @@ def check_against_reference(self, func, reference_func, output_func, args, kwarg
# Returns tensors in args that requires_grad, including tensors in TensorList args
def get_recording_tensors(args):
recording_tensors: List[torch.Tensor] = []
recording_tensors: list[torch.Tensor] = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.requires_grad:
@ -284,7 +284,7 @@ class JitCommonTestCase(TestCase):
self.assertEqual(should_autodiff_node,
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]],
traced_graph, assert_propagation, constant_prop=True):
# repropagte input shapes provided by tracing,
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()

View File

@ -16,7 +16,8 @@ import numpy as np
import numpy.typing as npt
from torch import inf, nan
from typing import Any, Dict, List, Tuple, Union, Sequence
from typing import Any, Union
from collections.abc import Sequence
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import (
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
@ -3992,7 +3993,7 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs):
# Ordered as shapes for input, weight, bias,
# and a dict of values of (stride, padding, dilation, groups)
cases: Tuple = (
cases: tuple = (
((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}),
((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}),
((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}),
@ -4140,7 +4141,7 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=
# Ordered as shapes for input, weight, bias
# and a dict of values of (stride, padding, groups, dilation)
cases: Tuple = (
cases: tuple = (
((1, 3, 4, 4), (3, 3, 3, 3), (3,),
{'stride': (2, 2), 'padding': 2, 'groups': 1}),
((2, 4, 8, 8), (2, 2, 3, 3), (2,),
@ -4185,7 +4186,7 @@ def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs):
# Ordered as shapes for input, weight, bias
# and dict of values of (stride, padding, dilation, groups)
cases: Tuple = (
cases: tuple = (
((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}),
((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}),
((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}),
@ -4658,7 +4659,7 @@ def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
features_options = [[3, 4], [8, 8]]
batch_options: List[List[int]] = [
batch_options: list[list[int]] = [
[], # no batch
[0],
[8],
@ -4684,7 +4685,7 @@ def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
features_options = [[3, 4, 5], [8, 8, 8]]
batch_options: List[List[int]] = [
batch_options: list[list[int]] = [
[], # no batch
[0],
[8],
@ -4706,7 +4707,7 @@ def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs):
features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]]
batch_options: List[List[int]] = [
batch_options: list[list[int]] = [
[], # no batch
[0],
[8],
@ -5063,7 +5064,7 @@ def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Order: input_shape, kernel_size, kwargs
cases: List[tuple[tuple[int, ...], Union[int, tuple[int, ...]], Dict]] = [
cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
((2, 3, 9), (3,), {}),
((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)),
((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)),
@ -5082,7 +5083,7 @@ def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
cases: List[tuple[tuple[int, ...], Union[int, tuple[int, ...]], Dict]] = [
cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
((2, 3, 3, 4, 4), (2, 2, 2), {}),
((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True,
count_include_pad=False, divisor_override=2)),
@ -6637,7 +6638,7 @@ def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs)
batch_size, num_classes = shape = (2, 3)
reductions = ("mean", "sum", "none")
input_shape_and_kwargs: List[tuple[tuple[int, ...], Dict[str, Any]]] = [
input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [
(shape, {}),
((*shape, 1), {}),
((*shape, 1, 2), {}),
@ -6828,17 +6829,17 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False
def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
requires_grad: bool,
*, variant: str, **kwargs) -> List[SampleInput]:
*, variant: str, **kwargs) -> list[SampleInput]:
if variant == 'variadic':
def make_inputs(
tensors: List[torch.Tensor]) -> tuple[Union[torch.Tensor,
List[torch.Tensor]],
tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
list[torch.Tensor]],
tuple[torch.Tensor, ...]]:
return tensors
elif variant == 'list':
def make_inputs(
tensors: List[torch.Tensor]) -> tuple[Union[torch.Tensor,
List[torch.Tensor]],
tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
list[torch.Tensor]],
tuple[torch.Tensor, ...]]:
return [tensors]
else:
@ -6848,7 +6849,7 @@ def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.d
SCALAR = torch.Size([])
VECTOR = torch.Size([3])
test_cases: List[List[torch.Size]] = [
test_cases: list[list[torch.Size]] = [
[SCALAR],
[VECTOR],
[VECTOR, SCALAR],
@ -9664,7 +9665,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
args.pop()
foreach_unary_op_db: List[OpInfo] = [
foreach_unary_op_db: list[OpInfo] = [
ForeachFuncInfo(
'exp',
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
@ -10835,7 +10836,7 @@ foreach_unary_op_db: List[OpInfo] = [
),
]
foreach_binary_op_db: List[OpInfo] = [
foreach_binary_op_db: list[OpInfo] = [
ForeachFuncInfo(
"add",
sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
@ -11130,7 +11131,7 @@ foreach_binary_op_db: List[OpInfo] = [
)
]
foreach_pointwise_op_db: List[ForeachFuncInfo] = [
foreach_pointwise_op_db: list[ForeachFuncInfo] = [
ForeachFuncInfo(
"addcmul",
sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
@ -11184,7 +11185,7 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
),
]
foreach_reduce_op_db: List[ForeachFuncInfo] = [
foreach_reduce_op_db: list[ForeachFuncInfo] = [
ForeachFuncInfo(
"max",
sample_inputs_func=foreach_max_sample_func(1, False, False),
@ -11267,7 +11268,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
),
]
foreach_other_op_db: List[ForeachFuncInfo] = [
foreach_other_op_db: list[ForeachFuncInfo] = [
ForeachFuncInfo(
"lerp",
sample_inputs_func=foreach_inputs_sample_func(3, True, True),
@ -11665,7 +11666,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
# Operator database (sorted alphabetically)
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
UnaryUfuncInfo('abs',
aliases=('absolute', ),
ref=np.abs,

View File

@ -28,11 +28,10 @@ from torch.testing._internal.common_utils import (
freeze_rng_state, skipIfMPS, skipIfMPSOnMacOS13, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
skipIfTorchDynamo)
from types import ModuleType
from typing import List, Type, Set, Dict
import operator
# List of all namespaces containing modules to test.
MODULE_NAMESPACES: List[ModuleType] = [
MODULE_NAMESPACES: list[ModuleType] = [
torch.nn.modules,
torch.ao.nn.qat.modules,
torch.ao.nn.quantizable.modules,
@ -41,7 +40,7 @@ MODULE_NAMESPACES: List[ModuleType] = [
]
# Modules that shouldn't be tested for one reason or another.
MODULES_TO_SKIP: Set[Type] = {
MODULES_TO_SKIP: set[type] = {
torch.nn.Module, # abstract base class
torch.nn.Container, # deprecated
torch.nn.NLLLoss2d, # deprecated
@ -50,14 +49,14 @@ MODULES_TO_SKIP: Set[Type] = {
}
# List of all module classes to test.
MODULE_CLASSES: List[Type] = list(chain(*[
MODULE_CLASSES: list[type] = list(chain(*[
[getattr(namespace, module_name) for module_name in namespace.__all__] # type: ignore[attr-defined]
for namespace in MODULE_NAMESPACES]))
MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
# Dict of module class -> common name. Useful for making test names more intuitive.
# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
MODULE_CLASS_NAMES: Dict[Type, str] = {}
MODULE_CLASS_NAMES: dict[type, str] = {}
for namespace in MODULE_NAMESPACES:
for module_name in namespace.__all__: # type: ignore[attr-defined]
module_cls = getattr(namespace, module_name)
@ -317,7 +316,7 @@ def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, t
def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_batchmean', {'reduction': 'batchmean'}),
@ -360,7 +359,7 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, tr
requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_none', {'reduction': 'none'}),
@ -425,7 +424,7 @@ def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -450,7 +449,7 @@ def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_g
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -498,7 +497,7 @@ def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, tr
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -980,7 +979,7 @@ def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requi
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1419,7 +1418,7 @@ def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_gra
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1455,7 +1454,7 @@ def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, tr
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1503,7 +1502,7 @@ def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, require
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1545,8 +1544,8 @@ def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
reductions: List[str] = ['mean', 'sum', 'none']
cases: List[tuple[str, dict]] = [
reductions: list[str] = ['mean', 'sum', 'none']
cases: list[tuple[str, dict]] = [
('', {}),
('weights', {'weight': make_weight((3,))}),
('ignore_index', {'ignore_index': 1}),
@ -1633,7 +1632,7 @@ def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, tr
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1799,7 +1798,7 @@ def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requir
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -1833,7 +1832,7 @@ def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requir
def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -2245,7 +2244,7 @@ def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, require
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -2273,7 +2272,7 @@ def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requ
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -2309,7 +2308,7 @@ def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -2340,7 +2339,7 @@ def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype,
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -2378,7 +2377,7 @@ def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_g
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
cases: List[tuple[str, dict]] = [
cases: list[tuple[str, dict]] = [
('', {}),
('reduction_sum', {'reduction': 'sum'}),
('reduction_mean', {'reduction': 'mean'}),
@ -3362,7 +3361,7 @@ _macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_
# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
module_db: list[ModuleInfo] = [
ModuleInfo(torch.nn.AdaptiveAvgPool1d,
module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
skips=(

View File

@ -23,7 +23,8 @@ from torch.autograd import Variable
from torch.types import _TensorOrTensors
import torch.backends.cudnn
from typing import Dict, Callable, List, Sequence, Union, Any
from typing import Callable, Union, Any
from collections.abc import Sequence
TemporaryFile = tempfile.TemporaryFile
PRECISION = 1e-5
@ -502,7 +503,7 @@ def nllloss_no_reduce_test():
def nllloss_no_reduce_ignore_index_test():
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
return dict(
fullname='NLLLoss_no_reduce_ignore_index',
constructor=wrap_functional(
@ -605,7 +606,7 @@ def nllloss2d_no_reduce_test():
def nllloss2d_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLoss2d_no_reduce_ignore_index',
constructor=wrap_functional(
@ -662,7 +663,7 @@ def nlllossNd_no_reduce_test():
def nlllossNd_no_reduce_ignore_index_test():
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
return dict(
fullname='NLLLossNd_no_reduce_ignore_index',
constructor=wrap_functional(
@ -2671,7 +2672,7 @@ def get_new_module_tests():
'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
'Tanhshrink', 'Threshold'
]
non_linear_activations_extra_info: Dict[str, dict] = {
non_linear_activations_extra_info: dict[str, dict] = {
'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
'Threshold': {'constructor_args': (2., 1.)},
'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
@ -3059,7 +3060,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
return output
loss_reference_fns: Dict['str', Callable] = {
loss_reference_fns: dict['str', Callable] = {
'KLDivLoss': kldivloss_reference,
'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
'NLLLoss': nllloss_reference,
@ -3173,7 +3174,7 @@ classification_criterion_no_batch = [
),
('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
]
classification_criterion_no_batch_extra_info: Dict[str, dict] = {
classification_criterion_no_batch_extra_info: dict[str, dict] = {
'MultiLabelMarginLoss': {'check_gradgrad': False},
}
# TODO : Fix these discrepancies
@ -3209,7 +3210,7 @@ class NNTestCase(TestCase):
raise NotImplementedError
@abstractmethod
def _get_parameters(self, module: nn.Module) -> tuple[List[nn.Parameter], List[nn.Parameter]]:
def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
raise NotImplementedError
@abstractmethod

View File

@ -6,7 +6,7 @@ import sys
import unittest
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Union
from typing import Any, Union
import torch
from torch import Tensor
@ -56,9 +56,9 @@ class OptimizerInput:
def __init__(
self,
params: Union[
List[Parameter], List[Tensor], Dict[Any, Any], List[Dict[str, Any]]
list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
desc: str = "",
):
# params can be a list of Tensors OR param_groups OR None
@ -1256,7 +1256,7 @@ def _get_device_type(device: Union[str, torch.device]) -> str:
def _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=()
) -> List[OptimizerInput]:
) -> list[OptimizerInput]:
"""
Return a list of all configs for a given optimizer as a list of OptimizerInputs,
including configs that have supported global cliquey kwargs (foreach, fused,
@ -1312,7 +1312,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
# Database of OptimizerInfo entries in alphabetical order.
optim_db: List[OptimizerInfo] = [
optim_db: list[OptimizerInfo] = [
OptimizerInfo(
Adadelta,
optim_inputs_func=optim_inputs_func_adadelta,

View File

@ -1,16 +1,16 @@
# Owner(s): ["module: unknown"]
from typing import Dict, Any
from typing import Any
from torch.ao.pruning import BaseSparsifier
import torch
import torch.nn.functional as F
from torch import nn
class ImplementedSparsifier(BaseSparsifier):
def __init__(self, **kwargs: Dict[str, Any]) -> None:
def __init__(self, **kwargs: dict[str, Any]) -> None:
super().__init__(defaults=kwargs)
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None:
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None:
module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr]
linear_state = self.state['linear1.weight']
linear_state['step_count'] = linear_state.get('step_count', 0) + 1

View File

@ -74,7 +74,7 @@ import os
import unittest
import numpy as np
from torch.testing import FileCheck
from typing import Callable, Dict, Any, Union, Type, Optional
from typing import Callable, Any, Union, Optional
import torch._dynamo as torchdynamo
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
@ -898,8 +898,8 @@ class QuantizationTestCase(TestCase):
def assert_types_for_matched_subgraph_pairs(
self,
matched_subgraph_pairs: Dict[str, tuple[NSSubgraph, NSSubgraph]],
expected_types: Dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]],
matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]],
gm_a: GraphModule,
gm_b: GraphModule,
) -> None:
@ -952,7 +952,7 @@ class QuantizationTestCase(TestCase):
def assert_ns_compare_dict_valid(
self,
act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]],
act_compare_dict: dict[str, dict[str, dict[str, Any]]],
) -> None:
"""
Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
@ -1214,7 +1214,7 @@ class QuantizationTestCase(TestCase):
self.assertTrue(expected_name in str(q_embeddingbag))
class QuantizationLiteTestCase(QuantizationTestCase):
def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs):
def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs):
# Creates quantized model for testing mobile script modules
qengine = "qnnpack"
with override_quantized_engine(qengine):

View File

@ -49,15 +49,11 @@ from statistics import mean
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Type,
TypeVar,
Union,
)
from collections.abc import Iterable, Iterator
from unittest.mock import MagicMock
import expecttest
@ -646,7 +642,7 @@ class parametrize(_TestParametrizer):
name_fn (Callable): Optional function that takes in parameters and returns subtest name.
"""
def __init__(self, arg_str, arg_values, name_fn=None):
self.arg_names: List[str] = [s.strip() for s in arg_str.split(',') if s != '']
self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != '']
self.arg_values = arg_values
self.name_fn = name_fn
@ -689,7 +685,7 @@ class parametrize(_TestParametrizer):
for idx, values in enumerate(self.arg_values):
maybe_name = None
decorators: List[Any] = []
decorators: list[Any] = []
if isinstance(values, subtest):
sub = values
values = sub.arg_values
@ -942,11 +938,7 @@ parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SL
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
parser.add_argument('--rerun-disabled-tests', action='store_true')
parser.add_argument('--pytest-single-test', type=str, nargs=1)
if sys.version_info >= (3, 9):
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
else:
parser.add_argument('--showlocals', action='store_true', default=False)
parser.add_argument('--no-showlocals', dest='showlocals', action='store_false')
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
# Only run when -h or --help flag is active to display both unittest and parser help messages.
def run_unittest_help(argv):
@ -1173,10 +1165,10 @@ def sanitize_pytest_xml(xml_file: str):
tree.write(xml_file)
def get_pytest_test_cases(argv: List[str]) -> List[str]:
def get_pytest_test_cases(argv: list[str]) -> list[str]:
class TestCollectorPlugin:
def __init__(self) -> None:
self.tests: List[Any] = []
self.tests: list[Any] = []
def pytest_collection_finish(self, session):
for item in session.items:
@ -2637,7 +2629,7 @@ def check_if_enable(test: unittest.TestCase):
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
if matches_test(disabled_test):
platform_to_conditional: Dict = {
platform_to_conditional: dict = {
"mac": IS_MACOS,
"macos": IS_MACOS,
"win": IS_WINDOWS,
@ -2706,7 +2698,7 @@ class RelaxedBooleanPair(BooleanPair):
def _process_inputs(self, actual, expected, *, id):
# We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
# number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
tensor_or_array_types: tuple[Type, ...] = (torch.Tensor, np.ndarray)
tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
if not (
(isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
@ -2763,7 +2755,7 @@ class RelaxedNumberPair(NumberPair):
def _process_inputs(self, actual, expected, *, id):
# We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
# element tensor or array, whereas in default NumberPair both inputs have to be numbers.
tensor_or_array_types: tuple[Type, ...] = (torch.Tensor, np.ndarray)
tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
other_supported_types = (*self._supported_types, *tensor_or_array_types)
if not (
(isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
@ -2849,7 +2841,7 @@ class UnittestPair(Pair):
Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
"""
CLS: Union[Type, tuple[Type, ...]]
CLS: Union[type, tuple[type, ...]]
TYPE_NAME: Optional[str] = None
def __init__(self, actual, expected, **other_parameters):
@ -5072,8 +5064,8 @@ def get_tensors_from(args, kwargs):
# Returns scalar tensor representation of a list of integer byte values
def bytes_to_scalar(byte_list: List[int], dtype: torch.dtype, device: torch.device):
dtype_to_ctype: Dict[torch.dtype, Any] = {
def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device):
dtype_to_ctype: dict[torch.dtype, Any] = {
torch.int8: ctypes.c_int8,
torch.uint8: ctypes.c_uint8,
torch.uint16: ctypes.c_uint16,

View File

@ -10,13 +10,10 @@ from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
List,
Sequence,
TypeVar,
Union,
)
from collections.abc import Iterator, Sequence
import torch
import torch.distributed as dist
@ -391,7 +388,7 @@ def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
@wraps(func) # pyre-ignore[6]
def wrapper(
self, *args: tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
) -> None:
# if enough GPU we can use GPU, otherwise we fallback to CPU
if not TEST_CUDA or torch.cuda.device_count() < self.world_size:
@ -437,7 +434,7 @@ class DTensorConverter:
self,
mesh: DeviceMesh,
args: tuple[object, ...],
kwargs: Dict[str, object],
kwargs: dict[str, object],
) -> None:
self.hit = 0
self.miss = 0
@ -447,9 +444,9 @@ class DTensorConverter:
flatten_args, flatten_args_spec = tree_flatten(args)
flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
self.flatten_args: List[object] = flatten_args
self.flatten_args: list[object] = flatten_args
self.flatten_args_spec: TreeSpec = flatten_args_spec
self.flatten_kwargs: List[object] = flatten_kwargs
self.flatten_kwargs: list[object] = flatten_kwargs
self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
choices_for_args = [self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_args if isinstance(arg, torch.Tensor)]
@ -490,7 +487,7 @@ class DTensorConverter:
def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
mesh_size = self.mesh.size()
sharding_choices: List[Placement] = [Replicate()]
sharding_choices: list[Placement] = [Replicate()]
# c10d collective does not support bool tensor
# for bool tensor we treat it as replicated
if arg.dtype != torch.bool:
@ -510,12 +507,12 @@ class DTensorConverter:
def __iter__(self) -> "DTensorConverter":
return self
def __next__(self) -> tuple[tuple[object, ...], Dict[str, object]]:
def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]:
try:
next_sharding_choices = next(self.sharding_combs)
idx = 0
new_args: List[object] = []
new_args: list[object] = []
for arg in self.flatten_args:
if isinstance(arg, torch.Tensor):
new_args.append(
@ -527,7 +524,7 @@ class DTensorConverter:
else:
new_args.append(arg)
new_kwargs: List[object] = []
new_kwargs: list[object] = []
for arg in self.flatten_kwargs:
if isinstance(arg, torch.Tensor):
new_kwargs.append(
@ -547,7 +544,7 @@ class DTensorConverter:
raise StopIteration from e
def to_dist_tensor(
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
) -> torch.Tensor:
if type(t) is torch.Tensor or type(t) is nn.Parameter:
if self.is_supported_tensor(t):

View File

@ -7,7 +7,7 @@ import os
import shutil
import tempfile
from functools import wraps
from typing import Any, Callable, cast, Dict, IO, Optional
from typing import Any, Callable, cast, IO, Optional
# introduced as collections.abc.Buffer in Python 3.12
from typing_extensions import Buffer
@ -130,7 +130,7 @@ def with_temp_dir(
assert func is not None
@wraps(func)
def wrapper(self, *args: tuple[object], **kwargs: Dict[str, Any]) -> None:
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
if dist.is_initialized():
# Only create temp_dir when rank is 0
if dist.get_rank() == 0:

View File

@ -4,7 +4,7 @@
import copy
from itertools import chain
from typing import Any, Dict
from typing import Any
import torch
import torch.nn as nn
@ -32,8 +32,8 @@ class VerifyStateDictMixin:
def _verify_msd(
self,
msd: Dict[str, Any],
dist_msd: Dict[str, Any],
msd: dict[str, Any],
dist_msd: dict[str, Any],
options: StateDictOptions = StateDictOptions(),
offload_to_cpu=False,
) -> None:
@ -56,8 +56,8 @@ class VerifyStateDictMixin:
self,
model: nn.Module,
optim: torch.optim.Optimizer,
osd: Dict[str, Any],
dist_osd: Dict[str, Any],
osd: dict[str, Any],
dist_osd: dict[str, Any],
) -> None:
params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
param_pid_mapping = dict(zip(params, range(len(params))))
@ -110,7 +110,7 @@ class VerifyStateDictMixin:
model: nn.Module,
optim: torch.optim.Optimizer,
new_optim: torch.optim.Optimizer,
dist_osd: Dict[str, Any],
dist_osd: dict[str, Any],
) -> None:
new_dist_osd = _gather_state_dict(dist_osd)
set_state_dict(

View File

@ -3,7 +3,7 @@
import sys
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Optional, Union
from functools import partial, reduce
import torch
@ -93,7 +93,7 @@ class AllToAllBase:
input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
)
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, List[int], None], world_size: int) -> torch.Tensor:
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, list[int], None], world_size: int) -> torch.Tensor:
if sizes is None or len(sizes) == 0:
sizes = torch.full(
(world_size,), buf_size // world_size, dtype=torch.int64
@ -316,8 +316,8 @@ class ProcessLocalGroup(dist.ProcessGroup):
self,
output_buffer: torch.Tensor,
input_buffer: torch.Tensor,
output_split_sizes: Optional[List[int]],
input_split_sizes: Optional[List[int]],
output_split_sizes: Optional[list[int]],
input_split_sizes: Optional[list[int]],
opts=AllToAllOptions()
) -> torch.Tensor:
coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
@ -455,14 +455,14 @@ dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "
@dataclass
class WorldData:
default_pg: dist.ProcessGroup
pg_map: Dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
pg_names: Dict[dist.ProcessGroup, str]
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
pg_backend_config: Dict[dist.ProcessGroup, str]
pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
pg_names: dict[dist.ProcessGroup, str]
pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]]
pg_backend_config: dict[dist.ProcessGroup, str]
group_count: int
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
pg_to_tag: Dict[dist.ProcessGroup, str]
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
tags_to_pg: dict[str, list[dist.ProcessGroup]]
pg_to_tag: dict[dist.ProcessGroup, str]
pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]
class ThreadLocalWorld:
@ -514,7 +514,7 @@ class ThreadLocalWorld:
return self._get_world().pg_to_tag
@property
def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]:
return self._get_world().pg_coalesce_state

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict
import torch
import torch.distributed.autograd as dist_autograd
@ -35,7 +34,7 @@ class JitDistAutogradTest(RpcAgentTestFixture):
def test_get_gradients(self):
@torch.jit.script
def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]):
def dist_get_gradients(context_id: int) -> (dict[Tensor, Tensor]):
return dist_autograd.get_gradients(context_id)
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))

View File

@ -2,7 +2,7 @@
import time
import io
from typing import Dict, List, Any
from typing import Any
import torch
import torch.distributed as dist
@ -42,13 +42,13 @@ def rref_local_value(rref: RRef[Tensor]) -> Tensor:
@torch.jit.script
def list_create() -> List[int]:
def list_create() -> list[int]:
global_list = [1, 2, 3]
return global_list
@torch.jit.script
def rref_list_mutate(rref: RRef[List[int]]) -> None:
def rref_list_mutate(rref: RRef[list[int]]) -> None:
rref.local_value().append(4)
rref.to_here().append(5)
rref.to_here(5.0).append(6)
@ -435,7 +435,7 @@ class LocalRRefTest:
def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
args = (rref,)
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
fut = rpc.rpc_async(
rref.owner(), script_rref_get_value_my_script_class, args, kwargs
)
@ -465,7 +465,7 @@ class LocalRRefTest:
def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
args = (rref,)
kwargs: Dict[str, Any] = {}
kwargs: dict[str, Any] = {}
fut = rpc.rpc_async(
rref.owner_name(),
script_rref_run_forward_my_script_module,
@ -518,7 +518,7 @@ def raise_script():
@torch.jit.script
def script_rpc_async_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
ret = fut.wait()
@ -526,14 +526,14 @@ def script_rpc_async_call(
@torch.jit.script
def script_rpc_sync_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
):
res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
return res
@torch.jit.script
def script_rpc_remote_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
):
rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
return rref_res.to_here()
@ -607,7 +607,7 @@ class JitRpcOpTest:
# The error JIT gives is,
# "Dict values must contain only a single type, "
# "expected: Tensor but found str instead."
kwargs: Dict[str, Any] = {
kwargs: dict[str, Any] = {
"tensor_kwarg": torch.tensor([3, 3]),
"str_kwarg": "_str_kwarg",
"int_kwarg": 3,

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
from typing import Dict
import torch
import torch.distributed.rpc as rpc
@ -28,7 +27,7 @@ def two_args_two_kwargs(
@torch.jit.script
def script_rpc_async_call(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
ret = fut.wait()
@ -39,7 +38,7 @@ def script_rpc_async_call(
def rpc_async_call_with_timeout(
dst_worker_name: str,
args: tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
kwargs: dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
@ -51,7 +50,7 @@ def rpc_async_call_with_timeout(
def rpc_async_call_with_timeout_future_ret(
dst_worker_name: str,
args: tuple[Tensor, Tensor],
kwargs: Dict[str, Tensor],
kwargs: dict[str, Tensor],
timeout: float,
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
@ -60,7 +59,7 @@ def rpc_async_call_with_timeout_future_ret(
@torch.jit.script
def rpc_async_call_future_ret(
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
):
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
return fut

View File

@ -3,7 +3,6 @@
import os
import sys
import unittest
from typing import Dict, List, Type
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import (
@ -148,10 +147,10 @@ FAULTY_AGENT_TESTS = [
def generate_tests(
prefix: str,
mixin: Type[RpcAgentTestFixture],
tests: List[Type[RpcAgentTestFixture]],
mixin: type[RpcAgentTestFixture],
tests: list[type[RpcAgentTestFixture]],
module_name: str,
) -> Dict[str, Type[RpcAgentTestFixture]]:
) -> dict[str, type[RpcAgentTestFixture]]:
"""Mix in the classes needed to autogenerate the tests based on the params.
Takes a series of test suites, each written against a "generic" agent (i.e.,
@ -166,7 +165,7 @@ def generate_tests(
that the classes can be fixed to make it look like they belong to it, which
is necessary for pickling to work on them.
"""
ret: Dict[str, Type[RpcAgentTestFixture]] = {}
ret: dict[str, type[RpcAgentTestFixture]] = {}
for test_class in tests:
if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
print(

View File

@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import is_iterable_of_tensors, noncont
import collections
from copy import deepcopy
from typing import Any, Dict, List, Union
from typing import Any, Union
import math # noqa: F401
# Testing utils
@ -361,9 +361,9 @@ def get_constant(x):
return x
def get_script_args(args):
formals: List[str] = []
tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = []
actuals: List[str] = []
formals: list[str] = []
tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = []
actuals: list[str] = []
for arg in args:
if isinstance(arg, torch.Tensor):
name = f'i{len(formals)}'
@ -405,14 +405,14 @@ def create_script_fn(self, method_name, func_type):
return script_fn
class SplitInputs:
all_tensors: List[Any]
tensor_args: List[Any]
nontensor_args: List[Any]
arg_types: List[str]
tensor_kwargs: Dict[str, Any]
kwarg_order: List[str]
nontensor_kwargs: Dict[str, Any]
kwarg_types: Dict[str, Any]
all_tensors: list[Any]
tensor_args: list[Any]
nontensor_args: list[Any]
arg_types: list[str]
tensor_kwargs: dict[str, Any]
kwarg_order: list[str]
nontensor_kwargs: dict[str, Any]
kwarg_types: dict[str, Any]
@staticmethod
def _is_tensor_input(arg):

View File

@ -39,7 +39,7 @@ import sys
import tempfile
import textwrap
from importlib.abc import Loader
from typing import Any, Dict, List, Union
from typing import Any, Union
RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
@ -176,7 +176,7 @@ class JitTestCase(JitCommonTestCase):
allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list)
fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list)
get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
(graph, fusion_nodes) = next(iter(fusion_groups.items()))
@ -385,7 +385,7 @@ class JitTestCase(JitCommonTestCase):
if not frame:
raise RuntimeError("failed to get frame")
i += 1
defined_vars: Dict[str, Any] = {}
defined_vars: dict[str, Any] = {}
defined_vars.update(frame.f_locals)
defined_vars.update(frame.f_globals)
return defined_vars
@ -407,7 +407,7 @@ class JitTestCase(JitCommonTestCase):
with self.assertRaisesRegex(exception, regex):
if isinstance(script, str):
frame = self.get_frame_vars(frames_up)
the_locals: Dict[str, Any] = {}
the_locals: dict[str, Any] = {}
execWrapper(script, glob=frame, loc=the_locals)
frame.update(the_locals)
@ -471,7 +471,7 @@ class JitTestCase(JitCommonTestCase):
# outputs
frame = self.get_frame_vars(frames_up)
the_locals: Dict[str, Any] = {}
the_locals: dict[str, Any] = {}
execWrapper(script, glob=frame, loc=the_locals)
frame.update(the_locals)
@ -796,7 +796,7 @@ class TensorExprTestOptions:
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
def clone_inputs(args):
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
for arg in args:
if isinstance(arg, torch.Tensor):
@ -810,7 +810,7 @@ def clone_inputs(args):
def get_traced_sample_variant_pairs(device, dtype, op):
# tuples of (variant, sample)
outputs: List[tuple[Any, Any]] = []
outputs: list[tuple[Any, Any]] = []
samples = op.sample_inputs(device, dtype)

View File

@ -2,7 +2,8 @@
import torch
from torch.utils._pytree import tree_map
from typing import Iterator, List, Optional
from typing import Optional
from collections.abc import Iterator
import logging
import contextlib
import itertools
@ -101,8 +102,8 @@ class LoggingTensorReentrant(LoggingTensor):
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
class LoggingTensorHandler(logging.Handler):
def __init__(
self, log_list: List[str], use_shortid_for_all_tensors: bool,
with_type: bool, tracebacks_list: Optional[List]) -> None:
self, log_list: list[str], use_shortid_for_all_tensors: bool,
with_type: bool, tracebacks_list: Optional[list]) -> None:
logging.Handler.__init__(self)
self.log_list = log_list
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
@ -154,10 +155,10 @@ class GatherTraceback(logging.Filter):
return True
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
log_list: List[str] = []
tracebacks_list: List[str] = []
log_list: list[str] = []
tracebacks_list: list[str] = []
handler = LoggingTensorHandler(
log_list,
with_type=True,

View File

@ -8,11 +8,12 @@ import math
import operator
import unittest
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import asdict, dataclass
from enum import Enum
from functools import partial
from itertools import product
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
import torch
from torch.testing import make_tensor
@ -689,10 +690,10 @@ class OpInfo:
# the following metadata are test directives for skipping or modifying tests
# information about which tests to skip
skips: Tuple = ()
skips: tuple = ()
# decorators to apply to generated tests
decorators: Tuple = ()
decorators: tuple = ()
# the following are pointers to functions to generate certain classes of inputs
@ -803,11 +804,11 @@ class OpInfo:
# If `supports_cow_input_no_materialize_forward == True`, this list contains
# the arg indices or kwarg names of inputs that are expected to materialize
allow_cow_input_materialize_forward: List[Union[int, str]] = None
allow_cow_input_materialize_forward: list[Union[int, str]] = None
# If `supports_cow_input_no_materialize_backward == True`, this list contains
# the arg indices or kwarg names of inputs that are expected to materialize
allow_cow_input_materialize_backward: List[Union[int, str]] = None
allow_cow_input_materialize_backward: list[Union[int, str]] = None
# wrapper function for gradcheck
gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs)
@ -853,13 +854,13 @@ class OpInfo:
# a list of strings with node names that are expected to be in a
# DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
# default is populated to be ['aten::(name of Python operator)']
autodiff_nonfusible_nodes: List[str] = None
autodiff_nonfusible_nodes: list[str] = None
# a list of strings with node names that are expected to be in FusionGroups
# inside of DifferentiableGraphs when this operation is autodiffed.
# Ex: ['aten::add', 'aten::mm'], defaults to an empty list
# Note: currently no ops use fusible nodes
autodiff_fusible_nodes: List[str] = None
autodiff_fusible_nodes: list[str] = None
# the following metadata relates to sparse support and is used in test_sparse.py

View File

@ -13,7 +13,7 @@ from torch.testing._internal.opinfo.definitions import (
# Operator database
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
*fft.op_db,
*linalg.op_db,
*signal.op_db,
@ -21,7 +21,7 @@ op_db: List[OpInfo] = [
*_masked.op_db,
]
python_ref_db: List[OpInfo] = [
python_ref_db: list[OpInfo] = [
*fft.python_ref_db,
*linalg.python_ref_db,
*special.python_ref_db,

View File

@ -3,7 +3,6 @@
import unittest
from collections.abc import Sequence
from functools import partial
from typing import List
import numpy as np
@ -424,7 +423,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
)
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
ReductionOpInfo(
"masked.sum",
ref=reference_reduction_numpy(np.sum),

View File

@ -2,7 +2,6 @@
import unittest
from functools import partial
from typing import List
import numpy as np
@ -117,7 +116,7 @@ def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
# Operator database
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
SpectralFuncInfo(
"fft.fft",
aten_name="fft_fft",
@ -634,7 +633,7 @@ op_db: List[OpInfo] = [
),
]
python_ref_db: List[OpInfo] = [
python_ref_db: list[OpInfo] = [
SpectralFuncPythonRefInfo(
"_refs.fft.fft",
torch_opinfo_name="fft.fft",

View File

@ -3,9 +3,9 @@
import itertools
import random
import unittest
from collections.abc import Iterable
from functools import partial
from itertools import chain, product
from typing import Iterable, List
import numpy as np
from numpy import inf
@ -1169,7 +1169,7 @@ def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
yield SampleInput(inp, ind=len(shape_lhs))
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
OpInfo(
"linalg.cross",
ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
@ -2408,7 +2408,7 @@ op_db: List[OpInfo] = [
),
]
python_ref_db: List[OpInfo] = [
python_ref_db: list[OpInfo] = [
#
# torch.linalg
#

View File

@ -4,7 +4,7 @@ import math
from copy import copy
from dataclasses import dataclass
from functools import partial
from typing import List, Optional
from typing import Optional
import torch
from torch.fx.experimental.symbolic_shapes import is_nested_int
@ -41,7 +41,7 @@ class ExtraOpData:
# each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
#
# If no overload of the op accepts dim-related args, this should be None.
dim_args: List[List[str]] = None
dim_args: list[list[str]] = None
# Helper function to extract names of dim-related args.
# Returns: tuple of (single dim argname if available, dim list argname if available)

View File

@ -3,7 +3,7 @@
import unittest
from functools import partial
from itertools import product
from typing import Callable, List
from typing import Callable
import numpy
@ -345,7 +345,7 @@ def make_signal_windows_opinfo(
)
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
make_signal_windows_opinfo(
name="signal.windows.hamming",
ref=reference_signal_window(scipy.signal.windows.hamming)

View File

@ -3,7 +3,6 @@
import unittest
from functools import partial
from itertools import product
from typing import List
import numpy as np
@ -119,7 +118,7 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
)
op_db: List[OpInfo] = [
op_db: list[OpInfo] = [
UnaryUfuncInfo(
"special.i0e",
aten_name="special_i0e",
@ -702,7 +701,7 @@ op_db: List[OpInfo] = [
),
]
python_ref_db: List[OpInfo] = [
python_ref_db: list[OpInfo] = [
#
# Elementwise Unary Special OpInfos
#

View File

@ -2,8 +2,8 @@
import collections
import warnings
from collections.abc import Sequence
from functools import partial, wraps
from typing import Sequence
import numpy as np
import numpy.typing as npt

View File

@ -10,7 +10,8 @@ import re
import tempfile
import threading
import unittest
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import torch
import torch._dynamo
@ -46,7 +47,7 @@ def is_abstract(tensor: torch.Tensor) -> bool:
def safe_schema_check(
op: torch._ops.OpOverload,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
) -> Any:
@ -62,7 +63,7 @@ def safe_schema_check(
def safe_autograd_registration_check(
op: torch._ops.OpOverload,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
) -> None:
@ -81,7 +82,7 @@ def safe_autograd_registration_check(
def safe_fake_check(
op: torch._ops.OpOverload,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
copy_inputs: bool = True,
) -> None:
@ -95,7 +96,7 @@ def safe_fake_check(
def safe_aot_autograd_check(
op: torch._ops.OpOverload,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
dynamic: bool,
*,
copy_inputs: bool = True,
@ -155,10 +156,10 @@ DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [
def generate_opcheck_tests(
testcase: Any,
namespaces: List[str],
namespaces: list[str],
failures_dict_path: Optional[str] = None,
additional_decorators: Optional[Dict[str, Callable]] = None,
test_utils: List[str] = DEFAULT_TEST_UTILS,
additional_decorators: Optional[dict[str, Callable]] = None,
test_utils: list[str] = DEFAULT_TEST_UTILS,
) -> None:
"""Given an existing TestCase, use the existing tests to generate
additional validation tests for custom operators.
@ -361,7 +362,7 @@ def validate_failures_dict_formatting(failures_dict_path: str) -> None:
def validate_failures_dict_structure(
failure_dict: "FailuresDict", test_utils: List[str], testcase: Any
failure_dict: "FailuresDict", test_utils: list[str], testcase: Any
) -> None:
"""Validates the failures dict.
@ -447,7 +448,7 @@ class OpCheckMode(TorchFunctionMode):
def __init__(
self,
namespaces: List[str],
namespaces: list[str],
test_util_name: str,
test_util: Callable,
failures_dict: "FailuresDict",
@ -619,11 +620,11 @@ def should_print_better_repro() -> None:
def opcheck(
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
args: tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
raise_exception: bool = True,
) -> Dict[str, str]:
) -> dict[str, str]:
"""See torch.library.opcheck for docstring"""
if kwargs is None:
@ -673,7 +674,7 @@ def generate_repro(
test: str,
op: torch._ops.OpOverload,
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
save_data: bool,
dry_run: bool = False,
@ -738,7 +739,7 @@ def resolve_unique_overload_or_throw(
DUMP_OPTIONS = {"indent": 2, "sort_keys": True}
FailuresDictData = Dict[str, Dict[str, Dict[str, str]]]
FailuresDictData = dict[str, dict[str, dict[str, str]]]
VERSION = 1

View File

@ -1,5 +1,5 @@
# mypy: ignore-errors
from typing import Any, Optional, Type
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
@ -68,7 +68,7 @@ class WrapperSubclass(torch.Tensor):
return return_and_correct_aliasing(func, args, kwargs, out)
def __coerce_same_metadata_as_tangent__(
self, expected_metadata: Any, expected_type: Optional[Type] = None
self, expected_metadata: Any, expected_type: Optional[type] = None
):
if expected_type == type(self.a):
return self.a