diff --git a/functorch/dim/__init__.py b/functorch/dim/__init__.py index cc620c94e699..691b1b984f8d 100644 --- a/functorch/dim/__init__.py +++ b/functorch/dim/__init__.py @@ -1,8 +1,3 @@ -import dis -import inspect -from collections.abc import Sequence -from typing import Union - import functorch._C import torch from functorch._C import dim as _C diff --git a/functorch/einops/_parsing.py b/functorch/einops/_parsing.py index 0ef9dff72a52..2352ea932426 100644 --- a/functorch/einops/_parsing.py +++ b/functorch/einops/_parsing.py @@ -27,7 +27,7 @@ from __future__ import annotations import keyword import warnings -from typing import List, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union if TYPE_CHECKING: @@ -73,11 +73,11 @@ class ParsedExpression: """ self.has_ellipsis: bool = False self.has_ellipsis_parenthesized: Optional[bool] = None - self.identifiers: Set[Union[str, AnonymousAxis]] = set() + self.identifiers: set[Union[str, AnonymousAxis]] = set() # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition self.has_non_unitary_anonymous_axes: bool = False # composition keeps structure of composite axes, see how different corner cases are handled in tests - self.composition: List[Union[List[Union[str, AnonymousAxis]], str]] = [] + self.composition: list[Union[list[Union[str, AnonymousAxis]], str]] = [] if "." in expression: if "..." not in expression: raise ValueError( @@ -90,7 +90,7 @@ class ParsedExpression: expression = expression.replace("...", _ellipsis) self.has_ellipsis = True - bracket_group: Optional[List[Union[str, AnonymousAxis]]] = None + bracket_group: Optional[list[Union[str, AnonymousAxis]]] = None def add_axis_name(x: str) -> None: if x in self.identifiers: @@ -164,7 +164,7 @@ class ParsedExpression: @staticmethod def check_axis_name_return_reason( name: str, allow_underscore: bool = False - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the given axis name is valid, and a message explaining why if not. Valid axes names are python identifiers except keywords, and should not start or end with an underscore. @@ -174,7 +174,7 @@ class ParsedExpression: allow_underscore (bool): whether axis names are allowed to start with an underscore Returns: - Tuple[bool, str]: whether the axis name is valid, a message explaining why if not + tuple[bool, str]: whether the axis name is valid, a message explaining why if not """ if not str.isidentifier(name): return False, "not a valid python identifier" @@ -211,7 +211,7 @@ class ParsedExpression: def parse_pattern( pattern: str, axes_lengths: Mapping[str, int] -) -> Tuple[ParsedExpression, ParsedExpression]: +) -> tuple[ParsedExpression, ParsedExpression]: """Parse an `einops`-style pattern into a left-hand side and right-hand side `ParsedExpression` object. Args: @@ -219,7 +219,7 @@ def parse_pattern( axes_lengths (Mapping[str, int]): any additional length specifications for dimensions Returns: - Tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions + tuple[ParsedExpression, ParsedExpression]: a tuple containing the left-hand side and right-hand side expressions """ # adapted from einops.einops._prepare_transformation_recipe # https://github.com/arogozhnikov/einops/blob/230ac1526c1f42c9e1f7373912c7f8047496df11/einops/einops.py diff --git a/functorch/einops/rearrange.py b/functorch/einops/rearrange.py index 02c27f432cba..d7d71f5103f9 100644 --- a/functorch/einops/rearrange.py +++ b/functorch/einops/rearrange.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import Callable, Dict, List, Tuple, TYPE_CHECKING, Union +from typing import Callable, TYPE_CHECKING, Union import torch from functorch._C import dim as _C @@ -18,7 +18,6 @@ from ._parsing import ( if TYPE_CHECKING: from collections.abc import Sequence - __all__ = ["rearrange"] dims = _C.dims @@ -69,9 +68,9 @@ def _create_rearrange_callable( # an identity rearrangement on a 0-dimension tensor return lambda tensor: tensor - first_class_dims: Tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) - identifier_dim_map: Dict[Union[str, AnonymousAxis], Tuple[str, ...]] = {} - anon_axes: List[AnonymousAxis] = [] + first_class_dims: tuple[str, ...] = tuple(f"d{i}" for i in range(n_dims)) + identifier_dim_map: dict[Union[str, AnonymousAxis], tuple[str, ...]] = {} + anon_axes: list[AnonymousAxis] = [] # map the left-hand side identifiers to strings representing first class dims dims_i = 0 @@ -99,11 +98,11 @@ def _create_rearrange_callable( raise ValueError(f"Unexpected dimension: {dimension}") def composition_to_dims( - composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]], - ) -> List[Union[str, Tuple[str, ...]]]: + composition: Sequence[Union[list[Union[str, AnonymousAxis]], str]], + ) -> list[Union[str, tuple[str, ...]]]: """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first class dims.""" - dim_composition: List[Union[str, Tuple[str, ...]]] = [] + dim_composition: list[Union[str, tuple[str, ...]]] = [] for dimension in composition: if isinstance(dimension, list): dim_composition.append( @@ -152,7 +151,7 @@ def _create_rearrange_callable( def rearrange( - tensor: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]], + tensor: Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]], pattern: str, **axes_lengths: int, ) -> torch.Tensor: diff --git a/test/ao/sparsity/test_activation_sparsifier.py b/test/ao/sparsity/test_activation_sparsifier.py index 1a59aa0ac40b..56a88596f994 100644 --- a/test/ao/sparsity/test_activation_sparsifier.py +++ b/test/ao/sparsity/test_activation_sparsifier.py @@ -2,7 +2,6 @@ import copy import logging -from typing import List import torch import torch.nn as nn @@ -247,7 +246,7 @@ class TestActivationSparsifier(TestCase): assert mask2 is None else: assert type(mask1) == type(mask2) - if isinstance(mask1, List): + if isinstance(mask1, list): assert len(mask1) == len(mask2) for idx in range(len(mask1)): assert torch.all(mask1[idx] == mask2[idx]) @@ -258,7 +257,7 @@ class TestActivationSparsifier(TestCase): for state in state_dict["state"].values(): mask = state["mask"] if mask is not None: - if isinstance(mask, List): + if isinstance(mask, list): for idx in range(len(mask)): assert mask[idx].is_sparse else: diff --git a/test/ao/sparsity/test_data_scheduler.py b/test/ao/sparsity/test_data_scheduler.py index 5f102486ecf7..6481867292e4 100644 --- a/test/ao/sparsity/test_data_scheduler.py +++ b/test/ao/sparsity/test_data_scheduler.py @@ -3,7 +3,6 @@ import copy import logging import warnings -from typing import Tuple import torch from torch import nn @@ -73,7 +72,7 @@ class TestBaseDataScheduler(TestCase): def _get_name_data_config(self, some_data, defaults): config = copy.deepcopy(defaults) - if isinstance(some_data, Tuple): + if isinstance(some_data, tuple): # dealing with data_list name, data = some_data else: diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index 90b204aec780..4f987b994ae8 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -4,7 +4,6 @@ import copy import itertools import logging import math -from typing import Tuple import torch from torch import nn @@ -54,7 +53,7 @@ class _BaseDataSparsiferTestCase(TestCase): @staticmethod def _get_name_data_config(some_data, defaults=None): - if isinstance(some_data, Tuple): + if isinstance(some_data, tuple): # dealing with data_list name, data = some_data config = defaults @@ -482,8 +481,9 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase): nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)), ) - param4, param5 = nn.Parameter(torch.randn(1, 1)), nn.Parameter( - torch.randn(4, 4) + param4, param5 = ( + nn.Parameter(torch.randn(1, 1)), + nn.Parameter(torch.randn(4, 4)), ) data_list = [("param1", param1), ("param2", param2), ("param3", param3)] defaults = {"test": 3} @@ -585,8 +585,9 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase): nn.Parameter(torch.randn(4, 4)), nn.Parameter(torch.randn(5, 5)), ) - param4, param5 = nn.Parameter(torch.randn(10, 10)), nn.Parameter( - torch.randn(4, 4) + param4, param5 = ( + nn.Parameter(torch.randn(10, 10)), + nn.Parameter(torch.randn(4, 4)), ) data_list = [("param1", param1), ("param2", param2), ("param3", param3)] defaults = { diff --git a/test/custom_operator/test_infer_schema_annotation.py b/test/custom_operator/test_infer_schema_annotation.py index 755a3364047a..3e32ffc661b1 100644 --- a/test/custom_operator/test_infer_schema_annotation.py +++ b/test/custom_operator/test_infer_schema_annotation.py @@ -2,13 +2,17 @@ from __future__ import annotations import typing -from typing import List, Optional, Sequence, Union # noqa: F401 +from typing import List, Optional, Union import torch from torch import Tensor, types from torch.testing._internal.common_utils import run_tests, TestCase +if typing.TYPE_CHECKING: + from collections.abc import Sequence + + mutates_args = {} diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 575a7d6059c9..18852166a8a5 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -6,7 +6,8 @@ import functools import itertools import unittest from collections import defaultdict -from typing import Any, Iterable, List, Optional, Tuple, Union +from collections.abc import Iterable +from typing import Any, List, Optional, Tuple, Union import torch import torch.distributed as dist diff --git a/test/distributed/_tensor/test_pointwise_ops.py b/test/distributed/_tensor/test_pointwise_ops.py index 4863bc9a9e94..189f1ddea719 100644 --- a/test/distributed/_tensor/test_pointwise_ops.py +++ b/test/distributed/_tensor/test_pointwise_ops.py @@ -1,7 +1,8 @@ # Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] -from typing import Any, Callable, Dict, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Callable, Dict, Optional from unittest import skip import torch diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 5150e8ce810b..4562e028aa3c 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -10,8 +10,9 @@ import inspect import io import operator import unittest +from collections.abc import Sequence from enum import Enum -from typing import Dict, List, Sequence +from typing import Dict, List from unittest.mock import patch import torch diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 75760b484a6a..57931a105186 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -21,10 +21,11 @@ import warnings import weakref from abc import ABC from collections import namedtuple +from collections.abc import Iterator from copy import deepcopy from enum import Enum, IntEnum from functools import wraps -from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict +from typing import Any, Dict, List, Literal, Tuple, TypedDict from unittest import mock import numpy as np diff --git a/test/export/test_lift_unlift.py b/test/export/test_lift_unlift.py index c027fc557178..28a3f61f8ab9 100644 --- a/test/export/test_lift_unlift.py +++ b/test/export/test_lift_unlift.py @@ -1,6 +1,7 @@ # Owner(s): ["oncall: export"] import unittest -from typing import Any, Dict, Optional, OrderedDict, Tuple +from collections import OrderedDict +from typing import Any, Dict, Optional, Tuple import torch from torch._export.passes.lift_constants_pass import ( diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 1d9f25b66273..130b0e52557b 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -14,8 +14,7 @@ import random import types import unittest import warnings -from collections import namedtuple -from typing import OrderedDict +from collections import namedtuple, OrderedDict from unittest.case import skipIf from common_utils import ( diff --git a/test/fx/test_partitioner_order.py b/test/fx/test_partitioner_order.py index 97b0e22ff3da..a646ec1bc776 100644 --- a/test/fx/test_partitioner_order.py +++ b/test/fx/test_partitioner_order.py @@ -1,7 +1,7 @@ # Owner(s): ["module: fx"] import unittest -from typing import Mapping +from collections.abc import Mapping import torch from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner diff --git a/test/lazy/test_ts_opinfo.py b/test/lazy/test_ts_opinfo.py index 2cc75f6e4ac2..7c467dc62413 100644 --- a/test/lazy/test_ts_opinfo.py +++ b/test/lazy/test_ts_opinfo.py @@ -3,8 +3,8 @@ import functools import itertools import os +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from unittest import skip import yaml diff --git a/test/nn/test_module_hooks.py b/test/nn/test_module_hooks.py index 2ebfe81540cd..c9c29f0ba4a3 100644 --- a/test/nn/test_module_hooks.py +++ b/test/nn/test_module_hooks.py @@ -9,7 +9,7 @@ from collections import namedtuple, OrderedDict from copy import deepcopy from functools import partial from tempfile import NamedTemporaryFile -from typing import Any, Dict, List, Tuple +from typing import Any import torch import torch.nn as nn @@ -55,11 +55,11 @@ class ToyModel(nn.Module): def forward_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - inp: Tuple[torch.Tensor], + inp: tuple[torch.Tensor], out: torch.Tensor, ) -> None: fired_hooks.append(hook_id) @@ -69,11 +69,11 @@ def forward_hook( def forward_pre_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - inp: Tuple[torch.Tensor], + inp: tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) @@ -82,12 +82,12 @@ def forward_pre_hook( def full_backward_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - grad_input: Tuple[torch.Tensor], - grad_output: Tuple[torch.Tensor], + grad_input: tuple[torch.Tensor], + grad_output: tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) @@ -97,11 +97,11 @@ def full_backward_hook( def full_backward_pre_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - grad_input: Tuple[torch.Tensor], + grad_input: tuple[torch.Tensor], ) -> None: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) @@ -122,8 +122,8 @@ class KwargModel(nn.Module): def internal_forward_hook( self, module: nn.Module, - args: Tuple[torch.Tensor], - kwargs: Dict[str, Any], + args: tuple[torch.Tensor], + kwargs: dict[str, Any], out: torch.Tensor, ): return out + kwargs["bias"] @@ -142,13 +142,13 @@ class FailsInForwardModel(nn.Module): def kwarg_forward_pre_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - args: Tuple[torch.Tensor], - kwargs: Dict[str, Any], -) -> Tuple[Any, Any]: + args: tuple[torch.Tensor], + kwargs: dict[str, Any], +) -> tuple[Any, Any]: fired_hooks.append(hook_id) self.assertEqual(id(module), id(expected_module)) self.assertEqual(len(args), 1) @@ -158,12 +158,12 @@ def kwarg_forward_pre_hook( def kwarg_forward_hook( self: TestCase, - fired_hooks: List[int], + fired_hooks: list[int], expected_module: nn.Module, hook_id: int, module: nn.Module, - args: Tuple[torch.Tensor], - kwargs: Dict[str, Any], + args: tuple[torch.Tensor], + kwargs: dict[str, Any], out: torch.Tensor, ) -> Any: fired_hooks.append(hook_id) @@ -188,7 +188,7 @@ class DummyContextManager: class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_forward_hooks(self, named_tuple): - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_hook, self, fired_hooks, model.net1.seq2) @@ -210,7 +210,7 @@ class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_forward_pre_hooks(self, named_tuple): - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1) @@ -232,7 +232,7 @@ class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_full_backward_hooks(self, named_tuple): - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_hook, self, fired_hooks, model.net1) @@ -254,7 +254,7 @@ class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_full_backward_pre_hooks(self, named_tuple): - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1) @@ -294,7 +294,7 @@ class TestModuleHooks(TestCase): @parametrize_test("named_tuple", (True, False)) def test_mixed_hooks(self, named_tuple): - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] model = ToyModel(named_tuple) x = torch.randn(10, 10) model.register_forward_pre_hook( @@ -319,7 +319,7 @@ class TestModuleHooks(TestCase): def test_kwarg_hooks(self): # 1. test forward pre hook - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() @@ -336,7 +336,7 @@ class TestModuleHooks(TestCase): self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5) # 2. test forward pre and forward hooks - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() @@ -372,7 +372,7 @@ class TestModuleHooks(TestCase): def test_remove_kwarg_hooks(self): # test forward pre and forward hooks - fired_hooks: List[int] = [] + fired_hooks: list[int] = [] x: torch.Tensor = torch.ones(10, 10) bias: torch.Tensor = torch.ones(10, 10) model = KwargModel() @@ -1217,8 +1217,8 @@ class TestModuleGlobalHooks(TestCase): def test_module_global_hooks_with_kwargs(self): def kwarg_global_forward_hook( module: nn.Module, - args: Tuple[torch.Tensor], - kwargs: Dict[str, Any], + args: tuple[torch.Tensor], + kwargs: dict[str, Any], out: torch.Tensor, ) -> Any: out = out + kwargs["bias"] diff --git a/test/nn/test_packed_sequence.py b/test/nn/test_packed_sequence.py index 1d8b8966af16..0d6de0145106 100644 --- a/test/nn/test_packed_sequence.py +++ b/test/nn/test_packed_sequence.py @@ -2,7 +2,6 @@ import itertools import random -from typing import List import torch import torch.nn.utils.rnn as rnn_utils @@ -219,7 +218,7 @@ class PackedSequenceTest(TestCase): # more dimensions maxlen = 9 for num_dim in (0, 1, 2, 3): - sequences: List[torch.Tensor] = [] + sequences: list[torch.Tensor] = [] trailing_dims = [4] * num_dim for i in range(1, maxlen + 1): seq_len = i * i diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index 67b38690bf85..82f6ca2fafe3 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -1573,10 +1573,10 @@ torch.cuda.synchronize() return torch.stack([col, col + 2], 1).view(2, 2, 2, 2) if adaptive: - cls_name = "AdaptiveMaxPool{}d".format(num_dim) # noqa: UP032 + cls_name = f"AdaptiveMaxPool{num_dim}d" else: # FIXME(#105716): Test fails when using f-string - cls_name = "MaxPool{}d".format(num_dim) # noqa: UP032 + cls_name = f"MaxPool{num_dim}d" module_cls = getattr(nn, cls_name) module = module_cls(2, return_indices=True).to(device, dtype=dtype) numel = 4 ** (num_dim + 1) diff --git a/test/onnx/internal/test_registraion.py b/test/onnx/internal/test_registraion.py index 0fb87ac019e6..39afcc24ee65 100644 --- a/test/onnx/internal/test_registraion.py +++ b/test/onnx/internal/test_registraion.py @@ -1,7 +1,7 @@ # Owner(s): ["module: onnx"] """Unit tests for the internal registration wrapper module.""" -from typing import Sequence +from collections.abc import Sequence from torch.onnx import errors from torch.onnx._internal import registration diff --git a/test/onnx/onnx_test_common.py b/test/onnx/onnx_test_common.py index 69a9a3b4e556..8f84d15e70a2 100644 --- a/test/onnx/onnx_test_common.py +++ b/test/onnx/onnx_test_common.py @@ -10,19 +10,8 @@ import logging import os import unittest import warnings -from typing import ( - Any, - Callable, - Collection, - Iterable, - List, - Mapping, - Optional, - Sequence, - Tuple, - Type, - Union, -) +from collections.abc import Collection, Iterable, Mapping, Sequence +from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import onnxruntime diff --git a/test/onnx/test_models_onnxruntime.py b/test/onnx/test_models_onnxruntime.py index 43f495c50964..28e6344849f1 100644 --- a/test/onnx/test_models_onnxruntime.py +++ b/test/onnx/test_models_onnxruntime.py @@ -3,7 +3,8 @@ import os import unittest from collections import OrderedDict -from typing import List, Mapping, Tuple +from collections.abc import Mapping +from typing import List, Tuple import onnx_test_common import parameterized diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 41b2c78ca7c7..dbca01ddbce6 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -10,7 +10,7 @@ import itertools import unittest import unittest.mock import warnings -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import numpy as np @@ -26,6 +26,10 @@ from torch.onnx._internal import registration from torch.testing._internal import common_quantization, common_utils, jit_utils +if TYPE_CHECKING: + from collections.abc import Iterable + + def export_to_onnx( model: Union[torch.nn.Module, torch.jit.ScriptFunction], input: Union[torch.Tensor, Tuple[torch.Tensor]], diff --git a/test/package/test_glob_group.py b/test/package/test_glob_group.py index ad798b5e869a..f41f2a86f6da 100644 --- a/test/package/test_glob_group.py +++ b/test/package/test_glob_group.py @@ -1,6 +1,6 @@ # Owner(s): ["oncall: package/deploy"] -from typing import Iterable +from collections.abc import Iterable from torch.package import GlobGroup from torch.testing._internal.common_utils import run_tests diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index eff4d187ef80..223c92946a05 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -18,7 +18,7 @@ import os import sys import tempfile import unittest -from typing import Any, Dict, List +from typing import Any import torch import torch.nn as nn @@ -51,7 +51,7 @@ from torch.testing._internal.common_utils import ( from torch.utils._triton import has_triton -Json = Dict[str, Any] +Json = dict[str, Any] class TestExecutionTrace(TestCase): @@ -97,7 +97,7 @@ class TestExecutionTrace(TestCase): nodes = et_graph["nodes"] return nodes - def get_execution_trace_rf_ids(self, nodes: List[Json]) -> List[int]: + def get_execution_trace_rf_ids(self, nodes: list[Json]) -> list[int]: """Returns a sorted list of rf_id (record function ids) in execution trace""" def get_rf_id(node): @@ -115,7 +115,7 @@ class TestExecutionTrace(TestCase): ) return sorted(rf_id for rf_id in rf_ids_ if rf_id is not None) - def get_kineto_rf_ids(self, events: List[Json]) -> List[int]: + def get_kineto_rf_ids(self, events: list[Json]) -> list[int]: """Returns a sorted list of Record function IDs for CPU operators and user annotations""" ops_and_annotations = ( e for e in events if e.get("cat", "") in ["cpu_op", "user_annotation"] diff --git a/test/profiler/test_memory_profiler.py b/test/profiler/test_memory_profiler.py index 09d2eb8641b0..304587faf8a6 100644 --- a/test/profiler/test_memory_profiler.py +++ b/test/profiler/test_memory_profiler.py @@ -5,7 +5,8 @@ import itertools as it import sys import textwrap import unittest -from typing import Callable, Dict, Iterator, List, Optional, Tuple +from collections.abc import Iterator +from typing import Callable, Optional import torch from torch._C._profiler import _EventType, _TensorMetadata @@ -309,9 +310,9 @@ class TestDataFlow(TestCase): @staticmethod def formatSchemas( prof: torch.profiler.profile, indent: int = 12 - ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]: + ) -> tuple[tuple[str, tuple[bool, ...]], ...]: tree = prof.profiler.kineto_results.experimental_event_tree() - out: List[Tuple[str, Tuple[bool, ...]]] = [] + out: list[tuple[str, tuple[bool, ...]]] = [] for node in _utils.traverse_dfs(tree): if node.tag == _EventType.TorchOp: e = node.extra_fields @@ -327,8 +328,8 @@ class TestDataFlow(TestCase): @staticmethod def _run_and_format_data_flow( - inputs: Dict[str, torch.Tensor], - f: Callable[..., Optional[Dict[str, torch.Tensor]]], + inputs: dict[str, torch.Tensor], + f: Callable[..., Optional[dict[str, torch.Tensor]]], indent: int = 12, ) -> str: with profile() as prof: @@ -339,7 +340,7 @@ class TestDataFlow(TestCase): graph = memory_profile._data_flow_graph storage_to_id = {key.storage.ptr: key.id for key in graph._active_version} - lines: List[str] = [] + lines: list[str] = [] for name, t in it.chain(inputs.items(), outputs.items()): lines.append(f"{name + ':':<8} T{storage_to_id[t.storage().data_ptr()]}") if t.grad is not None: @@ -352,7 +353,7 @@ class TestDataFlow(TestCase): for node in graph.flow_nodes: destroyed = {k for k, v in node._edges.items() if v.is_deletion} - inputs: List[str] = [] + inputs: list[str] = [] for key, (_, v) in node.inputs.items(): inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})") @@ -833,7 +834,7 @@ class TestMemoryProfilerE2E(TestCase): @staticmethod def _lookup_tensor_categories( t: torch.Tensor, memory_profile: _memory_profiler.MemoryProfile - ) -> Dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]: + ) -> dict[_memory_profiler.TensorAndID, Optional[_memory_profiler.Category]]: storage = t.storage() if storage is None: raise ValueError("Cannot look up uninitialized Tensor.") @@ -889,7 +890,7 @@ class TestMemoryProfilerE2E(TestCase): fn(lambda name: record_ops.mark_region(f"-- {name} ".ljust(105, "-"))) memory_profile = prof._memory_profile() - ptr_pair_to_key: Dict[Tuple[int, int], _memory_profiler.TensorKey] = {} + ptr_pair_to_key: dict[tuple[int, int], _memory_profiler.TensorKey] = {} snapshot = memory_profile._category_snapshot() # Build map from observed live Tensors to the memory profiler's @@ -922,7 +923,7 @@ class TestMemoryProfilerE2E(TestCase): return f"{target_key.storage.allocation_id} ({','.join(categories)})" - out: List[str] = [] + out: list[str] = [] for name, inputs, outputs in record_ops.results: if inputs or outputs: # PyTorch ops diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index b0e2930f1445..48fd8c89aacf 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -17,7 +17,7 @@ import threading import time import unittest from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional from unittest.mock import patch import expecttest @@ -2277,7 +2277,7 @@ class MockProfilerEvent: start_time_ns: int duration_time_ns: int correlation_id: int = 0 - children: List["MockProfilerEvent"] = field(default_factory=list) + children: list["MockProfilerEvent"] = field(default_factory=list) parent: Optional["MockProfilerEvent"] = None @property @@ -2301,7 +2301,7 @@ class MockNode: @unittest.skipIf(sys.version_info >= (3, 13), "segfaults") class TestExperimentalUtils(TestCase): - def make_tree(self) -> List[MockNode]: + def make_tree(self) -> list[MockNode]: tree = { "root_0": { "1": {"2": {}}, diff --git a/test/profiler/test_record_function.py b/test/profiler/test_record_function.py index 1608699d1aeb..4bc4ad16acb4 100644 --- a/test/profiler/test_record_function.py +++ b/test/profiler/test_record_function.py @@ -14,7 +14,7 @@ try: except ImportError: None -from typing import Any, Dict +from typing import Any import torch import torch.optim @@ -29,7 +29,7 @@ from torch.profiler import kineto_available, record_function from torch.testing._internal.common_utils import run_tests, TestCase -Json = Dict[str, Any] +Json = dict[str, Any] class TestRecordFunction(TestCase): diff --git a/test/profiler/test_torch_tidy.py b/test/profiler/test_torch_tidy.py index be5884e93c6f..119db5bb856f 100644 --- a/test/profiler/test_torch_tidy.py +++ b/test/profiler/test_torch_tidy.py @@ -19,7 +19,7 @@ import sys import textwrap import unittest import weakref -from typing import Any, Dict, List +from typing import Any import torch import torch.nn as nn @@ -30,7 +30,7 @@ from torch.profiler import _utils, profile from torch.testing._internal.common_utils import run_tests, TestCase -Json = Dict[str, Any] +Json = dict[str, Any] from torch._C._profiler import _ExtraFields_PyCall @@ -455,7 +455,7 @@ class TestTorchTidyProfiler(TestCase): nodes = p.profiler.kineto_results.experimental_event_tree() - def find_chain(names: List[str]): + def find_chain(names: list[str]): out = [] for name in names: root = [out[-1]] if out else nodes diff --git a/test/test_autograd.py b/test/test_autograd.py index 3fce999c0b17..3898fb286ac0 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -25,7 +25,7 @@ from copy import deepcopy from functools import partial, reduce from itertools import product from operator import mul -from typing import List, Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING import torch import torch.autograd._functions @@ -10091,14 +10091,14 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) { def test_multi_grad_any_hooks(self): hook_id = 0 - any_hook_handles: List[RemovableHandle] = [] + any_hook_handles: list[RemovableHandle] = [] class MultiOutputModule(nn.Module): def __init__(self) -> None: super().__init__() self.lin = nn.Linear(3, 3) - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: z = self.lin(x) out = torch.sin(z), torch.cos(z) nonlocal hook_id @@ -10123,7 +10123,7 @@ TORCH_LIBRARY(test_multigrad_all_hooks, m) { z = y[0] + y[1] return self.mod2(z) - hook_order: List[int] = [] + hook_order: list[int] = [] hook_count = 0 def hook(hook_id: int, *unused): @@ -13975,7 +13975,7 @@ class TestSelectiveActivationCheckpoint(TestCase): counter = [0] @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) - def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: + def sin_with_extra(x: torch.Tensor) -> tuple[torch.Tensor, int]: counter[0] += 1 return x.sin(), 2 diff --git a/test/test_bundled_inputs.py b/test/test_bundled_inputs.py index 8d5c02cf52c5..0ff5373993cf 100644 --- a/test/test_bundled_inputs.py +++ b/test/test_bundled_inputs.py @@ -4,7 +4,7 @@ import io import textwrap -from typing import Dict, List, Optional +from typing import Optional import torch import torch.utils.bundled_inputs @@ -32,7 +32,7 @@ class TestBundledInputs(TestCase): sm = torch.jit.script(SingleTensorModel()) original_size = model_size(sm) - get_expr: List[str] = [] + get_expr: list[str] = [] samples = [ # Tensor with small numel and small storage. (torch.tensor([1]),), @@ -328,8 +328,8 @@ class TestBundledInputs(TestCase): class MyModel(torch.nn.Module): def forward( self, - arg1: Optional[Dict[str, torch.Tensor]], - arg2: Optional[List[torch.Tensor]], + arg1: Optional[dict[str, torch.Tensor]], + arg2: Optional[list[torch.Tensor]], arg3: torch.Tensor, ): if arg1 is None: @@ -393,7 +393,7 @@ class TestBundledInputs(TestCase): """, ) - out: List[str] = [] + out: list[str] = [] sm = torch.jit.script(MyModel()) original_size = model_size(sm) small_inputs = ( diff --git a/test/test_cuda.py b/test/test_cuda.py index f4c910560ebc..4bc00826ee68 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4054,11 +4054,7 @@ class TestCudaMallocAsync(TestCase): that the pytorch call is returning a correct list of UUIDs. """ cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'" - uuids = ( - subprocess.check_output(cmd, shell=True, universal_newlines=True) - .strip() - .split("\n") - ) + uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n") uuids = [s.strip() for s in uuids] raw_uuids = torch.cuda._raw_device_uuid_amdsmi() for uuid in uuids: @@ -4082,11 +4078,7 @@ import os print(f"{torch.cuda.device_count()}") """ cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'" - uuids = ( - subprocess.check_output(cmd, shell=True, universal_newlines=True) - .strip() - .split("\n") - ) + uuids = subprocess.check_output(cmd, shell=True, text=True).strip().split("\n") uuids = [s.strip() for s in uuids] custom_envs = [] diff --git a/test/test_cuda_sanitizer.py b/test/test_cuda_sanitizer.py index 93a1f4e11505..6d2ecc36a093 100644 --- a/test/test_cuda_sanitizer.py +++ b/test/test_cuda_sanitizer.py @@ -3,7 +3,7 @@ import sys import textwrap import traceback -from typing import List, Optional +from typing import Optional import torch import torch.cuda._sanitizer as csan @@ -148,9 +148,9 @@ class TestEventHandler(TestCase): def kernel_launch( self, stream: StreamId, - read_only: Optional[List[DataPtr]] = None, - read_write: Optional[List[DataPtr]] = None, - ) -> List[csan.SynchronizationError]: + read_only: Optional[list[DataPtr]] = None, + read_write: Optional[list[DataPtr]] = None, + ) -> list[csan.SynchronizationError]: if read_only is None: read_only = [] if read_write is None: @@ -167,8 +167,8 @@ class TestEventHandler(TestCase): def assert_good_kernel_launch( self, stream: StreamId, - read_only: Optional[List[DataPtr]] = None, - read_write: Optional[List[DataPtr]] = None, + read_only: Optional[list[DataPtr]] = None, + read_write: Optional[list[DataPtr]] = None, ) -> None: self.assertEqual(self.kernel_launch(stream, read_only, read_write), []) @@ -176,8 +176,8 @@ class TestEventHandler(TestCase): self, number_of_errors: int, stream: StreamId, - read_only: Optional[List[DataPtr]] = None, - read_write: Optional[List[DataPtr]] = None, + read_only: Optional[list[DataPtr]] = None, + read_write: Optional[list[DataPtr]] = None, ) -> None: errors = self.kernel_launch(stream, read_only, read_write) self.assertEqual(len(errors), number_of_errors) diff --git a/test/test_dataloader.py b/test/test_dataloader.py index 5a251679795e..f1396cbda2cd 100644 --- a/test/test_dataloader.py +++ b/test/test_dataloader.py @@ -1262,14 +1262,12 @@ class TestDataLoader(TestCase): list(iter(loader)) def test_typing(self): - from typing import List - # Make sure there is no TypeError - class SomeDatasetClass(Dataset[List[torch.Tensor]]): + class SomeDatasetClass(Dataset[list[torch.Tensor]]): pass - def _create_dataloader(is_train: bool) -> DataLoader[List[torch.Tensor]]: + def _create_dataloader(is_train: bool) -> DataLoader[list[torch.Tensor]]: pass @unittest.skipIf(IS_SANDCASTLE, "subprocess doesn't work in FB internal CI") diff --git a/test/test_functional_optim.py b/test/test_functional_optim.py index 29b240801b92..e9c5566e26fd 100644 --- a/test/test_functional_optim.py +++ b/test/test_functional_optim.py @@ -1,7 +1,7 @@ # Owner(s): ["oncall: distributed"] import unittest -from typing import List, Optional, Tuple +from typing import Optional import torch import torch.distributed @@ -27,9 +27,9 @@ class MyModule(torch.nn.Module): class MyDummyFnOptimizer: def __init__( self, - params: List[Tensor], + params: list[Tensor], lr: float = 1e-3, - betas: Tuple[float, float] = (0.9, 0.999), + betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-6, weight_decay: float = 0.0, _allow_empty_param_list: bool = False, @@ -63,7 +63,7 @@ class MyDummyFnOptimizer: "MyDummyFnOptimizer does not support step_param() as of now" ) - def step(self, gradients: List[Optional[Tensor]]): + def step(self, gradients: list[Optional[Tensor]]): # call the custom optimizer step implementation with torch.no_grad(): raise RuntimeError("MyDummyFnOptimizer does not support step() as of now") diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 57cbd8f8be03..61c5687bf886 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1200,6 +1200,15 @@ class {test_classname}(torch.nn.Module): inp3_y = inp3.y return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y + class MyModule2(torch.nn.Module): + def forward(self, inp: tuple[CustomType, torch.Tensor], inp2: list[CustomType], inp3: CustomNamedTuple): + inp_0 = inp[0] + inp_1 = inp[1] + inp2_0 = inp2[0] + inp3_x = inp3.x + inp3_y = inp3.y + return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y + my_module = MyModule() my_module_traced = torch.fx.symbolic_trace(my_module) @@ -1214,6 +1223,20 @@ class {test_classname}(torch.nn.Module): if node.target == operator.getitem: self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") + my_module = MyModule2() + my_module_traced = torch.fx.symbolic_trace(my_module) + + # by default, fx transform loses type annotation of getitem nodes. + for node in my_module_traced.graph.nodes: + if node.target == operator.getitem: + assert node.type is None + + annotate_getitem_nodes(my_module_traced.graph) + + for node in my_module_traced.graph.nodes: + if node.target == operator.getitem: + self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.") + def test_subgraph_uniquename(self): class MyModule(torch.nn.Module): def __init__(self) -> None: diff --git a/test/test_masked.py b/test/test_masked.py index ecd1769a0d67..c5aee472a9a8 100644 --- a/test/test_masked.py +++ b/test/test_masked.py @@ -5,7 +5,7 @@ import itertools import torch -from typing import List, Any +from typing import Any from functools import wraps import unittest from torch.testing._internal.common_utils import skipIfTorchDynamo @@ -100,7 +100,7 @@ def apply_masked_reduction_along_dim(op, input, *args, **kwargs): # dimensions along which the reduction operation is applied: dim_ = torch.masked._canonical_dim(dim, input.ndim) # slices in product(*ranges) define all elementary slices: - ranges: List[Any] = [] + ranges: list[Any] = [] # shape of output for the keepdim=True case: shape = [] for i in range(input.ndim): diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index 0f2c6b69ff8c..4cb27866ef23 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -1,7 +1,7 @@ # Owner(s): ["module: mkldnn"] import itertools import unittest -from typing import NamedTuple, List +from typing import NamedTuple import torch from torch import nn @@ -16,7 +16,7 @@ FUSION_GROUP = 'prim::TensorExprGroup' class PointwisePostOp(NamedTuple): attr : str pointwise_module : nn.Module - scalars : List = [] + scalars : list = [] algorithm : str = "" CONV_MODULES = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} diff --git a/test/test_native_functions.py b/test/test_native_functions.py index 2760ca9171ab..5a894c278fd0 100644 --- a/test/test_native_functions.py +++ b/test/test_native_functions.py @@ -1,6 +1,6 @@ # Owner(s): ["module: unknown"] -from typing import Optional, List +from typing import Optional import torch from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorchDynamo @@ -8,12 +8,12 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfTorc class FloatListWrapperModule(torch.nn.Module): - def forward(self, values, incr: Optional[List[float]]): + def forward(self, values, incr: Optional[list[float]]): return torch._C._nn._test_optional_floatlist(values, incr) class IntListWrapperModule(torch.nn.Module): - def forward(self, values, incr: Optional[List[int]]): + def forward(self, values, incr: Optional[list[int]]): return torch._C._nn._test_optional_intlist(values, incr) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index d1ec2faed6cd..31cc006e305f 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -9,7 +9,7 @@ import sys import tempfile import unittest from functools import partial -from typing import Optional, Tuple +from typing import Optional import numpy as np @@ -3545,7 +3545,7 @@ def get_tolerances( true_value: torch.Tensor, computed_value: torch.Tensor, fudge_factor: Optional[float] = None, -) -> Tuple[float, float]: +) -> tuple[float, float]: """Returns the absolute and relative tolerances for comparing two tensors.""" fudge_factor = fudge_factor if fudge_factor is not None else 1.0 atol = get_atol(true_value, computed_value) diff --git a/test/test_nnapi.py b/test/test_nnapi.py index ef9fe7bb6dab..d8a6392d72f1 100644 --- a/test/test_nnapi.py +++ b/test/test_nnapi.py @@ -4,7 +4,6 @@ import ctypes import os import unittest -from typing import Tuple import torch from torch.backends._nnapi.prepare import convert_model_to_nnapi @@ -700,7 +699,7 @@ class TestNNAPI(TestCase): def test_multi_output(self): class MultiModel(torch.nn.Module): - def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, lhs, rhs) -> tuple[torch.Tensor, torch.Tensor]: the_sum = lhs + rhs the_diff = lhs - rhs return the_sum, the_diff diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index bff7681bbc81..15864a056041 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -278,7 +278,7 @@ class TestNumPyInterop(TestCase): def test_from_numpy_no_leak_on_invalid_dtype(self): # This used to leak memory as the `from_numpy` call raised an exception and didn't decref the temporary # object. See https://github.com/pytorch/pytorch/issues/121138 - x = np.array("value".encode("ascii")) + x = np.array(b"value") for _ in range(1000): try: torch.from_numpy(x) diff --git a/test/test_ops.py b/test/test_ops.py index d358f745eca0..afbd0507be2b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -11,7 +11,6 @@ from collections import defaultdict from collections.abc import Sequence from functools import partial from importlib import import_module -from typing import Dict, List import torch import torch._prims as prims @@ -1483,7 +1482,7 @@ class TestCommon(TestCase): unsupported_dtypes = set() supported_backward_dtypes = set() unsupported_backward_dtypes = set() - dtype_error: Dict[torch.dtype, Exception] = {} + dtype_error: dict[torch.dtype, Exception] = {} def unsupported(dtype, e): dtype_error[dtype] = e @@ -1987,7 +1986,7 @@ class TestCompositeCompliance(TestCase): for sample in op.sample_inputs(device, dtype, requires_grad=False): inp = sample.input outs = op(inp, *sample.args, **sample.kwargs) - if not isinstance(outs, (tuple, List)): + if not isinstance(outs, (tuple, list)): outs = [outs] # for all outputs that are views of the input, we should be able to replay the diff --git a/test/test_optim.py b/test/test_optim.py index f44a7e877779..91a8506ce4f1 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -4,7 +4,7 @@ import math import tempfile import unittest from copy import deepcopy -from typing import Any, Dict, Tuple +from typing import Any from unittest.mock import patch from optim.test_lrscheduler import TestLRScheduler # noqa: F401 @@ -1769,8 +1769,8 @@ class TestOptimRenewed(TestCase): @staticmethod def _state_dict_post_hook( - optimizer: Optimizer, state_dict: Dict[str, Any] - ) -> Dict[str, Any]: + optimizer: Optimizer, state_dict: dict[str, Any] + ) -> dict[str, Any]: if "test" in state_dict["state"]: state_dict["state"].pop("test") state_dict["ran_state_dict_pre_hook"] = True @@ -1821,14 +1821,14 @@ class TestOptimRenewed(TestCase): @staticmethod def _load_state_dict_pre_hook1( - optimizer: Optimizer, state_dict: Dict[str, Any] + optimizer: Optimizer, state_dict: dict[str, Any] ) -> None: state_dict["param_groups"][0]["lr"] = 0.002 @staticmethod def _load_state_dict_pre_hook2( - optimizer: Optimizer, state_dict: Dict[str, Any] - ) -> Dict[str, Any]: + optimizer: Optimizer, state_dict: dict[str, Any] + ) -> dict[str, Any]: # The typical use case for returning a state dict is to drastically modify the state dict. # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used my_state_dict = deepcopy(state_dict) @@ -1906,7 +1906,7 @@ class TestOptimRenewed(TestCase): @optims(optim_db, dtypes=[torch.float32]) def test_step_post_hook(self, device, dtype, optim_info): - def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -1938,7 +1938,7 @@ class TestOptimRenewed(TestCase): @optims(optim_db, dtypes=[torch.float32]) def test_step_pre_hook(self, device, dtype, optim_info): - def pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data += 2 @@ -1970,19 +1970,19 @@ class TestOptimRenewed(TestCase): @optims(optim_db, dtypes=[torch.float32]) def test_step_all_hooks(self, device, dtype, optim_info): - def global_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def global_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data.append(0) - def global_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def global_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data.append(5) - def local_pre_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def local_pre_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data.append(1) - def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]): + def local_post_hook(opt: Optimizer, args: tuple[Any], kwargs: dict[Any, Any]): nonlocal data data.append(2) diff --git a/test/test_reductions.py b/test/test_reductions.py index 8ce65c789852..dc84432777d3 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -5,7 +5,7 @@ import torch import numpy as np import math -from typing import Dict, List, Sequence +from collections.abc import Sequence import random from functools import partial from itertools import product, combinations, permutations @@ -736,7 +736,7 @@ class TestReductions(TestCase): # TODO: kill this ane replace with common creation ops def _make_tensors(self, shape, val_range=(-100, 100), use_floating=True, use_integral=True, - use_complex=False) -> Dict[str, List[torch.Tensor]]: + use_complex=False) -> dict[str, list[torch.Tensor]]: float_types = [torch.double, torch.float] int_types = [torch.int64, @@ -778,7 +778,7 @@ class TestReductions(TestCase): types += int_types if use_complex: types += complex_types - tensors: Dict[str, List[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []} + tensors: dict[str, list[torch.Tensor]] = {"cont": [], "noncont": [], "slice": []} for dtype in types: tensors["cont"].append(make_contiguous(shape, dtype)) tensors["noncont"].append(make_non_contiguous(shape, dtype)) diff --git a/test/test_sparse.py b/test/test_sparse.py index a1da3872fd8f..64d7ad9b1c2a 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -15,7 +15,7 @@ from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm skipIfCrossRef from torch.testing._internal.common_cuda import TEST_CUDA from numbers import Number -from typing import Dict, Any +from typing import Any from packaging import version from torch.testing._internal.common_cuda import \ (SM53OrLater, SM80OrLater, TEST_MULTIGPU) @@ -334,7 +334,7 @@ class TestSparse(TestSparseBase): self.assertEqual(t._values(), tc._values()) return tc - value_map: Dict[Any, Any] = {} + value_map: dict[Any, Any] = {} for idx, val in zip(t._indices().t(), t._values()): idx_tup = tuple(idx.tolist()) if idx_tup in value_map: diff --git a/test/test_spectral_ops.py b/test/test_spectral_ops.py index 580a7a3bfde2..95ce1160753e 100644 --- a/test/test_spectral_ops.py +++ b/test/test_spectral_ops.py @@ -21,7 +21,7 @@ from torch.testing._internal.common_methods_invocations import ( from torch.testing._internal.common_cuda import SM53OrLater from torch._prims_common import corresponding_complex_dtype -from typing import Optional, List +from typing import Optional from packaging import version @@ -597,7 +597,7 @@ class TestFFT(TestCase): else: numpy_fn = getattr(np.fft, fname) - def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None): + def fn(t: torch.Tensor, s: Optional[list[int]], dim: list[int] = (-2, -1), norm: Optional[str] = None): return torch_fn(t, s, dim, norm) torch_fns = (torch_fn, torch.jit.script(fn)) diff --git a/test/test_static_runtime.py b/test/test_static_runtime.py index a5cf00e95224..893aea8e3130 100644 --- a/test/test_static_runtime.py +++ b/test/test_static_runtime.py @@ -2,14 +2,13 @@ # ruff: noqa: F841 import unittest -from typing import Dict, Optional +from typing import Optional import numpy as np import torch from torch import nn from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.static_module import StaticModule -from typing import List def linear_shim( @@ -108,7 +107,7 @@ def fork_wait_graph2(input1, input2): :param iters: number of future/wait pairs to be created """ def fork_wait_graph3(input, iters: int): - futures : List[torch.jit.Future[torch.Tensor]] = [] + futures : list[torch.jit.Future[torch.Tensor]] = [] for _ in range(iters): futures.append(torch.jit.fork(torch.neg, input)) results = [] @@ -123,7 +122,7 @@ def fork_wait_graph3(input, iters: int): :param num_child_forks: number of child forks per parent fork """ def fork_wait_graph4(input, num_forks: int, num_child_forks: int): - futures : List[torch.jit.Future[torch.Tensor]] = [] + futures : list[torch.jit.Future[torch.Tensor]] = [] for _ in range(num_forks): futures.append(torch.jit.fork(fork_wait_graph3, input, num_child_forks)) results = [] @@ -150,7 +149,7 @@ def loop_graph(a, b, iters: int): def output_graph(a, b, c, iters: int): s = torch.tensor([[3, 3], [3, 3]]) k = a + b * c + s - d: Dict[int, torch.Tensor] = {} + d: dict[int, torch.Tensor] = {} for i in range(iters): d[i] = k + i return d diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 7771c60c0527..5cd930274173 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -5,7 +5,7 @@ import itertools import math import pickle import sys -from typing import Callable, List, Tuple, Type +from typing import Callable import sympy @@ -594,7 +594,7 @@ class TestSympyInterp(TestCase): self.fail(f"Unexpected error for {fn}{args}: {str(e)}") -def type_name_fn(type: Type) -> str: +def type_name_fn(type: type) -> str: return type.__name__ @@ -606,7 +606,7 @@ def parametrize_relational_types(*types): class TestSympySolve(TestCase): - def _create_integer_symbols(self) -> List[sympy.Symbol]: + def _create_integer_symbols(self) -> list[sympy.Symbol]: return sympy.symbols("a b c", integer=True) def test_give_up(self): @@ -665,9 +665,9 @@ class TestSympySolve(TestCase): def _test_cases( self, - cases: List[Tuple[sympy.Basic, sympy.Basic]], + cases: list[tuple[sympy.Basic, sympy.Basic]], thing: sympy.Basic, - op: Type[sympy.Rel], + op: type[sympy.Rel], **kwargs, ): for source, expected in cases: @@ -761,7 +761,7 @@ class TestSympySolve(TestCase): Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos), }[op] - cases: List[Tuple[sympy.Basic, sympy.Basic]] = [ + cases: list[tuple[sympy.Basic, sympy.Basic]] = [ # 'b' is not strictly positive (op(FloorDiv(a, b), integer), None), # 'c' is not strictly positive diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 8dde36573a51..8afdc1c8791f 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -11,7 +11,7 @@ import unittest from itertools import product, combinations, combinations_with_replacement, permutations import random import tempfile -from typing import Any, Dict, List, Tuple +from typing import Any from torch.testing import make_tensor from torch.testing._internal.common_utils import ( @@ -4125,7 +4125,7 @@ class TestAsArray(TestCase): def test_default_device(self, device): original = torch.arange(5) - examples: List[Tuple[Any, Dict]] = [ + examples: list[tuple[Any, dict]] = [ (3, {}), (original, {}), (to_numpy(original), {}), diff --git a/test/test_testing.py b/test/test_testing.py index 2ff5cb00a02d..bfb84759a383 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -12,7 +12,8 @@ import re import subprocess import sys import unittest.mock -from typing import Any, Callable, Iterator, List, Tuple +from typing import Any, Callable +from collections.abc import Iterator import torch @@ -496,7 +497,7 @@ if __name__ == '__main__': self.assertNotIn('OK', stderr.decode('ascii')) -def make_assert_close_inputs(actual: Any, expected: Any) -> List[Tuple[Any, Any]]: +def make_assert_close_inputs(actual: Any, expected: Any) -> list[tuple[Any, Any]]: """Makes inputs for :func:`torch.testing.assert_close` functions based on two examples. Args: diff --git a/test/test_torch.py b/test/test_torch.py index 3679d6a25441..ff882391d5b5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -53,7 +53,6 @@ from torch.testing._internal.common_device_type import ( dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipMeta, PYTORCH_CUDA_MEMCHECK, largeTensorTest, onlyNativeDeviceTypes, skipCUDAIfNotRocm, get_all_device_types, skipXLA) -from typing import Tuple import torch.backends.quantized import torch.testing._internal.data from torch.testing._internal.common_cuda import ( @@ -3511,7 +3510,7 @@ else: def _prepare_data_for_index_copy_and_add_deterministic( self, dim: int, device: torch.device - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert (dim >= 0 and dim < 3) a = [5, 4, 3] a[dim] = 2000 diff --git a/test/test_transformers.py b/test/test_transformers.py index 715fbe4297bd..1e2b0727fc9e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -18,7 +18,7 @@ import math import itertools import torch.optim as optim from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, onlyCPU -from typing import List, Tuple, Optional, Dict +from typing import Optional import torch.utils.cpp_extension from torch.testing._internal.common_nn import NNTestCase from torch.testing._internal.common_utils import ( @@ -149,12 +149,12 @@ def _check_equal( def check_out_and_grad( - out_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - grad_query_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - grad_key_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - grad_value_tuple: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - grad_attn_mask_tuple: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, - fudge_factors: Optional[Dict[str, float]] = None + out_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + grad_query_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + grad_key_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + grad_value_tuple: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + grad_attn_mask_tuple: Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None, + fudge_factors: Optional[dict[str, float]] = None ) -> None: """ Check output and gradients of attention mechanism tensors. @@ -2574,7 +2574,7 @@ class TestSDPACudaOnly(NNTestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") def test_cudnn_attention_preserves_query_layout(self, device): - def test_attention(backend: SDPBackend, permute_order: List[List[int]]): + def test_attention(backend: SDPBackend, permute_order: list[list[int]]): BHSqD = [4, 16, 256, 64] BHSkvD = [4, 16, 512, 64] @@ -2602,7 +2602,7 @@ class TestSDPACudaOnly(NNTestCase): @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) - def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): + def test_mem_efficient_attention_mask_variants(self, device, mask_dim: list[int]): dtype = torch.float16 make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) batch, num_heads, head_dim = 8, 8, 64 @@ -3307,7 +3307,7 @@ class TestSDPACudaOnly(NNTestCase): @tf32_enabled() def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, - scale: str, enable_gqa: bool, n_heads: List[int]): + scale: str, enable_gqa: bool, n_heads: list[int]): if isSM8XDevice and head_dim in range(193, 256 + 1): self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: @@ -3905,7 +3905,7 @@ class TestAttnBias(NNTestCase): "shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)], ) - def test_causal_variants(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): + def test_causal_variants(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) @@ -3942,7 +3942,7 @@ class TestAttnBias(NNTestCase): ) @unittest.skipIf(IS_WINDOWS, "torch.compile is not supported on windows") @skipIfTorchDynamo("This function already calls torch.compile.") - def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: List[Tuple[int]]): + def test_causal_variants_compile(self, device, causal_variant: CausalVariant, shape: list[tuple[int]]): if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: self.skipTest("No support for LOWER_RIGHT variant for now") return @@ -3975,7 +3975,7 @@ class TestAttnBias(NNTestCase): self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") @parametrize("shape", [(16, 16, 128, 128, 16), (16, 16, 128, 256, 32), (16, 16, 256, 128, 32), (1, 1, 23, 56, 15)]) - def test_is_causal_equals_upper_left(self, device, shape: List[Tuple[int]]): + def test_is_causal_equals_upper_left(self, device, shape: list[tuple[int]]): make_tensor = partial( torch.rand, device=device, dtype=torch.float16, requires_grad=True ) diff --git a/test/test_typing.py b/test/test_typing.py index bd7998fee7f2..6c265526e2cb 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -8,7 +8,7 @@ import shutil import unittest from collections import defaultdict from threading import Lock -from typing import Dict, IO, List, Optional +from typing import IO, Optional from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -49,12 +49,12 @@ def _strip_filename(msg: str) -> str: return tail.split(":", 1)[-1] -def _run_mypy() -> Dict[str, List[str]]: +def _run_mypy() -> dict[str, list[str]]: """Clears the cache and run mypy before running any of the typing tests.""" if os.path.isdir(CACHE_DIR): shutil.rmtree(CACHE_DIR) - rc: Dict[str, List[str]] = {} + rc: dict[str, list[str]] = {} for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR): # Run mypy stdout, stderr, _ = api.run( @@ -119,10 +119,10 @@ def _construct_format_dict(): #: A dictionary with all supported format keys (as keys) #: and matching values -FORMAT_DICT: Dict[str, str] = _construct_format_dict() +FORMAT_DICT: dict[str, str] = _construct_format_dict() -def _parse_reveals(file: IO[str]) -> List[str]: +def _parse_reveals(file: IO[str]) -> list[str]: """Extract and parse all ``" # E: "`` comments from the passed file-like object. All format keys will be substituted for their respective value from `FORMAT_DICT`, @@ -160,10 +160,10 @@ def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> N @unittest.skipIf(NO_MYPY, reason="Mypy is not installed") class TestTyping(TestCase): _lock = Lock() - _cached_output: Optional[Dict[str, List[str]]] = None + _cached_output: Optional[dict[str, list[str]]] = None @classmethod - def get_mypy_output(cls) -> Dict[str, List[str]]: + def get_mypy_output(cls) -> dict[str, list[str]]: with cls._lock: if cls._cached_output is None: cls._cached_output = _run_mypy() @@ -192,7 +192,7 @@ class TestTyping(TestCase): with open(path) as fin: lines = fin.readlines() - errors = defaultdict(lambda: "") + errors = defaultdict(str) output_mypy = self.get_mypy_output() self.assertIn(path, output_mypy) diff --git a/test/test_utils.py b/test/test_utils.py index 1b96f5551fe4..6ddc600be21f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -12,7 +12,7 @@ import textwrap import traceback import unittest import warnings -from typing import Any, Dict, List +from typing import Any import torch import torch.cuda @@ -439,7 +439,7 @@ class TestCheckpoint(TestCase): # get de-allocated directly. So using cuda memory usage as a proxy def _do_test(fn, should_free): - stats: List[int] = [] + stats: list[int] = [] def track(x, idx): # Track that at each step of the backward, some Tensor were @@ -1203,7 +1203,7 @@ def f(x): return g(x) + 1 """ - out: Dict[str, Any] = {} + out: dict[str, Any] = {} scope = {"__compile_source__": source} exec(source, scope, out) diff --git a/torch/fx/passes/annotate_getitem_nodes.py b/torch/fx/passes/annotate_getitem_nodes.py index 0399cef52620..0a31a76420b3 100644 --- a/torch/fx/passes/annotate_getitem_nodes.py +++ b/torch/fx/passes/annotate_getitem_nodes.py @@ -7,7 +7,7 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: """ Annotate the type of getitem nodes, inferred from the type of sequence node. If sequence node is not annotated with a type, do nothing. - Currently support getitem nodes from Tuple, List, and NamedTuple sequence node. + Currently support getitem nodes from tuple, list, and NamedTuple sequence node. This is helpful since annotations on local names within function are lost during FX transforms. Adding back known type annotation for getitem nodes to improve jit scriptability. @@ -35,6 +35,21 @@ def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: elif sequence_node.type._name == "List": assert len(parameterized_types) == 1 node.type = parameterized_types[0] + # Generic Alias Type + elif hasattr(sequence_node.type, "__origin__"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type.__origin__ is tuple: + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type.__origin__ is list: + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] # NamedTuple type elif hasattr(sequence_node.type, "__annotations__"): if sequence_node.type == torch.Tensor: