From 55250b324d3a9506c0c8d6e06271540b0ead0a3b Mon Sep 17 00:00:00 2001 From: cyy Date: Mon, 2 Dec 2024 21:46:15 +0000 Subject: [PATCH] [1/N] Apply py39 ruff fixes (#138578) Pull Request resolved: https://github.com/pytorch/pytorch/pull/138578 Approved by: https://github.com/Skylion007 --- mypy-strict.ini | 2 +- torchgen/api/autograd.py | 6 +++++- torchgen/api/cpp.py | 6 +++++- torchgen/api/dispatcher.py | 6 +++++- torchgen/api/native.py | 6 +++++- torchgen/api/python.py | 6 +++++- torchgen/api/translate.py | 6 +++++- torchgen/api/types/signatures.py | 4 +++- torchgen/code_template.py | 6 +++++- torchgen/context.py | 6 +++++- torchgen/dest/ufunc.py | 4 +++- torchgen/executorch/api/custom_ops.py | 4 +++- torchgen/executorch/api/et_cpp.py | 6 +++++- torchgen/executorch/api/unboxing.py | 4 +++- torchgen/gen.py | 8 ++++++-- torchgen/gen_aoti_c_shim.py | 10 +++++++--- torchgen/gen_backend_stubs.py | 6 +++++- torchgen/gen_executorch.py | 4 +++- torchgen/gen_lazy_tensor.py | 6 +++++- torchgen/gen_vmap_plumbing.py | 8 ++++++-- torchgen/local.py | 6 +++++- torchgen/model.py | 6 +++++- torchgen/native_function_generation.py | 6 +++++- .../static_runtime/gen_static_runtime_ops.py | 6 +++++- torchgen/static_runtime/generator.py | 6 +++++- torchgen/utils.py | 16 +++------------- 26 files changed, 118 insertions(+), 42 deletions(-) diff --git a/mypy-strict.ini b/mypy-strict.ini index 82aa689cff2e..2feea92cb8c0 100644 --- a/mypy-strict.ini +++ b/mypy-strict.ini @@ -6,7 +6,7 @@ # files. [mypy] -python_version = 3.8 +python_version = 3.9 plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin cache_dir = .mypy_cache/strict diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 644069395e1d..3f3b793825c9 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -2,7 +2,7 @@ from __future__ import annotations import re from dataclasses import dataclass -from typing import cast, Sequence +from typing import cast, TYPE_CHECKING from torchgen import local from torchgen.api import cpp @@ -20,6 +20,10 @@ from torchgen.model import ( from torchgen.utils import IDENT_REGEX +if TYPE_CHECKING: + from collections.abc import Sequence + + # Represents a saved attribute involved in backward calculation. # Note that it can be a derived property of an input argument, e.g.: # we could save `other.scalar_type()` instead of the entire `other` tensor. diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 6cc40d66037d..c46f265b515f 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING from torchgen import local from torchgen.api.types import ( @@ -51,6 +51,10 @@ from torchgen.model import ( from torchgen.utils import assert_never +if TYPE_CHECKING: + from collections.abc import Sequence + + # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. # diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index 103e6cf42990..4cc6186d7e0e 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from typing import Sequence +from typing import TYPE_CHECKING from torchgen.api import cpp from torchgen.api.types import ArgName, Binding, CType, NamedCType @@ -16,6 +16,10 @@ from torchgen.model import ( from torchgen.utils import assert_never, concatMap +if TYPE_CHECKING: + from collections.abc import Sequence + + # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through # the dispatcher are made. Historically, the dispatcher API matched diff --git a/torchgen/api/native.py b/torchgen/api/native.py index a00e8266b8da..82bc051a6832 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING from torchgen import local from torchgen.api import cpp @@ -32,6 +32,10 @@ from torchgen.model import ( from torchgen.utils import assert_never +if TYPE_CHECKING: + from collections.abc import Sequence + + # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the # idea was you wrote native functions to implement functions in the C++ API), diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 7c27e815b5e9..5552eede5749 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence +from typing import TYPE_CHECKING from torchgen.api import cpp from torchgen.api.types import Binding, CppSignature, CppSignatureGroup @@ -20,6 +20,10 @@ from torchgen.model import ( ) +if TYPE_CHECKING: + from collections.abc import Sequence + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Data Models diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 6e62816cac69..f98ce09bbfaf 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import NoReturn, Sequence +from typing import NoReturn, TYPE_CHECKING from torchgen.api.types import ( ArrayRefCType, @@ -36,6 +36,10 @@ from torchgen.api.types import ( ) +if TYPE_CHECKING: + from collections.abc import Sequence + + # This file implements a small program synthesis engine that implements # conversions between one API to another. # diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index 7e0a4b91037a..d7c60e52d93a 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -1,12 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Iterator, Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING from torchgen.api.types.types_base import Binding, CType, Expr if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + from torchgen.model import ( BackendIndex, FunctionSchema, diff --git a/torchgen/code_template.py b/torchgen/code_template.py index cdb86a480642..8c33aad126f8 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,7 +1,11 @@ from __future__ import annotations import re -from typing import Mapping, Sequence +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence # match $identifier or ${identifier} and replace with value in env diff --git a/torchgen/context.py b/torchgen/context.py index d257bf99243d..afdd6ea58dbf 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -2,7 +2,7 @@ from __future__ import annotations import contextlib import functools -from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union import torchgen.local as local from torchgen.model import ( @@ -15,6 +15,10 @@ from torchgen.model import ( from torchgen.utils import context, S, T +if TYPE_CHECKING: + from collections.abc import Iterator + + # Helper functions for defining generators on things in the model F = TypeVar( diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 3acf098f2323..8bb873d8f589 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING import torchgen.api.ufunc as ufunc from torchgen.api.translate import translate @@ -30,6 +30,8 @@ from torchgen.utils import OrderedSet if TYPE_CHECKING: + from collections.abc import Sequence + from torchgen.api.ufunc import UfunctorBindings diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py index cb56c34b660d..45f7f8e3cda8 100644 --- a/torchgen/executorch/api/custom_ops.py +++ b/torchgen/executorch/api/custom_ops.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections import defaultdict from dataclasses import dataclass -from typing import Sequence, TYPE_CHECKING +from typing import TYPE_CHECKING from torchgen import dest @@ -15,6 +15,8 @@ from torchgen.utils import concatMap, Target if TYPE_CHECKING: + from collections.abc import Sequence + from torchgen.executorch.model import ETKernelIndex from torchgen.selective_build.selector import SelectiveBuilder diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index e4e92ff58d1e..554a63864e09 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Sequence +from typing import TYPE_CHECKING from torchgen import local from torchgen.api.types import ( @@ -40,6 +40,10 @@ from torchgen.model import ( from torchgen.utils import assert_never +if TYPE_CHECKING: + from collections.abc import Sequence + + """ This file describes the translation of JIT schema to the public C++ API, which is what people use when they call functions like at::add. It also serves as a native function API, which is the signature of kernels, diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index 999147212a1a..6d648f715114 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Callable, Sequence, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from torchgen.model import ( Argument, @@ -15,6 +15,8 @@ from torchgen.model import ( if TYPE_CHECKING: + from collections.abc import Sequence + from torchgen.api.types import Binding, CType, NamedCType diff --git a/torchgen/gen.py b/torchgen/gen.py index 33d7944d8b9d..354297a9240d 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -8,7 +8,7 @@ import os from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Literal, Sequence, TypeVar +from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar import yaml @@ -96,6 +96,10 @@ from torchgen.utils import ( from torchgen.yaml_utils import YamlDumper, YamlLoader +if TYPE_CHECKING: + from collections.abc import Sequence + + T = TypeVar("T") # Welcome to the ATen code generator v2! The ATen code generator is @@ -229,7 +233,7 @@ def parse_tags_yaml_struct(es: object, path: str = "") -> set[str]: return rs -@functools.lru_cache(maxsize=None) +@functools.cache def parse_tags_yaml(path: str) -> set[str]: global _GLOBAL_PARSE_TAGS_YAML_CACHE if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 67cf64493f91..24a3b0c91381 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -2,7 +2,7 @@ from __future__ import annotations import textwrap from dataclasses import dataclass -from typing import Sequence +from typing import TYPE_CHECKING from torchgen.api.types import DispatcherSignature from torchgen.api.types.signatures import CppSignature, CppSignatureGroup @@ -24,6 +24,10 @@ from torchgen.model import ( from torchgen.utils import mapMaybe +if TYPE_CHECKING: + from collections.abc import Sequence + + base_type_to_c_type = { BaseTy.Tensor: "AtenTensorHandle", BaseTy.bool: "int32_t", # Use int to pass bool @@ -114,14 +118,14 @@ def convert_arg_type_and_name( # type: ignore[return] new_aten_types.append(f"::std::optional<{aten_type}>") base_type = aten_type[len("c10::ArrayRef<") : -1] new_callsite_exprs.append( - f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})" + f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})" ) j += 2 elif aten_type == "c10::Device": # Device is passed as device_type + device_index new_aten_types.append("::std::optional") new_callsite_exprs.append( - f"pointer_to_optional_device({names[j]}, {names[j+1]})" + f"pointer_to_optional_device({names[j]}, {names[j + 1]})" ) j += 2 else: diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 86a355579930..b891c17671fc 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -5,7 +5,7 @@ import os import re from collections import Counter, defaultdict, namedtuple from pathlib import Path -from typing import Sequence +from typing import TYPE_CHECKING import yaml @@ -28,6 +28,10 @@ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Tar from torchgen.yaml_utils import YamlLoader +if TYPE_CHECKING: + from collections.abc import Sequence + + # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping) ParsedExternalYaml = namedtuple( diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index 902ffa3889e6..7d3cf4edb05f 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -5,7 +5,7 @@ import os from collections import defaultdict from dataclasses import dataclass from pathlib import Path -from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING +from typing import Any, Callable, TextIO, TYPE_CHECKING import yaml @@ -57,6 +57,8 @@ from torchgen.utils import ( if TYPE_CHECKING: + from collections.abc import Sequence + from torchgen.selective_build.selector import SelectiveBuilder diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 106e9fb6acbb..a15fa62fd1ee 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -4,7 +4,7 @@ import argparse import os from collections import namedtuple from pathlib import Path -from typing import Any, Callable, Iterable, Iterator, Sequence +from typing import Any, Callable, TYPE_CHECKING import yaml @@ -25,6 +25,10 @@ from torchgen.utils import FileManager, NamespaceHelper from torchgen.yaml_utils import YamlLoader +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Lazy Tensor Codegen diff --git a/torchgen/gen_vmap_plumbing.py b/torchgen/gen_vmap_plumbing.py index af9af6454eb0..0f1f14d45749 100644 --- a/torchgen/gen_vmap_plumbing.py +++ b/torchgen/gen_vmap_plumbing.py @@ -2,7 +2,7 @@ from __future__ import annotations import textwrap from dataclasses import dataclass -from typing import Sequence +from typing import TYPE_CHECKING from torchgen.api.translate import translate from torchgen.api.types import DispatcherSignature @@ -22,6 +22,10 @@ from torchgen.model import ( from torchgen.utils import mapMaybe +if TYPE_CHECKING: + from collections.abc import Sequence + + def is_tensor(typ: Type) -> bool: return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor @@ -111,7 +115,7 @@ def gen_returns( idx += 2 elif is_tensor_list(ret.type): wrapped_returns.append( - f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" + f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})" ) idx += 2 else: diff --git a/torchgen/local.py b/torchgen/local.py index 7c687c3a7991..19045f4a9487 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -2,7 +2,11 @@ from __future__ import annotations import threading from contextlib import contextmanager -from typing import Iterator +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from collections.abc import Iterator # Simple dynamic scoping implementation. The name "parametrize" comes diff --git a/torchgen/model.py b/torchgen/model.py index f4f57f2ae138..bda503a2909c 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -5,11 +5,15 @@ import itertools import re from dataclasses import dataclass from enum import auto, Enum -from typing import Callable, Iterator, List, Sequence +from typing import Callable, List, TYPE_CHECKING from torchgen.utils import assert_never, NamespaceHelper, OrderedSet +if TYPE_CHECKING: + from collections.abc import Iterator, Sequence + + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # DATA MODEL diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index b73bd4447369..83c8344f8a8e 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -2,7 +2,7 @@ from __future__ import annotations import string from collections import defaultdict -from typing import Sequence +from typing import TYPE_CHECKING import torchgen.api.dispatcher as dispatcher from torchgen.api.translate import translate @@ -30,6 +30,10 @@ from torchgen.model import ( from torchgen.utils import concatMap +if TYPE_CHECKING: + from collections.abc import Sequence + + # See Note: [Out ops with functional variants that don't get grouped properly] OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ # This has a functional variant, but it's currently marked private. diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index 9f7357173746..81faef4f1094 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -3,7 +3,7 @@ from __future__ import annotations import argparse import itertools import os -from typing import Sequence, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar, Union from libfb.py.log import set_simple_logging # type: ignore[import] @@ -13,6 +13,10 @@ from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsVie from torchgen.static_runtime import config, generator +if TYPE_CHECKING: + from collections.abc import Sequence + + # Given a list of `grouped_native_functions` sorted by their op names, return a list of # lists each of which groups ops that share the base name. For example, `mean` and # `mean.dim` are grouped together by this function. diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index 02fcbcf0376d..1ed70ec5200f 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -3,7 +3,7 @@ from __future__ import annotations import json import logging import math -from typing import Sequence +from typing import TYPE_CHECKING import torchgen.api.cpp as cpp from torchgen.context import native_function_manager @@ -23,6 +23,10 @@ from torchgen.model import ( from torchgen.static_runtime import config +if TYPE_CHECKING: + from collections.abc import Sequence + + logger: logging.Logger = logging.getLogger() diff --git a/torchgen/utils.py b/torchgen/utils.py index 6d83a27dc9e7..17b8146b672b 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -10,18 +10,7 @@ import textwrap from dataclasses import fields, is_dataclass from enum import auto, Enum from pathlib import Path -from typing import ( - Any, - Callable, - Generic, - Iterable, - Iterator, - Literal, - NoReturn, - Sequence, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar from typing_extensions import Self from torchgen.code_template import CodeTemplate @@ -29,6 +18,7 @@ from torchgen.code_template import CodeTemplate if TYPE_CHECKING: from argparse import Namespace + from collections.abc import Iterable, Iterator, Sequence REPO_ROOT = Path(__file__).absolute().parent.parent @@ -113,7 +103,7 @@ def assert_never(x: NoReturn) -> NoReturn: raise AssertionError(f"Unhandled type: {type(x).__name__}") -@functools.lru_cache(maxsize=None) +@functools.cache def _read_template(template_fn: str) -> CodeTemplate: return CodeTemplate.from_file(template_fn)