mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
PEP585 update - torch/testing (#145200)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145200 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
805c4b597a
commit
dea7ad3371
@ -3,18 +3,8 @@ import abc
|
||||
import cmath
|
||||
import collections.abc
|
||||
import contextlib
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Collection, Sequence
|
||||
from typing import Any, Callable, NoReturn, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
@ -33,7 +23,7 @@ class ErrorMeta(Exception):
|
||||
"""Internal testing exception that makes that carries error metadata."""
|
||||
|
||||
def __init__(
|
||||
self, type: Type[Exception], msg: str, *, id: tuple[Any, ...] = ()
|
||||
self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = ()
|
||||
) -> None:
|
||||
super().__init__(
|
||||
"If you are a user and see this message during normal operation "
|
||||
@ -82,7 +72,7 @@ _DTYPE_PRECISIONS.update(
|
||||
|
||||
def default_tolerances(
|
||||
*inputs: Union[torch.Tensor, torch.dtype],
|
||||
dtype_precisions: Optional[Dict[torch.dtype, tuple[float, float]]] = None,
|
||||
dtype_precisions: Optional[dict[torch.dtype, tuple[float, float]]] = None,
|
||||
) -> tuple[float, float]:
|
||||
"""Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
|
||||
|
||||
@ -341,13 +331,13 @@ class Pair(abc.ABC):
|
||||
raise UnsupportedInputs
|
||||
|
||||
@staticmethod
|
||||
def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, tuple[Type, ...]]):
|
||||
def _check_inputs_isinstance(*inputs: Any, cls: Union[type, tuple[type, ...]]):
|
||||
"""Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
|
||||
if not all(isinstance(input, cls) for input in inputs):
|
||||
Pair._inputs_not_supported()
|
||||
|
||||
def _fail(
|
||||
self, type: Type[Exception], msg: str, *, id: tuple[Any, ...] = ()
|
||||
self, type: type[Exception], msg: str, *, id: tuple[Any, ...] = ()
|
||||
) -> NoReturn:
|
||||
"""Raises an :class:`ErrorMeta` from a given exception type and message and the stored id.
|
||||
|
||||
@ -451,8 +441,8 @@ class BooleanPair(Pair):
|
||||
super().__init__(actual, expected, **other_parameters)
|
||||
|
||||
@property
|
||||
def _supported_types(self) -> tuple[Type, ...]:
|
||||
cls: List[Type] = [bool]
|
||||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls: list[type] = [bool]
|
||||
if HAS_NUMPY:
|
||||
cls.append(np.bool_)
|
||||
return tuple(cls)
|
||||
@ -545,7 +535,7 @@ class NumberPair(Pair):
|
||||
self.check_dtype = check_dtype
|
||||
|
||||
@property
|
||||
def _supported_types(self) -> tuple[Type, ...]:
|
||||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls = list(self._NUMBER_TYPES)
|
||||
if HAS_NUMPY:
|
||||
cls.append(np.number)
|
||||
@ -1052,12 +1042,12 @@ def originate_pairs(
|
||||
actual: Any,
|
||||
expected: Any,
|
||||
*,
|
||||
pair_types: Sequence[Type[Pair]],
|
||||
sequence_types: tuple[Type, ...] = (collections.abc.Sequence,),
|
||||
mapping_types: tuple[Type, ...] = (collections.abc.Mapping,),
|
||||
pair_types: Sequence[type[Pair]],
|
||||
sequence_types: tuple[type, ...] = (collections.abc.Sequence,),
|
||||
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
||||
id: tuple[Any, ...] = (),
|
||||
**options: Any,
|
||||
) -> List[Pair]:
|
||||
) -> list[Pair]:
|
||||
"""Originates pairs from the individual inputs.
|
||||
|
||||
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
|
||||
@ -1092,8 +1082,8 @@ def originate_pairs(
|
||||
and isinstance(expected, sequence_types)
|
||||
and not isinstance(expected, str)
|
||||
):
|
||||
actual_len = len(actual)
|
||||
expected_len = len(expected)
|
||||
actual_len = len(actual) # type: ignore[arg-type]
|
||||
expected_len = len(expected) # type: ignore[arg-type]
|
||||
if actual_len != expected_len:
|
||||
raise ErrorMeta(
|
||||
AssertionError,
|
||||
@ -1105,8 +1095,8 @@ def originate_pairs(
|
||||
for idx in range(actual_len):
|
||||
pairs.extend(
|
||||
originate_pairs(
|
||||
actual[idx],
|
||||
expected[idx],
|
||||
actual[idx], # type: ignore[index]
|
||||
expected[idx], # type: ignore[index]
|
||||
pair_types=pair_types,
|
||||
sequence_types=sequence_types,
|
||||
mapping_types=mapping_types,
|
||||
@ -1117,8 +1107,8 @@ def originate_pairs(
|
||||
return pairs
|
||||
|
||||
elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
|
||||
actual_keys = set(actual.keys())
|
||||
expected_keys = set(expected.keys())
|
||||
actual_keys = set(actual.keys()) # type: ignore[attr-defined]
|
||||
expected_keys = set(expected.keys()) # type: ignore[attr-defined]
|
||||
if actual_keys != expected_keys:
|
||||
missing_keys = expected_keys - actual_keys
|
||||
additional_keys = actual_keys - expected_keys
|
||||
@ -1141,8 +1131,8 @@ def originate_pairs(
|
||||
for key in keys:
|
||||
pairs.extend(
|
||||
originate_pairs(
|
||||
actual[key],
|
||||
expected[key],
|
||||
actual[key], # type: ignore[index]
|
||||
expected[key], # type: ignore[index]
|
||||
pair_types=pair_types,
|
||||
sequence_types=sequence_types,
|
||||
mapping_types=mapping_types,
|
||||
@ -1190,11 +1180,11 @@ def not_close_error_metas(
|
||||
actual: Any,
|
||||
expected: Any,
|
||||
*,
|
||||
pair_types: Sequence[Type[Pair]] = (ObjectPair,),
|
||||
sequence_types: tuple[Type, ...] = (collections.abc.Sequence,),
|
||||
mapping_types: tuple[Type, ...] = (collections.abc.Mapping,),
|
||||
pair_types: Sequence[type[Pair]] = (ObjectPair,),
|
||||
sequence_types: tuple[type, ...] = (collections.abc.Sequence,),
|
||||
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
||||
**options: Any,
|
||||
) -> List[ErrorMeta]:
|
||||
) -> list[ErrorMeta]:
|
||||
"""Asserts that inputs are equal.
|
||||
|
||||
``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
|
||||
@ -1225,7 +1215,7 @@ def not_close_error_metas(
|
||||
# Explicitly raising from None to hide the internal traceback
|
||||
raise error_meta.to_error() from None # noqa: RSE102
|
||||
|
||||
error_metas: List[ErrorMeta] = []
|
||||
error_metas: list[ErrorMeta] = []
|
||||
for pair in pairs:
|
||||
try:
|
||||
pair.compare()
|
||||
|
@ -6,7 +6,7 @@ import collections.abc
|
||||
import functools
|
||||
import math
|
||||
import warnings
|
||||
from typing import cast, List, Optional, Union
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -43,7 +43,7 @@ def _uniform_random_(t: torch.Tensor, low: float, high: float) -> torch.Tensor:
|
||||
|
||||
|
||||
def make_tensor(
|
||||
*shape: Union[int, torch.Size, List[int], tuple[int, ...]],
|
||||
*shape: Union[int, torch.Size, list[int], tuple[int, ...]],
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device],
|
||||
low: Optional[float] = None,
|
||||
|
@ -3,7 +3,6 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
__all__ = [
|
||||
"check_code_for_cuda_kernel_launches",
|
||||
@ -15,7 +14,7 @@ __all__ = [
|
||||
# launch a kernel without some safety? Use this as a quick workaround
|
||||
# for a problem with the checker, fix the checker, then de-exclude
|
||||
# the files in question.
|
||||
exclude_files: List[str] = []
|
||||
exclude_files: list[str] = []
|
||||
|
||||
# Without using a C++ AST we can't 100% detect kernel launches, so we
|
||||
# model them as having the pattern "<<<parameters>>>(arguments);"
|
||||
|
@ -9,21 +9,10 @@ import sys
|
||||
import threading
|
||||
import unittest
|
||||
from collections import namedtuple
|
||||
from collections.abc import Iterable, Sequence
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, ClassVar, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -505,7 +494,7 @@ class DeviceTypeTestBase(TestCase):
|
||||
|
||||
def dtype_parametrize_fn(test, generic_cls, device_cls, dtypes=dtypes):
|
||||
for dtype in dtypes:
|
||||
param_kwargs: Dict[str, Any] = {}
|
||||
param_kwargs: dict[str, Any] = {}
|
||||
_update_param_kwargs(param_kwargs, "dtype", dtype)
|
||||
|
||||
# Note that an empty test suffix is set here so that the dtype can be appended
|
||||
@ -728,7 +717,7 @@ class PrivateUse1TestBase(DeviceTypeTestBase):
|
||||
def get_device_type_test_bases():
|
||||
# set type to List[Any] due to mypy list-of-union issue:
|
||||
# https://github.com/python/mypy/issues/3351
|
||||
test_bases: List[Any] = []
|
||||
test_bases: list[Any] = []
|
||||
|
||||
if IS_SANDCASTLE or IS_FBCODE:
|
||||
if IS_REMOTE_GPU:
|
||||
@ -1089,7 +1078,7 @@ class ops(_TestParametrizer):
|
||||
op = check_exhausted_iterator = object()
|
||||
for op in self.op_list:
|
||||
# Determine the set of dtypes to use.
|
||||
dtypes: Union[Set[torch.dtype], Set[None]]
|
||||
dtypes: Union[set[torch.dtype], set[None]]
|
||||
if isinstance(self.opinfo_dtypes, Sequence):
|
||||
dtypes = set(self.opinfo_dtypes)
|
||||
elif self.opinfo_dtypes == OpDTypes.unsupported_backward:
|
||||
@ -1854,7 +1843,7 @@ def skipCUDAIfNotMiopenSuggestNHWC(fn):
|
||||
|
||||
|
||||
# Skips a test for specified CUDA versions, given in the form of a list of [major, minor]s.
|
||||
def skipCUDAVersionIn(versions: Optional[List[tuple[int, int]]] = None):
|
||||
def skipCUDAVersionIn(versions: Optional[list[tuple[int, int]]] = None):
|
||||
def dec_fn(fn):
|
||||
@wraps(fn)
|
||||
def wrap_fn(self, *args, **kwargs):
|
||||
@ -1969,7 +1958,7 @@ def skipPRIVATEUSE1(fn):
|
||||
|
||||
# TODO: the "all" in the name isn't true anymore for quite some time as we have also have for example XLA and MPS now.
|
||||
# This should probably enumerate all available device type test base classes.
|
||||
def get_all_device_types() -> List[str]:
|
||||
def get_all_device_types() -> list[str]:
|
||||
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from datetime import timedelta
|
||||
from enum import Enum
|
||||
from functools import partial, reduce, wraps
|
||||
from io import StringIO
|
||||
from typing import Dict, NamedTuple, Optional, Union, List, Any, Callable
|
||||
from typing import NamedTuple, Optional, Union, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
from torch._logging._internal import trace_log
|
||||
@ -963,7 +963,7 @@ class DistributedTestBase(MultiProcessTestCase):
|
||||
|
||||
def run_subtests(
|
||||
cls_inst,
|
||||
subtest_config: Dict[str, List[Any]],
|
||||
subtest_config: dict[str, list[Any]],
|
||||
test_fn: Callable,
|
||||
*test_args,
|
||||
**test_kwargs: Any,
|
||||
@ -982,9 +982,9 @@ def run_subtests(
|
||||
test_kwargs: Keyword arguments to pass to ``test_fn``.
|
||||
"""
|
||||
# Convert the config mapping to a list to have a fixed order
|
||||
subtest_config_items: List[tuple[str, List[Any]]] = list(subtest_config.items())
|
||||
subtest_config_keys: List[str] = [item[0] for item in subtest_config_items]
|
||||
subtest_config_values: List[List[Any]] = [item[1] for item in subtest_config_items]
|
||||
subtest_config_items: list[tuple[str, list[Any]]] = list(subtest_config.items())
|
||||
subtest_config_keys: list[str] = [item[0] for item in subtest_config_items]
|
||||
subtest_config_values: list[list[Any]] = [item[1] for item in subtest_config_items]
|
||||
for values in itertools.product(*subtest_config_values):
|
||||
# Map keyword to chosen value
|
||||
subtest_kwargs = dict(zip(subtest_config_keys, values))
|
||||
@ -1314,7 +1314,7 @@ class MultiThreadedTestCase(TestCase):
|
||||
class SaveForwardInputsModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor],
|
||||
forward_inputs: dict[nn.Module, torch.Tensor],
|
||||
cast_forward_inputs: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -1330,7 +1330,7 @@ class SaveForwardInputsModule(nn.Module):
|
||||
class SaveForwardInputsModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
forward_inputs: Dict[nn.Module, torch.Tensor],
|
||||
forward_inputs: dict[nn.Module, torch.Tensor],
|
||||
cast_forward_inputs: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -158,7 +157,7 @@ def get_all_dtypes(
|
||||
include_complex=True,
|
||||
include_complex32=False,
|
||||
include_qint=False,
|
||||
) -> List[torch.dtype]:
|
||||
) -> list[torch.dtype]:
|
||||
dtypes = get_all_int_dtypes() + get_all_fp_dtypes(
|
||||
include_half=include_half, include_bfloat16=include_bfloat16
|
||||
)
|
||||
@ -171,7 +170,7 @@ def get_all_dtypes(
|
||||
return dtypes
|
||||
|
||||
|
||||
def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||
def get_all_math_dtypes(device) -> list[torch.dtype]:
|
||||
return (
|
||||
get_all_int_dtypes()
|
||||
+ get_all_fp_dtypes(
|
||||
@ -181,7 +180,7 @@ def get_all_math_dtypes(device) -> List[torch.dtype]:
|
||||
)
|
||||
|
||||
|
||||
def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
||||
def get_all_complex_dtypes(include_complex32=False) -> list[torch.dtype]:
|
||||
return (
|
||||
[torch.complex32, torch.complex64, torch.complex128]
|
||||
if include_complex32
|
||||
@ -189,11 +188,11 @@ def get_all_complex_dtypes(include_complex32=False) -> List[torch.dtype]:
|
||||
)
|
||||
|
||||
|
||||
def get_all_int_dtypes() -> List[torch.dtype]:
|
||||
def get_all_int_dtypes() -> list[torch.dtype]:
|
||||
return [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
|
||||
|
||||
|
||||
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dtype]:
|
||||
def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> list[torch.dtype]:
|
||||
dtypes = [torch.float32, torch.float64]
|
||||
if include_half:
|
||||
dtypes.append(torch.float16)
|
||||
@ -202,7 +201,7 @@ def get_all_fp_dtypes(include_half=True, include_bfloat16=True) -> List[torch.dt
|
||||
return dtypes
|
||||
|
||||
|
||||
def get_all_qint_dtypes() -> List[torch.dtype]:
|
||||
def get_all_qint_dtypes() -> list[torch.dtype]:
|
||||
return [torch.qint8, torch.quint8, torch.qint32, torch.quint4x2, torch.quint2x4]
|
||||
|
||||
|
||||
|
@ -12,7 +12,7 @@ from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from enum import auto, Enum
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, cast, Dict, List, no_type_check, Optional, Type, Union
|
||||
from typing import Any, Callable, cast, no_type_check, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
@ -199,7 +199,7 @@ def _broadcast_state_dict(rank, state_dict):
|
||||
|
||||
olist = [state_dict if rank == 0 else None]
|
||||
dist.broadcast_object_list(olist)
|
||||
state_dict = cast(Dict[str, torch.Tensor], olist[0])
|
||||
state_dict = cast(dict[str, torch.Tensor], olist[0])
|
||||
# Ensure that the state is on DEVICE
|
||||
for param_name in state_dict.keys():
|
||||
state_dict[param_name] = state_dict[param_name].to(DEVICE_TYPE)
|
||||
@ -322,7 +322,7 @@ class TransformerWithSharedParams(FSDPTestModel):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
add_bn: bool = True,
|
||||
) -> Union[nn.Module, FSDP]:
|
||||
@ -451,7 +451,7 @@ class NestedWrappedModule(FSDPTestModel):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
@ -499,7 +499,7 @@ class AlwaysWrapNestedWrappedModule(NestedWrappedModule):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -581,7 +581,7 @@ class NonUniformReqGradNWM(NestedWrappedModule):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
"""
|
||||
@ -674,7 +674,7 @@ class ModuleWithDelay(FSDPTestModel):
|
||||
|
||||
@staticmethod
|
||||
def init(
|
||||
module_class: Type[FSDPTestModel],
|
||||
module_class: type[FSDPTestModel],
|
||||
*model_args: Any,
|
||||
delay_after_loss_ms: int,
|
||||
delay_before_reduction_ms: int,
|
||||
@ -706,7 +706,7 @@ class NestedWrappedModuleWithDelay(ModuleWithDelay):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode = DEVICEInitMode.DEVICE_AFTER,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
delay_after_loss_ms: int = 0,
|
||||
delay_before_reduction_ms: int = 0,
|
||||
@ -826,7 +826,7 @@ class MixtureOfExperts(NestedWrappedModule):
|
||||
group: dist.ProcessGroup,
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
fsdp_kwargs: Optional[Dict[str, Any]] = None,
|
||||
fsdp_kwargs: Optional[dict[str, Any]] = None,
|
||||
deterministic: bool = False,
|
||||
delay_before_free_ms: int = 0,
|
||||
):
|
||||
@ -907,7 +907,7 @@ class MLP(nn.Module):
|
||||
|
||||
class MLPStack(nn.Sequential):
|
||||
def __init__(self, mlp_dim: int, *, with_seq_parallel: bool = False):
|
||||
modules: List[nn.Module] = [
|
||||
modules: list[nn.Module] = [
|
||||
# Use multiplier of 3 to exercise uneven case
|
||||
MLP(mlp_dim, dim_multiplier=3),
|
||||
MLP(mlp_dim),
|
||||
@ -1237,7 +1237,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||
mixed_precision: Optional[MixedPrecision] = None,
|
||||
enable_sharded_grad_scaler: bool = False,
|
||||
use_pure_fp16: bool = False,
|
||||
sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
cpu_offload_params = fsdp_cpu_offload and fsdp_cpu_offload.offload_params
|
||||
|
||||
@ -1315,7 +1315,7 @@ class FSDPTest(MultiProcessTestCase):
|
||||
|
||||
def _test_fsdp_parity(
|
||||
self,
|
||||
model_class: Type[FSDPTestModel],
|
||||
model_class: type[FSDPTestModel],
|
||||
fsdp_init_mode: FSDPInitMode,
|
||||
device_init_mode: DEVICEInitMode,
|
||||
ref_init_fn: Optional[Callable] = None,
|
||||
@ -1329,8 +1329,8 @@ class FSDPTest(MultiProcessTestCase):
|
||||
use_orig_params: bool = False,
|
||||
enable_sharded_grad_scaler: bool = False,
|
||||
use_pure_fp16: bool = False,
|
||||
init_kwargs: Optional[Dict[str, Any]] = None,
|
||||
sharded_grad_scaler_kwargs: Optional[Dict[str, Any]] = None,
|
||||
init_kwargs: Optional[dict[str, Any]] = None,
|
||||
sharded_grad_scaler_kwargs: Optional[dict[str, Any]] = None,
|
||||
**fsdp_kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -16,7 +16,7 @@ from torch.testing._internal.common_utils import enable_profiling_mode # noqa:
|
||||
|
||||
# Standard library
|
||||
from itertools import chain
|
||||
from typing import List, Union
|
||||
from typing import Union
|
||||
from torch._C import TensorType
|
||||
|
||||
import io
|
||||
@ -62,7 +62,7 @@ def check_against_reference(self, func, reference_func, output_func, args, kwarg
|
||||
return t.detach().clone().requires_grad_(require_grad)
|
||||
|
||||
def clone_inputs(preserve_requires_grad: bool):
|
||||
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
|
||||
inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
@ -76,7 +76,7 @@ def check_against_reference(self, func, reference_func, output_func, args, kwarg
|
||||
|
||||
# Returns tensors in args that requires_grad, including tensors in TensorList args
|
||||
def get_recording_tensors(args):
|
||||
recording_tensors: List[torch.Tensor] = []
|
||||
recording_tensors: list[torch.Tensor] = []
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
||||
@ -284,7 +284,7 @@ class JitCommonTestCase(TestCase):
|
||||
self.assertEqual(should_autodiff_node,
|
||||
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
|
||||
|
||||
def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
|
||||
def checkShapeAnalysis(self, out_sizes: Union[list[int], list[list[int]]],
|
||||
traced_graph, assert_propagation, constant_prop=True):
|
||||
# repropagte input shapes provided by tracing,
|
||||
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
|
||||
|
@ -16,7 +16,8 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
from torch import inf, nan
|
||||
|
||||
from typing import Any, Dict, List, Tuple, Union, Sequence
|
||||
from typing import Any, Union
|
||||
from collections.abc import Sequence
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import (
|
||||
_dispatch_dtypes, floating_types, floating_types_and, complex_types, floating_and_complex_types,
|
||||
@ -3992,7 +3993,7 @@ def sample_inputs_conv1d(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
# Ordered as shapes for input, weight, bias,
|
||||
# and a dict of values of (stride, padding, dilation, groups)
|
||||
cases: Tuple = (
|
||||
cases: tuple = (
|
||||
((1, 3, 4), (3, 3, 3), (3,), {'stride': (2,), 'padding': 2, 'groups': 1}),
|
||||
((2, 4, 8), (2, 2, 3), (2,), {'stride': 3, 'padding': 1, 'groups': 2, 'dilation': 2}),
|
||||
((1, 4, 5), (1, 4, 3), None, {'stride': (2,), 'padding': 'valid'}),
|
||||
@ -4140,7 +4141,7 @@ def sample_inputs_conv2d(op_info, device, dtype, requires_grad, jit_fail_sample=
|
||||
|
||||
# Ordered as shapes for input, weight, bias
|
||||
# and a dict of values of (stride, padding, groups, dilation)
|
||||
cases: Tuple = (
|
||||
cases: tuple = (
|
||||
((1, 3, 4, 4), (3, 3, 3, 3), (3,),
|
||||
{'stride': (2, 2), 'padding': 2, 'groups': 1}),
|
||||
((2, 4, 8, 8), (2, 2, 3, 3), (2,),
|
||||
@ -4185,7 +4186,7 @@ def sample_inputs_conv3d(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
# Ordered as shapes for input, weight, bias
|
||||
# and dict of values of (stride, padding, dilation, groups)
|
||||
cases: Tuple = (
|
||||
cases: tuple = (
|
||||
((1, 1, 4, 4, 4), (1, 1, 1, 1, 1), (1,), {'padding': 'same'}),
|
||||
((1, 1, 4, 4, 4), (1, 1, 4, 4, 4), (1,), {'stride': (2, 2, 2)}),
|
||||
((1, 1, 5, 5, 5), (1, 1, 3, 3, 3), (1,), {'dilation': 2}),
|
||||
@ -4658,7 +4659,7 @@ def sample_inputs_hardswish(self, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
|
||||
features_options = [[3, 4], [8, 8]]
|
||||
batch_options: List[List[int]] = [
|
||||
batch_options: list[list[int]] = [
|
||||
[], # no batch
|
||||
[0],
|
||||
[8],
|
||||
@ -4684,7 +4685,7 @@ def sample_inputs_linear(self, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
|
||||
features_options = [[3, 4, 5], [8, 8, 8]]
|
||||
batch_options: List[List[int]] = [
|
||||
batch_options: list[list[int]] = [
|
||||
[], # no batch
|
||||
[0],
|
||||
[8],
|
||||
@ -4706,7 +4707,7 @@ def sample_inputs_bilinear(self, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
def sample_inputs_glu(self, device, dtype, requires_grad, **kwargs):
|
||||
features_options = [[2], [2, 4], [8, 8], [3, 6, 8], [1, 4, 6, 7]]
|
||||
batch_options: List[List[int]] = [
|
||||
batch_options: list[list[int]] = [
|
||||
[], # no batch
|
||||
[0],
|
||||
[8],
|
||||
@ -5063,7 +5064,7 @@ def sample_inputs_avgpool1d(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
# Order: input_shape, kernel_size, kwargs
|
||||
cases: List[tuple[tuple[int, ...], Union[int, tuple[int, ...]], Dict]] = [
|
||||
cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
|
||||
((2, 3, 9), (3,), {}),
|
||||
((1, 3, 9), 3, dict(stride=1, padding=1, ceil_mode=True, count_include_pad=False)),
|
||||
((1, 3, 9), (6,), dict(stride=(3,), padding=(2,), ceil_mode=True, count_include_pad=True)),
|
||||
@ -5082,7 +5083,7 @@ def sample_inputs_avgpool3d(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
# Order: input_shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
|
||||
cases: List[tuple[tuple[int, ...], Union[int, tuple[int, ...]], Dict]] = [
|
||||
cases: list[tuple[tuple[int, ...], Union[int, tuple[int, ...]], dict]] = [
|
||||
((2, 3, 3, 4, 4), (2, 2, 2), {}),
|
||||
((1, 2, 4, 4, 4), 2, dict(stride=1, padding=1, ceil_mode=True,
|
||||
count_include_pad=False, divisor_override=2)),
|
||||
@ -6637,7 +6638,7 @@ def sample_inputs_cross_entropy(op_info, device, dtype, requires_grad, **kwargs)
|
||||
batch_size, num_classes = shape = (2, 3)
|
||||
reductions = ("mean", "sum", "none")
|
||||
|
||||
input_shape_and_kwargs: List[tuple[tuple[int, ...], Dict[str, Any]]] = [
|
||||
input_shape_and_kwargs: list[tuple[tuple[int, ...], dict[str, Any]]] = [
|
||||
(shape, {}),
|
||||
((*shape, 1), {}),
|
||||
((*shape, 1, 2), {}),
|
||||
@ -6828,17 +6829,17 @@ def sample_inputs_matmul(op_info, device, dtype, requires_grad, is_rmatmul=False
|
||||
|
||||
def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.dtype,
|
||||
requires_grad: bool,
|
||||
*, variant: str, **kwargs) -> List[SampleInput]:
|
||||
*, variant: str, **kwargs) -> list[SampleInput]:
|
||||
if variant == 'variadic':
|
||||
def make_inputs(
|
||||
tensors: List[torch.Tensor]) -> tuple[Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
|
||||
list[torch.Tensor]],
|
||||
tuple[torch.Tensor, ...]]:
|
||||
return tensors
|
||||
elif variant == 'list':
|
||||
def make_inputs(
|
||||
tensors: List[torch.Tensor]) -> tuple[Union[torch.Tensor,
|
||||
List[torch.Tensor]],
|
||||
tensors: list[torch.Tensor]) -> tuple[Union[torch.Tensor,
|
||||
list[torch.Tensor]],
|
||||
tuple[torch.Tensor, ...]]:
|
||||
return [tensors]
|
||||
else:
|
||||
@ -6848,7 +6849,7 @@ def sample_inputs_meshgrid(op_info: OpInfo, device: torch.device, dtype: torch.d
|
||||
|
||||
SCALAR = torch.Size([])
|
||||
VECTOR = torch.Size([3])
|
||||
test_cases: List[List[torch.Size]] = [
|
||||
test_cases: list[list[torch.Size]] = [
|
||||
[SCALAR],
|
||||
[VECTOR],
|
||||
[VECTOR, SCALAR],
|
||||
@ -9664,7 +9665,7 @@ class foreach_pointwise_sample_func(foreach_inputs_sample_func):
|
||||
args.pop()
|
||||
|
||||
|
||||
foreach_unary_op_db: List[OpInfo] = [
|
||||
foreach_unary_op_db: list[OpInfo] = [
|
||||
ForeachFuncInfo(
|
||||
'exp',
|
||||
sample_inputs_func=foreach_inputs_sample_func(1, False, False),
|
||||
@ -10835,7 +10836,7 @@ foreach_unary_op_db: List[OpInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
foreach_binary_op_db: List[OpInfo] = [
|
||||
foreach_binary_op_db: list[OpInfo] = [
|
||||
ForeachFuncInfo(
|
||||
"add",
|
||||
sample_inputs_func=foreach_inputs_sample_func(2, True, True, True),
|
||||
@ -11130,7 +11131,7 @@ foreach_binary_op_db: List[OpInfo] = [
|
||||
)
|
||||
]
|
||||
|
||||
foreach_pointwise_op_db: List[ForeachFuncInfo] = [
|
||||
foreach_pointwise_op_db: list[ForeachFuncInfo] = [
|
||||
ForeachFuncInfo(
|
||||
"addcmul",
|
||||
sample_inputs_func=foreach_pointwise_sample_func(4, True, True),
|
||||
@ -11184,7 +11185,7 @@ foreach_pointwise_op_db: List[ForeachFuncInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
foreach_reduce_op_db: list[ForeachFuncInfo] = [
|
||||
ForeachFuncInfo(
|
||||
"max",
|
||||
sample_inputs_func=foreach_max_sample_func(1, False, False),
|
||||
@ -11267,7 +11268,7 @@ foreach_reduce_op_db: List[ForeachFuncInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
foreach_other_op_db: List[ForeachFuncInfo] = [
|
||||
foreach_other_op_db: list[ForeachFuncInfo] = [
|
||||
ForeachFuncInfo(
|
||||
"lerp",
|
||||
sample_inputs_func=foreach_inputs_sample_func(3, True, True),
|
||||
@ -11665,7 +11666,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
|
||||
# Operator database (sorted alphabetically)
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
UnaryUfuncInfo('abs',
|
||||
aliases=('absolute', ),
|
||||
ref=np.abs,
|
||||
|
@ -28,11 +28,10 @@ from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state, skipIfMPS, skipIfMPSOnMacOS13, GRADCHECK_NONDET_TOL, TEST_WITH_ROCM, IS_WINDOWS,
|
||||
skipIfTorchDynamo)
|
||||
from types import ModuleType
|
||||
from typing import List, Type, Set, Dict
|
||||
import operator
|
||||
|
||||
# List of all namespaces containing modules to test.
|
||||
MODULE_NAMESPACES: List[ModuleType] = [
|
||||
MODULE_NAMESPACES: list[ModuleType] = [
|
||||
torch.nn.modules,
|
||||
torch.ao.nn.qat.modules,
|
||||
torch.ao.nn.quantizable.modules,
|
||||
@ -41,7 +40,7 @@ MODULE_NAMESPACES: List[ModuleType] = [
|
||||
]
|
||||
|
||||
# Modules that shouldn't be tested for one reason or another.
|
||||
MODULES_TO_SKIP: Set[Type] = {
|
||||
MODULES_TO_SKIP: set[type] = {
|
||||
torch.nn.Module, # abstract base class
|
||||
torch.nn.Container, # deprecated
|
||||
torch.nn.NLLLoss2d, # deprecated
|
||||
@ -50,14 +49,14 @@ MODULES_TO_SKIP: Set[Type] = {
|
||||
}
|
||||
|
||||
# List of all module classes to test.
|
||||
MODULE_CLASSES: List[Type] = list(chain(*[
|
||||
MODULE_CLASSES: list[type] = list(chain(*[
|
||||
[getattr(namespace, module_name) for module_name in namespace.__all__] # type: ignore[attr-defined]
|
||||
for namespace in MODULE_NAMESPACES]))
|
||||
MODULE_CLASSES = [cls for cls in MODULE_CLASSES if cls not in MODULES_TO_SKIP]
|
||||
|
||||
# Dict of module class -> common name. Useful for making test names more intuitive.
|
||||
# Example: torch.nn.modules.linear.Linear -> "nn.Linear"
|
||||
MODULE_CLASS_NAMES: Dict[Type, str] = {}
|
||||
MODULE_CLASS_NAMES: dict[type, str] = {}
|
||||
for namespace in MODULE_NAMESPACES:
|
||||
for module_name in namespace.__all__: # type: ignore[attr-defined]
|
||||
module_cls = getattr(namespace, module_name)
|
||||
@ -317,7 +316,7 @@ def module_inputs_torch_nn_Bilinear(module_info, device, dtype, requires_grad, t
|
||||
def module_inputs_torch_nn_KLDivLoss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_batchmean', {'reduction': 'batchmean'}),
|
||||
@ -360,7 +359,7 @@ def module_inputs_torch_nn_NLLLoss(module_info, device, dtype, requires_grad, tr
|
||||
requires_grad=False).log_softmax(dim=1).requires_grad_(requires_grad)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_none', {'reduction': 'none'}),
|
||||
@ -425,7 +424,7 @@ def module_inputs_torch_nn_GaussianNLLLoss(module_info, device, dtype, requires_
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -450,7 +449,7 @@ def module_inputs_torch_nn_PoissonNLLLoss(module_info, device, dtype, requires_g
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -498,7 +497,7 @@ def module_inputs_torch_nn_MSELoss(module_info, device, dtype, requires_grad, tr
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -980,7 +979,7 @@ def module_inputs_torch_nn_CosineEmbeddingLoss(module_info, device, dtype, requi
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1419,7 +1418,7 @@ def module_inputs_torch_nn_SmoothL1Loss(module_info, device, dtype, requires_gra
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1455,7 +1454,7 @@ def module_inputs_torch_nn_BCELoss(module_info, device, dtype, requires_grad, tr
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1503,7 +1502,7 @@ def module_inputs_torch_nn_BCEWithLogitsLoss(module_info, device, dtype, require
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1545,8 +1544,8 @@ def module_inputs_torch_nn_CrossEntropyLoss(module_info, device, dtype, requires
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
reductions: List[str] = ['mean', 'sum', 'none']
|
||||
cases: List[tuple[str, dict]] = [
|
||||
reductions: list[str] = ['mean', 'sum', 'none']
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('weights', {'weight': make_weight((3,))}),
|
||||
('ignore_index', {'ignore_index': 1}),
|
||||
@ -1633,7 +1632,7 @@ def module_inputs_torch_nn_CTCLoss(module_info, device, dtype, requires_grad, tr
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1799,7 +1798,7 @@ def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requir
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -1833,7 +1832,7 @@ def module_inputs_torch_nn_HingeEmbeddingLoss(module_info, device, dtype, requir
|
||||
def module_inputs_torch_nn_HuberLoss(module_info, device, dtype, requires_grad, training, **kwargs):
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -2245,7 +2244,7 @@ def module_inputs_torch_nn_MarginRankingLoss(module_info, device, dtype, require
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -2273,7 +2272,7 @@ def module_inputs_torch_nn_MultiLabelMarginLoss(module_info, device, dtype, requ
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -2309,7 +2308,7 @@ def module_inputs_torch_nn_MultiMarginLoss(module_info, device, dtype, requires_
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -2340,7 +2339,7 @@ def module_inputs_torch_nn_MultiLabelSoftMarginLoss(module_info, device, dtype,
|
||||
make_target = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)
|
||||
make_weight = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -2378,7 +2377,7 @@ def module_inputs_torch_nn_SoftMarginLoss(module_info, device, dtype, requires_g
|
||||
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_target = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
||||
cases: List[tuple[str, dict]] = [
|
||||
cases: list[tuple[str, dict]] = [
|
||||
('', {}),
|
||||
('reduction_sum', {'reduction': 'sum'}),
|
||||
('reduction_mean', {'reduction': 'mean'}),
|
||||
@ -3362,7 +3361,7 @@ _macos15_or_newer = torch.backends.mps.is_available() and torch.backends.mps.is_
|
||||
|
||||
|
||||
# Database of ModuleInfo entries in alphabetical order.
|
||||
module_db: List[ModuleInfo] = [
|
||||
module_db: list[ModuleInfo] = [
|
||||
ModuleInfo(torch.nn.AdaptiveAvgPool1d,
|
||||
module_inputs_func=module_inputs_torch_nn_AdaptiveAvgPool1d,
|
||||
skips=(
|
||||
|
@ -23,7 +23,8 @@ from torch.autograd import Variable
|
||||
from torch.types import _TensorOrTensors
|
||||
import torch.backends.cudnn
|
||||
|
||||
from typing import Dict, Callable, List, Sequence, Union, Any
|
||||
from typing import Callable, Union, Any
|
||||
from collections.abc import Sequence
|
||||
|
||||
TemporaryFile = tempfile.TemporaryFile
|
||||
PRECISION = 1e-5
|
||||
@ -502,7 +503,7 @@ def nllloss_no_reduce_test():
|
||||
|
||||
def nllloss_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.empty(15).uniform_().mul(10).floor().long())
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
|
||||
kwargs: dict[str, Union[int, str]] = {'ignore_index': 2, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
@ -605,7 +606,7 @@ def nllloss2d_no_reduce_test():
|
||||
|
||||
def nllloss2d_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.rand(2, 5, 5).mul(3).floor().long())
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLoss2d_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
@ -662,7 +663,7 @@ def nlllossNd_no_reduce_test():
|
||||
|
||||
def nlllossNd_no_reduce_ignore_index_test():
|
||||
t = Variable(torch.rand(2, 5, 5, 2, 2).mul(3).floor().long())
|
||||
kwargs: Dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
kwargs: dict[str, Union[int, str]] = {'ignore_index': 1, 'reduction': 'none'}
|
||||
return dict(
|
||||
fullname='NLLLossNd_no_reduce_ignore_index',
|
||||
constructor=wrap_functional(
|
||||
@ -2671,7 +2672,7 @@ def get_new_module_tests():
|
||||
'Sigmoid', 'SiLU', 'Mish', 'Softplus', 'Softshrink', 'Softsign', 'Tanh',
|
||||
'Tanhshrink', 'Threshold'
|
||||
]
|
||||
non_linear_activations_extra_info: Dict[str, dict] = {
|
||||
non_linear_activations_extra_info: dict[str, dict] = {
|
||||
'CELU': {'constructor_args': (2.,), 'default_dtype': torch.double},
|
||||
'Threshold': {'constructor_args': (2., 1.)},
|
||||
'Hardsigmoid': {'check_gradgrad': False, 'check_jit': False, 'default_dtype': torch.double},
|
||||
@ -3059,7 +3060,7 @@ def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0
|
||||
return output
|
||||
|
||||
|
||||
loss_reference_fns: Dict['str', Callable] = {
|
||||
loss_reference_fns: dict['str', Callable] = {
|
||||
'KLDivLoss': kldivloss_reference,
|
||||
'KLDivLoss_log_target': partial(kldivloss_reference, log_target=True),
|
||||
'NLLLoss': nllloss_reference,
|
||||
@ -3173,7 +3174,7 @@ classification_criterion_no_batch = [
|
||||
),
|
||||
('MultiLabelSoftMarginLoss', lambda: torch.randn(9, dtype=torch.double), lambda: torch.randn(9)),
|
||||
]
|
||||
classification_criterion_no_batch_extra_info: Dict[str, dict] = {
|
||||
classification_criterion_no_batch_extra_info: dict[str, dict] = {
|
||||
'MultiLabelMarginLoss': {'check_gradgrad': False},
|
||||
}
|
||||
# TODO : Fix these discrepancies
|
||||
@ -3209,7 +3210,7 @@ class NNTestCase(TestCase):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_parameters(self, module: nn.Module) -> tuple[List[nn.Parameter], List[nn.Parameter]]:
|
||||
def _get_parameters(self, module: nn.Module) -> tuple[list[nn.Parameter], list[nn.Parameter]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -56,9 +56,9 @@ class OptimizerInput:
|
||||
def __init__(
|
||||
self,
|
||||
params: Union[
|
||||
List[Parameter], List[Tensor], Dict[Any, Any], List[Dict[str, Any]]
|
||||
list[Parameter], list[Tensor], dict[Any, Any], list[dict[str, Any]]
|
||||
],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
desc: str = "",
|
||||
):
|
||||
# params can be a list of Tensors OR param_groups OR None
|
||||
@ -1256,7 +1256,7 @@ def _get_device_type(device: Union[str, torch.device]) -> str:
|
||||
|
||||
def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
device, dtype, optim_info, skip=()
|
||||
) -> List[OptimizerInput]:
|
||||
) -> list[OptimizerInput]:
|
||||
"""
|
||||
Return a list of all configs for a given optimizer as a list of OptimizerInputs,
|
||||
including configs that have supported global cliquey kwargs (foreach, fused,
|
||||
@ -1312,7 +1312,7 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
|
||||
|
||||
|
||||
# Database of OptimizerInfo entries in alphabetical order.
|
||||
optim_db: List[OptimizerInfo] = [
|
||||
optim_db: list[OptimizerInfo] = [
|
||||
OptimizerInfo(
|
||||
Adadelta,
|
||||
optim_inputs_func=optim_inputs_func_adadelta,
|
||||
|
@ -1,16 +1,16 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
from torch.ao.pruning import BaseSparsifier
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
class ImplementedSparsifier(BaseSparsifier):
|
||||
def __init__(self, **kwargs: Dict[str, Any]) -> None:
|
||||
def __init__(self, **kwargs: dict[str, Any]) -> None:
|
||||
super().__init__(defaults=kwargs)
|
||||
|
||||
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: Dict[str, Any]) -> None:
|
||||
def update_mask(self, module: nn.Module, tensor_name: str, **kwargs: dict[str, Any]) -> None:
|
||||
module.parametrizations.weight[0].mask[0] = 0 # type: ignore[index, union-attr]
|
||||
linear_state = self.state['linear1.weight']
|
||||
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
|
||||
|
@ -74,7 +74,7 @@ import os
|
||||
import unittest
|
||||
import numpy as np
|
||||
from torch.testing import FileCheck
|
||||
from typing import Callable, Dict, Any, Union, Type, Optional
|
||||
from typing import Callable, Any, Union, Optional
|
||||
import torch._dynamo as torchdynamo
|
||||
import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq
|
||||
import torch.ao.quantization.quantizer.xpu_inductor_quantizer as xpuiq
|
||||
@ -898,8 +898,8 @@ class QuantizationTestCase(TestCase):
|
||||
|
||||
def assert_types_for_matched_subgraph_pairs(
|
||||
self,
|
||||
matched_subgraph_pairs: Dict[str, tuple[NSSubgraph, NSSubgraph]],
|
||||
expected_types: Dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]],
|
||||
matched_subgraph_pairs: dict[str, tuple[NSSubgraph, NSSubgraph]],
|
||||
expected_types: dict[str, tuple[tuple[Callable, Callable], tuple[Callable, Callable]]],
|
||||
gm_a: GraphModule,
|
||||
gm_b: GraphModule,
|
||||
) -> None:
|
||||
@ -952,7 +952,7 @@ class QuantizationTestCase(TestCase):
|
||||
|
||||
def assert_ns_compare_dict_valid(
|
||||
self,
|
||||
act_compare_dict: Dict[str, Dict[str, Dict[str, Any]]],
|
||||
act_compare_dict: dict[str, dict[str, dict[str, Any]]],
|
||||
) -> None:
|
||||
"""
|
||||
Verifies that the act_compare_dict (output of Numeric Suite APIs) is valid:
|
||||
@ -1214,7 +1214,7 @@ class QuantizationTestCase(TestCase):
|
||||
self.assertTrue(expected_name in str(q_embeddingbag))
|
||||
|
||||
class QuantizationLiteTestCase(QuantizationTestCase):
|
||||
def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs):
|
||||
def _create_quantized_model(self, model_class: type[torch.nn.Module], **kwargs):
|
||||
# Creates quantized model for testing mobile script modules
|
||||
qengine = "qnnpack"
|
||||
with override_quantized_engine(qengine):
|
||||
|
@ -49,15 +49,11 @@ from statistics import mean
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterable, Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import expecttest
|
||||
@ -646,7 +642,7 @@ class parametrize(_TestParametrizer):
|
||||
name_fn (Callable): Optional function that takes in parameters and returns subtest name.
|
||||
"""
|
||||
def __init__(self, arg_str, arg_values, name_fn=None):
|
||||
self.arg_names: List[str] = [s.strip() for s in arg_str.split(',') if s != '']
|
||||
self.arg_names: list[str] = [s.strip() for s in arg_str.split(',') if s != '']
|
||||
self.arg_values = arg_values
|
||||
self.name_fn = name_fn
|
||||
|
||||
@ -689,7 +685,7 @@ class parametrize(_TestParametrizer):
|
||||
for idx, values in enumerate(self.arg_values):
|
||||
maybe_name = None
|
||||
|
||||
decorators: List[Any] = []
|
||||
decorators: list[Any] = []
|
||||
if isinstance(values, subtest):
|
||||
sub = values
|
||||
values = sub.arg_values
|
||||
@ -942,11 +938,7 @@ parser.add_argument('--import-slow-tests', type=str, nargs='?', const=DEFAULT_SL
|
||||
parser.add_argument('--import-disabled-tests', type=str, nargs='?', const=DEFAULT_DISABLED_TESTS_FILE)
|
||||
parser.add_argument('--rerun-disabled-tests', action='store_true')
|
||||
parser.add_argument('--pytest-single-test', type=str, nargs=1)
|
||||
if sys.version_info >= (3, 9):
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
else:
|
||||
parser.add_argument('--showlocals', action='store_true', default=False)
|
||||
parser.add_argument('--no-showlocals', dest='showlocals', action='store_false')
|
||||
parser.add_argument('--showlocals', action=argparse.BooleanOptionalAction, default=False)
|
||||
|
||||
# Only run when -h or --help flag is active to display both unittest and parser help messages.
|
||||
def run_unittest_help(argv):
|
||||
@ -1173,10 +1165,10 @@ def sanitize_pytest_xml(xml_file: str):
|
||||
tree.write(xml_file)
|
||||
|
||||
|
||||
def get_pytest_test_cases(argv: List[str]) -> List[str]:
|
||||
def get_pytest_test_cases(argv: list[str]) -> list[str]:
|
||||
class TestCollectorPlugin:
|
||||
def __init__(self) -> None:
|
||||
self.tests: List[Any] = []
|
||||
self.tests: list[Any] = []
|
||||
|
||||
def pytest_collection_finish(self, session):
|
||||
for item in session.items:
|
||||
@ -2637,7 +2629,7 @@ def check_if_enable(test: unittest.TestCase):
|
||||
|
||||
for disabled_test, (issue_url, platforms) in disabled_tests_dict.items():
|
||||
if matches_test(disabled_test):
|
||||
platform_to_conditional: Dict = {
|
||||
platform_to_conditional: dict = {
|
||||
"mac": IS_MACOS,
|
||||
"macos": IS_MACOS,
|
||||
"win": IS_WINDOWS,
|
||||
@ -2706,7 +2698,7 @@ class RelaxedBooleanPair(BooleanPair):
|
||||
def _process_inputs(self, actual, expected, *, id):
|
||||
# We require only one of the inputs of the inputs to be a boolean and the other can also be a boolean, a
|
||||
# number, or a single element tensor or array, whereas in default BooleanPair both inputs have to be booleans.
|
||||
tensor_or_array_types: tuple[Type, ...] = (torch.Tensor, np.ndarray)
|
||||
tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
|
||||
other_supported_types = (*self._supported_types, *self._supported_number_types, *tensor_or_array_types)
|
||||
if not (
|
||||
(isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
|
||||
@ -2763,7 +2755,7 @@ class RelaxedNumberPair(NumberPair):
|
||||
def _process_inputs(self, actual, expected, *, id):
|
||||
# We require only one of the inputs of the inputs to be a number and the other can also be a number or a single
|
||||
# element tensor or array, whereas in default NumberPair both inputs have to be numbers.
|
||||
tensor_or_array_types: tuple[Type, ...] = (torch.Tensor, np.ndarray)
|
||||
tensor_or_array_types: tuple[type, ...] = (torch.Tensor, np.ndarray)
|
||||
other_supported_types = (*self._supported_types, *tensor_or_array_types)
|
||||
if not (
|
||||
(isinstance(actual, self._supported_types) and isinstance(expected, other_supported_types))
|
||||
@ -2849,7 +2841,7 @@ class UnittestPair(Pair):
|
||||
|
||||
Define the :attr:`UnittestPair.CLS` in a subclass to indicate which class(es) of the inputs the pair should support.
|
||||
"""
|
||||
CLS: Union[Type, tuple[Type, ...]]
|
||||
CLS: Union[type, tuple[type, ...]]
|
||||
TYPE_NAME: Optional[str] = None
|
||||
|
||||
def __init__(self, actual, expected, **other_parameters):
|
||||
@ -5072,8 +5064,8 @@ def get_tensors_from(args, kwargs):
|
||||
|
||||
|
||||
# Returns scalar tensor representation of a list of integer byte values
|
||||
def bytes_to_scalar(byte_list: List[int], dtype: torch.dtype, device: torch.device):
|
||||
dtype_to_ctype: Dict[torch.dtype, Any] = {
|
||||
def bytes_to_scalar(byte_list: list[int], dtype: torch.dtype, device: torch.device):
|
||||
dtype_to_ctype: dict[torch.dtype, Any] = {
|
||||
torch.int8: ctypes.c_int8,
|
||||
torch.uint8: ctypes.c_uint8,
|
||||
torch.uint16: ctypes.c_uint16,
|
||||
|
@ -10,13 +10,10 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -391,7 +388,7 @@ def with_comms(eager_init: Union[TestFunc, bool] = False) -> TestFunc:
|
||||
|
||||
@wraps(func) # pyre-ignore[6]
|
||||
def wrapper(
|
||||
self, *args: tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
|
||||
self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc]
|
||||
) -> None:
|
||||
# if enough GPU we can use GPU, otherwise we fallback to CPU
|
||||
if not TEST_CUDA or torch.cuda.device_count() < self.world_size:
|
||||
@ -437,7 +434,7 @@ class DTensorConverter:
|
||||
self,
|
||||
mesh: DeviceMesh,
|
||||
args: tuple[object, ...],
|
||||
kwargs: Dict[str, object],
|
||||
kwargs: dict[str, object],
|
||||
) -> None:
|
||||
self.hit = 0
|
||||
self.miss = 0
|
||||
@ -447,9 +444,9 @@ class DTensorConverter:
|
||||
flatten_args, flatten_args_spec = tree_flatten(args)
|
||||
flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
|
||||
|
||||
self.flatten_args: List[object] = flatten_args
|
||||
self.flatten_args: list[object] = flatten_args
|
||||
self.flatten_args_spec: TreeSpec = flatten_args_spec
|
||||
self.flatten_kwargs: List[object] = flatten_kwargs
|
||||
self.flatten_kwargs: list[object] = flatten_kwargs
|
||||
self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
|
||||
|
||||
choices_for_args = [self.gen_sharding_choices_for_arg(arg) for arg in self.flatten_args if isinstance(arg, torch.Tensor)]
|
||||
@ -490,7 +487,7 @@ class DTensorConverter:
|
||||
|
||||
def gen_sharding_choices_for_arg(self, arg: torch.Tensor) -> Sequence[Placement]:
|
||||
mesh_size = self.mesh.size()
|
||||
sharding_choices: List[Placement] = [Replicate()]
|
||||
sharding_choices: list[Placement] = [Replicate()]
|
||||
# c10d collective does not support bool tensor
|
||||
# for bool tensor we treat it as replicated
|
||||
if arg.dtype != torch.bool:
|
||||
@ -510,12 +507,12 @@ class DTensorConverter:
|
||||
def __iter__(self) -> "DTensorConverter":
|
||||
return self
|
||||
|
||||
def __next__(self) -> tuple[tuple[object, ...], Dict[str, object]]:
|
||||
def __next__(self) -> tuple[tuple[object, ...], dict[str, object]]:
|
||||
try:
|
||||
next_sharding_choices = next(self.sharding_combs)
|
||||
idx = 0
|
||||
|
||||
new_args: List[object] = []
|
||||
new_args: list[object] = []
|
||||
for arg in self.flatten_args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
new_args.append(
|
||||
@ -527,7 +524,7 @@ class DTensorConverter:
|
||||
else:
|
||||
new_args.append(arg)
|
||||
|
||||
new_kwargs: List[object] = []
|
||||
new_kwargs: list[object] = []
|
||||
for arg in self.flatten_kwargs:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
new_kwargs.append(
|
||||
@ -547,7 +544,7 @@ class DTensorConverter:
|
||||
raise StopIteration from e
|
||||
|
||||
def to_dist_tensor(
|
||||
self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
|
||||
self, t: torch.Tensor, mesh: DeviceMesh, placements: list[Placement]
|
||||
) -> torch.Tensor:
|
||||
if type(t) is torch.Tensor or type(t) is nn.Parameter:
|
||||
if self.is_supported_tensor(t):
|
||||
|
@ -7,7 +7,7 @@ import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, cast, Dict, IO, Optional
|
||||
from typing import Any, Callable, cast, IO, Optional
|
||||
|
||||
# introduced as collections.abc.Buffer in Python 3.12
|
||||
from typing_extensions import Buffer
|
||||
@ -130,7 +130,7 @@ def with_temp_dir(
|
||||
assert func is not None
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args: tuple[object], **kwargs: Dict[str, Any]) -> None:
|
||||
def wrapper(self, *args: tuple[object], **kwargs: dict[str, Any]) -> None:
|
||||
if dist.is_initialized():
|
||||
# Only create temp_dir when rank is 0
|
||||
if dist.get_rank() == 0:
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
import copy
|
||||
from itertools import chain
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -32,8 +32,8 @@ class VerifyStateDictMixin:
|
||||
|
||||
def _verify_msd(
|
||||
self,
|
||||
msd: Dict[str, Any],
|
||||
dist_msd: Dict[str, Any],
|
||||
msd: dict[str, Any],
|
||||
dist_msd: dict[str, Any],
|
||||
options: StateDictOptions = StateDictOptions(),
|
||||
offload_to_cpu=False,
|
||||
) -> None:
|
||||
@ -56,8 +56,8 @@ class VerifyStateDictMixin:
|
||||
self,
|
||||
model: nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
osd: Dict[str, Any],
|
||||
dist_osd: Dict[str, Any],
|
||||
osd: dict[str, Any],
|
||||
dist_osd: dict[str, Any],
|
||||
) -> None:
|
||||
params = list(chain.from_iterable(g["params"] for g in optim.param_groups))
|
||||
param_pid_mapping = dict(zip(params, range(len(params))))
|
||||
@ -110,7 +110,7 @@ class VerifyStateDictMixin:
|
||||
model: nn.Module,
|
||||
optim: torch.optim.Optimizer,
|
||||
new_optim: torch.optim.Optimizer,
|
||||
dist_osd: Dict[str, Any],
|
||||
dist_osd: dict[str, Any],
|
||||
) -> None:
|
||||
new_dist_osd = _gather_state_dict(dist_osd)
|
||||
set_state_dict(
|
||||
|
@ -3,7 +3,7 @@
|
||||
import sys
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
from functools import partial, reduce
|
||||
|
||||
import torch
|
||||
@ -93,7 +93,7 @@ class AllToAllBase:
|
||||
input_buffer[input_indexes[dest_rank]:input_indexes[dest_rank + 1]]
|
||||
)
|
||||
|
||||
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, List[int], None], world_size: int) -> torch.Tensor:
|
||||
def _size_cumsum(self, buf_size: int, sizes: Union[torch.Tensor, list[int], None], world_size: int) -> torch.Tensor:
|
||||
if sizes is None or len(sizes) == 0:
|
||||
sizes = torch.full(
|
||||
(world_size,), buf_size // world_size, dtype=torch.int64
|
||||
@ -316,8 +316,8 @@ class ProcessLocalGroup(dist.ProcessGroup):
|
||||
self,
|
||||
output_buffer: torch.Tensor,
|
||||
input_buffer: torch.Tensor,
|
||||
output_split_sizes: Optional[List[int]],
|
||||
input_split_sizes: Optional[List[int]],
|
||||
output_split_sizes: Optional[list[int]],
|
||||
input_split_sizes: Optional[list[int]],
|
||||
opts=AllToAllOptions()
|
||||
) -> torch.Tensor:
|
||||
coll = ProcessLocalGroup._start_coll(AllToAllBase(), self)
|
||||
@ -455,14 +455,14 @@ dist.Backend.register_backend("threaded", _create_threaded_pg, devices=["cpu", "
|
||||
@dataclass
|
||||
class WorldData:
|
||||
default_pg: dist.ProcessGroup
|
||||
pg_map: Dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
|
||||
pg_names: Dict[dist.ProcessGroup, str]
|
||||
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
|
||||
pg_backend_config: Dict[dist.ProcessGroup, str]
|
||||
pg_map: dict[dist.ProcessGroup, tuple[str, Optional[Store]]]
|
||||
pg_names: dict[dist.ProcessGroup, str]
|
||||
pg_group_ranks: dict[dist.ProcessGroup, dict[int, int]]
|
||||
pg_backend_config: dict[dist.ProcessGroup, str]
|
||||
group_count: int
|
||||
tags_to_pg: Dict[str, List[dist.ProcessGroup]]
|
||||
pg_to_tag: Dict[dist.ProcessGroup, str]
|
||||
pg_coalesce_state: Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]
|
||||
tags_to_pg: dict[str, list[dist.ProcessGroup]]
|
||||
pg_to_tag: dict[dist.ProcessGroup, str]
|
||||
pg_coalesce_state: dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]
|
||||
|
||||
|
||||
class ThreadLocalWorld:
|
||||
@ -514,7 +514,7 @@ class ThreadLocalWorld:
|
||||
return self._get_world().pg_to_tag
|
||||
|
||||
@property
|
||||
def pg_coalesce_state(self) -> Dict[dist.ProcessGroup, List[Union[_CollOp, P2POp]]]:
|
||||
def pg_coalesce_state(self) -> dict[dist.ProcessGroup, list[Union[_CollOp, P2POp]]]:
|
||||
return self._get_world().pg_coalesce_state
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
@ -35,7 +34,7 @@ class JitDistAutogradTest(RpcAgentTestFixture):
|
||||
def test_get_gradients(self):
|
||||
|
||||
@torch.jit.script
|
||||
def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]):
|
||||
def dist_get_gradients(context_id: int) -> (dict[Tensor, Tensor]):
|
||||
return dist_autograd.get_gradients(context_id)
|
||||
|
||||
FileCheck().check("get_gradients").run(str(dist_get_gradients.graph))
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import time
|
||||
import io
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -42,13 +42,13 @@ def rref_local_value(rref: RRef[Tensor]) -> Tensor:
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def list_create() -> List[int]:
|
||||
def list_create() -> list[int]:
|
||||
global_list = [1, 2, 3]
|
||||
return global_list
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def rref_list_mutate(rref: RRef[List[int]]) -> None:
|
||||
def rref_list_mutate(rref: RRef[list[int]]) -> None:
|
||||
rref.local_value().append(4)
|
||||
rref.to_here().append(5)
|
||||
rref.to_here(5.0).append(6)
|
||||
@ -435,7 +435,7 @@ class LocalRRefTest:
|
||||
|
||||
def use_rref_on_owner(rref: RRef[MyScriptClass]) -> int:
|
||||
args = (rref,)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
fut = rpc.rpc_async(
|
||||
rref.owner(), script_rref_get_value_my_script_class, args, kwargs
|
||||
)
|
||||
@ -465,7 +465,7 @@ class LocalRRefTest:
|
||||
|
||||
def use_rref_on_owner(rref: RRef[MyModuleInterface]) -> Tensor:
|
||||
args = (rref,)
|
||||
kwargs: Dict[str, Any] = {}
|
||||
kwargs: dict[str, Any] = {}
|
||||
fut = rpc.rpc_async(
|
||||
rref.owner_name(),
|
||||
script_rref_run_forward_my_script_module,
|
||||
@ -518,7 +518,7 @@ def raise_script():
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_async_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
ret = fut.wait()
|
||||
@ -526,14 +526,14 @@ def script_rpc_async_call(
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_sync_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
):
|
||||
res = rpc.rpc_sync(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return res
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_remote_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
):
|
||||
rref_res = rpc.remote(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return rref_res.to_here()
|
||||
@ -607,7 +607,7 @@ class JitRpcOpTest:
|
||||
# The error JIT gives is,
|
||||
# "Dict values must contain only a single type, "
|
||||
# "expected: Tensor but found str instead."
|
||||
kwargs: Dict[str, Any] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"tensor_kwarg": torch.tensor([3, 3]),
|
||||
"str_kwarg": "_str_kwarg",
|
||||
"int_kwarg": 3,
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.distributed.rpc as rpc
|
||||
@ -28,7 +27,7 @@ def two_args_two_kwargs(
|
||||
|
||||
@torch.jit.script
|
||||
def script_rpc_async_call(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
ret = fut.wait()
|
||||
@ -39,7 +38,7 @@ def script_rpc_async_call(
|
||||
def rpc_async_call_with_timeout(
|
||||
dst_worker_name: str,
|
||||
args: tuple[Tensor, Tensor],
|
||||
kwargs: Dict[str, Tensor],
|
||||
kwargs: dict[str, Tensor],
|
||||
timeout: float,
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
|
||||
@ -51,7 +50,7 @@ def rpc_async_call_with_timeout(
|
||||
def rpc_async_call_with_timeout_future_ret(
|
||||
dst_worker_name: str,
|
||||
args: tuple[Tensor, Tensor],
|
||||
kwargs: Dict[str, Tensor],
|
||||
kwargs: dict[str, Tensor],
|
||||
timeout: float,
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs, timeout)
|
||||
@ -60,7 +59,7 @@ def rpc_async_call_with_timeout_future_ret(
|
||||
|
||||
@torch.jit.script
|
||||
def rpc_async_call_future_ret(
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: Dict[str, Tensor]
|
||||
dst_worker_name: str, args: tuple[Tensor, Tensor], kwargs: dict[str, Tensor]
|
||||
):
|
||||
fut = rpc.rpc_async(dst_worker_name, two_args_two_kwargs, args, kwargs)
|
||||
return fut
|
||||
|
@ -3,7 +3,6 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -148,10 +147,10 @@ FAULTY_AGENT_TESTS = [
|
||||
|
||||
def generate_tests(
|
||||
prefix: str,
|
||||
mixin: Type[RpcAgentTestFixture],
|
||||
tests: List[Type[RpcAgentTestFixture]],
|
||||
mixin: type[RpcAgentTestFixture],
|
||||
tests: list[type[RpcAgentTestFixture]],
|
||||
module_name: str,
|
||||
) -> Dict[str, Type[RpcAgentTestFixture]]:
|
||||
) -> dict[str, type[RpcAgentTestFixture]]:
|
||||
"""Mix in the classes needed to autogenerate the tests based on the params.
|
||||
|
||||
Takes a series of test suites, each written against a "generic" agent (i.e.,
|
||||
@ -166,7 +165,7 @@ def generate_tests(
|
||||
that the classes can be fixed to make it look like they belong to it, which
|
||||
is necessary for pickling to work on them.
|
||||
"""
|
||||
ret: Dict[str, Type[RpcAgentTestFixture]] = {}
|
||||
ret: dict[str, type[RpcAgentTestFixture]] = {}
|
||||
for test_class in tests:
|
||||
if IS_SANDCASTLE and TEST_WITH_DEV_DBG_ASAN:
|
||||
print(
|
||||
|
@ -13,7 +13,7 @@ from torch.testing._internal.common_utils import is_iterable_of_tensors, noncont
|
||||
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
import math # noqa: F401
|
||||
|
||||
# Testing utils
|
||||
@ -361,9 +361,9 @@ def get_constant(x):
|
||||
return x
|
||||
|
||||
def get_script_args(args):
|
||||
formals: List[str] = []
|
||||
tensors: List[Union[torch.Tensor, List[torch.Tensor]]] = []
|
||||
actuals: List[str] = []
|
||||
formals: list[str] = []
|
||||
tensors: list[Union[torch.Tensor, list[torch.Tensor]]] = []
|
||||
actuals: list[str] = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
name = f'i{len(formals)}'
|
||||
@ -405,14 +405,14 @@ def create_script_fn(self, method_name, func_type):
|
||||
return script_fn
|
||||
|
||||
class SplitInputs:
|
||||
all_tensors: List[Any]
|
||||
tensor_args: List[Any]
|
||||
nontensor_args: List[Any]
|
||||
arg_types: List[str]
|
||||
tensor_kwargs: Dict[str, Any]
|
||||
kwarg_order: List[str]
|
||||
nontensor_kwargs: Dict[str, Any]
|
||||
kwarg_types: Dict[str, Any]
|
||||
all_tensors: list[Any]
|
||||
tensor_args: list[Any]
|
||||
nontensor_args: list[Any]
|
||||
arg_types: list[str]
|
||||
tensor_kwargs: dict[str, Any]
|
||||
kwarg_order: list[str]
|
||||
nontensor_kwargs: dict[str, Any]
|
||||
kwarg_types: dict[str, Any]
|
||||
|
||||
@staticmethod
|
||||
def _is_tensor_input(arg):
|
||||
|
@ -39,7 +39,7 @@ import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
from importlib.abc import Loader
|
||||
from typing import Any, Dict, List, Union
|
||||
from typing import Any, Union
|
||||
|
||||
RUN_CUDA = torch.cuda.is_available()
|
||||
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
|
||||
@ -176,7 +176,7 @@ class JitTestCase(JitCommonTestCase):
|
||||
allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
|
||||
'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
|
||||
|
||||
fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list)
|
||||
fusion_groups : dict[torch._C.Block, list[torch._C.Node]] = defaultdict(list)
|
||||
get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
|
||||
self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
|
||||
(graph, fusion_nodes) = next(iter(fusion_groups.items()))
|
||||
@ -385,7 +385,7 @@ class JitTestCase(JitCommonTestCase):
|
||||
if not frame:
|
||||
raise RuntimeError("failed to get frame")
|
||||
i += 1
|
||||
defined_vars: Dict[str, Any] = {}
|
||||
defined_vars: dict[str, Any] = {}
|
||||
defined_vars.update(frame.f_locals)
|
||||
defined_vars.update(frame.f_globals)
|
||||
return defined_vars
|
||||
@ -407,7 +407,7 @@ class JitTestCase(JitCommonTestCase):
|
||||
with self.assertRaisesRegex(exception, regex):
|
||||
if isinstance(script, str):
|
||||
frame = self.get_frame_vars(frames_up)
|
||||
the_locals: Dict[str, Any] = {}
|
||||
the_locals: dict[str, Any] = {}
|
||||
execWrapper(script, glob=frame, loc=the_locals)
|
||||
frame.update(the_locals)
|
||||
|
||||
@ -471,7 +471,7 @@ class JitTestCase(JitCommonTestCase):
|
||||
# outputs
|
||||
|
||||
frame = self.get_frame_vars(frames_up)
|
||||
the_locals: Dict[str, Any] = {}
|
||||
the_locals: dict[str, Any] = {}
|
||||
execWrapper(script, glob=frame, loc=the_locals)
|
||||
frame.update(the_locals)
|
||||
|
||||
@ -796,7 +796,7 @@ class TensorExprTestOptions:
|
||||
torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
|
||||
|
||||
def clone_inputs(args):
|
||||
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
|
||||
inputs: list[Union[torch.Tensor, list[torch.Tensor]]] = []
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
@ -810,7 +810,7 @@ def clone_inputs(args):
|
||||
|
||||
def get_traced_sample_variant_pairs(device, dtype, op):
|
||||
# tuples of (variant, sample)
|
||||
outputs: List[tuple[Any, Any]] = []
|
||||
outputs: list[tuple[Any, Any]] = []
|
||||
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
|
||||
|
@ -2,7 +2,8 @@
|
||||
|
||||
import torch
|
||||
from torch.utils._pytree import tree_map
|
||||
from typing import Iterator, List, Optional
|
||||
from typing import Optional
|
||||
from collections.abc import Iterator
|
||||
import logging
|
||||
import contextlib
|
||||
import itertools
|
||||
@ -101,8 +102,8 @@ class LoggingTensorReentrant(LoggingTensor):
|
||||
# https://stackoverflow.com/questions/36408496/python-logging-handler-to-append-to-list
|
||||
class LoggingTensorHandler(logging.Handler):
|
||||
def __init__(
|
||||
self, log_list: List[str], use_shortid_for_all_tensors: bool,
|
||||
with_type: bool, tracebacks_list: Optional[List]) -> None:
|
||||
self, log_list: list[str], use_shortid_for_all_tensors: bool,
|
||||
with_type: bool, tracebacks_list: Optional[list]) -> None:
|
||||
logging.Handler.__init__(self)
|
||||
self.log_list = log_list
|
||||
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors
|
||||
@ -154,10 +155,10 @@ class GatherTraceback(logging.Filter):
|
||||
return True
|
||||
|
||||
@contextlib.contextmanager
|
||||
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
|
||||
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[list[str]]:
|
||||
collect_traceback = python_tb or script_tb or cpp_tb
|
||||
log_list: List[str] = []
|
||||
tracebacks_list: List[str] = []
|
||||
log_list: list[str] = []
|
||||
tracebacks_list: list[str] = []
|
||||
handler = LoggingTensorHandler(
|
||||
log_list,
|
||||
with_type=True,
|
||||
|
@ -8,11 +8,12 @@ import math
|
||||
import operator
|
||||
import unittest
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
@ -689,10 +690,10 @@ class OpInfo:
|
||||
# the following metadata are test directives for skipping or modifying tests
|
||||
|
||||
# information about which tests to skip
|
||||
skips: Tuple = ()
|
||||
skips: tuple = ()
|
||||
|
||||
# decorators to apply to generated tests
|
||||
decorators: Tuple = ()
|
||||
decorators: tuple = ()
|
||||
|
||||
# the following are pointers to functions to generate certain classes of inputs
|
||||
|
||||
@ -803,11 +804,11 @@ class OpInfo:
|
||||
|
||||
# If `supports_cow_input_no_materialize_forward == True`, this list contains
|
||||
# the arg indices or kwarg names of inputs that are expected to materialize
|
||||
allow_cow_input_materialize_forward: List[Union[int, str]] = None
|
||||
allow_cow_input_materialize_forward: list[Union[int, str]] = None
|
||||
|
||||
# If `supports_cow_input_no_materialize_backward == True`, this list contains
|
||||
# the arg indices or kwarg names of inputs that are expected to materialize
|
||||
allow_cow_input_materialize_backward: List[Union[int, str]] = None
|
||||
allow_cow_input_materialize_backward: list[Union[int, str]] = None
|
||||
|
||||
# wrapper function for gradcheck
|
||||
gradcheck_wrapper: Callable = lambda op, *args, **kwargs: op(*args, **kwargs)
|
||||
@ -853,13 +854,13 @@ class OpInfo:
|
||||
# a list of strings with node names that are expected to be in a
|
||||
# DifferentiableGraph when autodiffed. Ex: ['aten::add', 'aten::mm'],
|
||||
# default is populated to be ['aten::(name of Python operator)']
|
||||
autodiff_nonfusible_nodes: List[str] = None
|
||||
autodiff_nonfusible_nodes: list[str] = None
|
||||
|
||||
# a list of strings with node names that are expected to be in FusionGroups
|
||||
# inside of DifferentiableGraphs when this operation is autodiffed.
|
||||
# Ex: ['aten::add', 'aten::mm'], defaults to an empty list
|
||||
# Note: currently no ops use fusible nodes
|
||||
autodiff_fusible_nodes: List[str] = None
|
||||
autodiff_fusible_nodes: list[str] = None
|
||||
|
||||
# the following metadata relates to sparse support and is used in test_sparse.py
|
||||
|
||||
|
@ -13,7 +13,7 @@ from torch.testing._internal.opinfo.definitions import (
|
||||
|
||||
|
||||
# Operator database
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
*fft.op_db,
|
||||
*linalg.op_db,
|
||||
*signal.op_db,
|
||||
@ -21,7 +21,7 @@ op_db: List[OpInfo] = [
|
||||
*_masked.op_db,
|
||||
]
|
||||
|
||||
python_ref_db: List[OpInfo] = [
|
||||
python_ref_db: list[OpInfo] = [
|
||||
*fft.python_ref_db,
|
||||
*linalg.python_ref_db,
|
||||
*special.python_ref_db,
|
||||
|
@ -3,7 +3,6 @@
|
||||
import unittest
|
||||
from collections.abc import Sequence
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -424,7 +423,7 @@ def sample_inputs_masked_normalize(op_info, device, dtype, requires_grad, **kwar
|
||||
)
|
||||
|
||||
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
ReductionOpInfo(
|
||||
"masked.sum",
|
||||
ref=reference_reduction_numpy(np.sum),
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
import unittest
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -117,7 +116,7 @@ def sample_inputs_fftshift(op_info, device, dtype, requires_grad, **kwargs):
|
||||
|
||||
|
||||
# Operator database
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
SpectralFuncInfo(
|
||||
"fft.fft",
|
||||
aten_name="fft_fft",
|
||||
@ -634,7 +633,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
python_ref_db: List[OpInfo] = [
|
||||
python_ref_db: list[OpInfo] = [
|
||||
SpectralFuncPythonRefInfo(
|
||||
"_refs.fft.fft",
|
||||
torch_opinfo_name="fft.fft",
|
||||
|
@ -3,9 +3,9 @@
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from itertools import chain, product
|
||||
from typing import Iterable, List
|
||||
|
||||
import numpy as np
|
||||
from numpy import inf
|
||||
@ -1169,7 +1169,7 @@ def sample_inputs_tensorinv(op_info, device, dtype, requires_grad, **kwargs):
|
||||
yield SampleInput(inp, ind=len(shape_lhs))
|
||||
|
||||
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
OpInfo(
|
||||
"linalg.cross",
|
||||
ref=lambda x, y, dim=-1: np.cross(x, y, axis=dim),
|
||||
@ -2408,7 +2408,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
python_ref_db: List[OpInfo] = [
|
||||
python_ref_db: list[OpInfo] = [
|
||||
#
|
||||
# torch.linalg
|
||||
#
|
||||
|
@ -4,7 +4,7 @@ import math
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
||||
@ -41,7 +41,7 @@ class ExtraOpData:
|
||||
# each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
|
||||
#
|
||||
# If no overload of the op accepts dim-related args, this should be None.
|
||||
dim_args: List[List[str]] = None
|
||||
dim_args: list[list[str]] = None
|
||||
|
||||
# Helper function to extract names of dim-related args.
|
||||
# Returns: tuple of (single dim argname if available, dim list argname if available)
|
||||
|
@ -3,7 +3,7 @@
|
||||
import unittest
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import Callable, List
|
||||
from typing import Callable
|
||||
|
||||
import numpy
|
||||
|
||||
@ -345,7 +345,7 @@ def make_signal_windows_opinfo(
|
||||
)
|
||||
|
||||
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
make_signal_windows_opinfo(
|
||||
name="signal.windows.hamming",
|
||||
ref=reference_signal_window(scipy.signal.windows.hamming)
|
||||
|
@ -3,7 +3,6 @@
|
||||
import unittest
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -119,7 +118,7 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs):
|
||||
)
|
||||
|
||||
|
||||
op_db: List[OpInfo] = [
|
||||
op_db: list[OpInfo] = [
|
||||
UnaryUfuncInfo(
|
||||
"special.i0e",
|
||||
aten_name="special_i0e",
|
||||
@ -702,7 +701,7 @@ op_db: List[OpInfo] = [
|
||||
),
|
||||
]
|
||||
|
||||
python_ref_db: List[OpInfo] = [
|
||||
python_ref_db: list[OpInfo] = [
|
||||
#
|
||||
# Elementwise Unary Special OpInfos
|
||||
#
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
import collections
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from functools import partial, wraps
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
@ -10,7 +10,8 @@ import re
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
@ -46,7 +47,7 @@ def is_abstract(tensor: torch.Tensor) -> bool:
|
||||
def safe_schema_check(
|
||||
op: torch._ops.OpOverload,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
) -> Any:
|
||||
@ -62,7 +63,7 @@ def safe_schema_check(
|
||||
def safe_autograd_registration_check(
|
||||
op: torch._ops.OpOverload,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
) -> None:
|
||||
@ -81,7 +82,7 @@ def safe_autograd_registration_check(
|
||||
def safe_fake_check(
|
||||
op: torch._ops.OpOverload,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
) -> None:
|
||||
@ -95,7 +96,7 @@ def safe_fake_check(
|
||||
def safe_aot_autograd_check(
|
||||
op: torch._ops.OpOverload,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
dynamic: bool,
|
||||
*,
|
||||
copy_inputs: bool = True,
|
||||
@ -155,10 +156,10 @@ DEPRECATED_DEFAULT_TEST_UTILS = DEFAULT_TEST_UTILS + [
|
||||
|
||||
def generate_opcheck_tests(
|
||||
testcase: Any,
|
||||
namespaces: List[str],
|
||||
namespaces: list[str],
|
||||
failures_dict_path: Optional[str] = None,
|
||||
additional_decorators: Optional[Dict[str, Callable]] = None,
|
||||
test_utils: List[str] = DEFAULT_TEST_UTILS,
|
||||
additional_decorators: Optional[dict[str, Callable]] = None,
|
||||
test_utils: list[str] = DEFAULT_TEST_UTILS,
|
||||
) -> None:
|
||||
"""Given an existing TestCase, use the existing tests to generate
|
||||
additional validation tests for custom operators.
|
||||
@ -361,7 +362,7 @@ def validate_failures_dict_formatting(failures_dict_path: str) -> None:
|
||||
|
||||
|
||||
def validate_failures_dict_structure(
|
||||
failure_dict: "FailuresDict", test_utils: List[str], testcase: Any
|
||||
failure_dict: "FailuresDict", test_utils: list[str], testcase: Any
|
||||
) -> None:
|
||||
"""Validates the failures dict.
|
||||
|
||||
@ -447,7 +448,7 @@ class OpCheckMode(TorchFunctionMode):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
namespaces: List[str],
|
||||
namespaces: list[str],
|
||||
test_util_name: str,
|
||||
test_util: Callable,
|
||||
failures_dict: "FailuresDict",
|
||||
@ -619,11 +620,11 @@ def should_print_better_repro() -> None:
|
||||
def opcheck(
|
||||
op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
test_utils: Union[str, Sequence[str]] = DEFAULT_TEST_UTILS,
|
||||
raise_exception: bool = True,
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""See torch.library.opcheck for docstring"""
|
||||
|
||||
if kwargs is None:
|
||||
@ -673,7 +674,7 @@ def generate_repro(
|
||||
test: str,
|
||||
op: torch._ops.OpOverload,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
save_data: bool,
|
||||
dry_run: bool = False,
|
||||
@ -738,7 +739,7 @@ def resolve_unique_overload_or_throw(
|
||||
DUMP_OPTIONS = {"indent": 2, "sort_keys": True}
|
||||
|
||||
|
||||
FailuresDictData = Dict[str, Dict[str, Dict[str, str]]]
|
||||
FailuresDictData = dict[str, dict[str, dict[str, str]]]
|
||||
|
||||
|
||||
VERSION = 1
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: ignore-errors
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -68,7 +68,7 @@ class WrapperSubclass(torch.Tensor):
|
||||
return return_and_correct_aliasing(func, args, kwargs, out)
|
||||
|
||||
def __coerce_same_metadata_as_tangent__(
|
||||
self, expected_metadata: Any, expected_type: Optional[Type] = None
|
||||
self, expected_metadata: Any, expected_type: Optional[type] = None
|
||||
):
|
||||
if expected_type == type(self.a):
|
||||
return self.a
|
||||
|
Reference in New Issue
Block a user