[1/N] Apply py39 ruff fixes (#138578)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138578
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-12-02 21:46:15 +00:00
committed by PyTorch MergeBot
parent b47bdb06d8
commit 55250b324d
26 changed files with 118 additions and 42 deletions

View File

@ -6,7 +6,7 @@
# files. # files.
[mypy] [mypy]
python_version = 3.8 python_version = 3.9
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/strict cache_dir = .mypy_cache/strict

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import cast, Sequence from typing import cast, TYPE_CHECKING
from torchgen import local from torchgen import local
from torchgen.api import cpp from torchgen.api import cpp
@ -20,6 +20,10 @@ from torchgen.model import (
from torchgen.utils import IDENT_REGEX from torchgen.utils import IDENT_REGEX
if TYPE_CHECKING:
from collections.abc import Sequence
# Represents a saved attribute involved in backward calculation. # Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.: # Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor. # we could save `other.scalar_type()` instead of the entire `other` tensor.

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import TYPE_CHECKING
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -51,6 +51,10 @@ from torchgen.model import (
from torchgen.utils import assert_never from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the public C++ # This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add. # API, which is what people use when they call functions like at::add.
# #

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import itertools import itertools
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.api.types import ArgName, Binding, CType, NamedCType
@ -16,6 +16,10 @@ from torchgen.model import (
from torchgen.utils import assert_never, concatMap from torchgen.utils import assert_never, concatMap
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the dispatcher # This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through # API, the *unboxed* calling convention by which invocations through
# the dispatcher are made. Historically, the dispatcher API matched # the dispatcher are made. Historically, the dispatcher API matched

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import TYPE_CHECKING
from torchgen import local from torchgen import local
from torchgen.api import cpp from torchgen.api import cpp
@ -32,6 +32,10 @@ from torchgen.model import (
from torchgen.utils import assert_never from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the native functions API. # This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the # This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API), # idea was you wrote native functions to implement functions in the C++ API),

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
@ -20,6 +20,10 @@ from torchgen.model import (
) )
if TYPE_CHECKING:
from collections.abc import Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# #
# Data Models # Data Models

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import NoReturn, Sequence from typing import NoReturn, TYPE_CHECKING
from torchgen.api.types import ( from torchgen.api.types import (
ArrayRefCType, ArrayRefCType,
@ -36,6 +36,10 @@ from torchgen.api.types import (
) )
if TYPE_CHECKING:
from collections.abc import Sequence
# This file implements a small program synthesis engine that implements # This file implements a small program synthesis engine that implements
# conversions between one API to another. # conversions between one API to another.
# #

View File

@ -1,12 +1,14 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterator, Sequence, TYPE_CHECKING from typing import TYPE_CHECKING
from torchgen.api.types.types_base import Binding, CType, Expr from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from torchgen.model import ( from torchgen.model import (
BackendIndex, BackendIndex,
FunctionSchema, FunctionSchema,

View File

@ -1,7 +1,11 @@
from __future__ import annotations from __future__ import annotations
import re import re
from typing import Mapping, Sequence from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
# match $identifier or ${identifier} and replace with value in env # match $identifier or ${identifier} and replace with value in env

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import contextlib import contextlib
import functools import functools
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
import torchgen.local as local import torchgen.local as local
from torchgen.model import ( from torchgen.model import (
@ -15,6 +15,10 @@ from torchgen.model import (
from torchgen.utils import context, S, T from torchgen.utils import context, S, T
if TYPE_CHECKING:
from collections.abc import Iterator
# Helper functions for defining generators on things in the model # Helper functions for defining generators on things in the model
F = TypeVar( F = TypeVar(

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING from typing import TYPE_CHECKING
import torchgen.api.ufunc as ufunc import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate from torchgen.api.translate import translate
@ -30,6 +30,8 @@ from torchgen.utils import OrderedSet
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.api.ufunc import UfunctorBindings from torchgen.api.ufunc import UfunctorBindings

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING from typing import TYPE_CHECKING
from torchgen import dest from torchgen import dest
@ -15,6 +15,8 @@ from torchgen.utils import concatMap, Target
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.executorch.model import ETKernelIndex from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence from typing import TYPE_CHECKING
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -40,6 +40,10 @@ from torchgen.model import (
from torchgen.utils import assert_never from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
""" """
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
functions like at::add. It also serves as a native function API, which is the signature of kernels, functions like at::add. It also serves as a native function API, which is the signature of kernels,

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING from typing import Callable, TYPE_CHECKING
from torchgen.model import ( from torchgen.model import (
Argument, Argument,
@ -15,6 +15,8 @@ from torchgen.model import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.api.types import Binding, CType, NamedCType from torchgen.api.types import Binding, CType, NamedCType

View File

@ -8,7 +8,7 @@ import os
from collections import defaultdict, namedtuple, OrderedDict from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Sequence, TypeVar from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
import yaml import yaml
@ -96,6 +96,10 @@ from torchgen.utils import (
from torchgen.yaml_utils import YamlDumper, YamlLoader from torchgen.yaml_utils import YamlDumper, YamlLoader
if TYPE_CHECKING:
from collections.abc import Sequence
T = TypeVar("T") T = TypeVar("T")
# Welcome to the ATen code generator v2! The ATen code generator is # Welcome to the ATen code generator v2! The ATen code generator is
@ -229,7 +233,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
return rs return rs
@functools.lru_cache(maxsize=None) @functools.cache
def parse_tags_yaml(path: str) -> set[str]: def parse_tags_yaml(path: str) -> set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
@ -24,6 +24,10 @@ from torchgen.model import (
from torchgen.utils import mapMaybe from torchgen.utils import mapMaybe
if TYPE_CHECKING:
from collections.abc import Sequence
base_type_to_c_type = { base_type_to_c_type = {
BaseTy.Tensor: "AtenTensorHandle", BaseTy.Tensor: "AtenTensorHandle",
BaseTy.bool: "int32_t", # Use int to pass bool BaseTy.bool: "int32_t", # Use int to pass bool
@ -114,14 +118,14 @@ def convert_arg_type_and_name( # type: ignore[return]
new_aten_types.append(f"::std::optional<{aten_type}>") new_aten_types.append(f"::std::optional<{aten_type}>")
base_type = aten_type[len("c10::ArrayRef<") : -1] base_type = aten_type[len("c10::ArrayRef<") : -1]
new_callsite_exprs.append( new_callsite_exprs.append(
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})" f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
) )
j += 2 j += 2
elif aten_type == "c10::Device": elif aten_type == "c10::Device":
# Device is passed as device_type + device_index # Device is passed as device_type + device_index
new_aten_types.append("::std::optional<c10::Device>") new_aten_types.append("::std::optional<c10::Device>")
new_callsite_exprs.append( new_callsite_exprs.append(
f"pointer_to_optional_device({names[j]}, {names[j+1]})" f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
) )
j += 2 j += 2
else: else:

View File

@ -5,7 +5,7 @@ import os
import re import re
from collections import Counter, defaultdict, namedtuple from collections import Counter, defaultdict, namedtuple
from pathlib import Path from pathlib import Path
from typing import Sequence from typing import TYPE_CHECKING
import yaml import yaml
@ -28,6 +28,10 @@ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Tar
from torchgen.yaml_utils import YamlLoader from torchgen.yaml_utils import YamlLoader
if TYPE_CHECKING:
from collections.abc import Sequence
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping) # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
ParsedExternalYaml = namedtuple( ParsedExternalYaml = namedtuple(

View File

@ -5,7 +5,7 @@ import os
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING from typing import Any, Callable, TextIO, TYPE_CHECKING
import yaml import yaml
@ -57,6 +57,8 @@ from torchgen.utils import (
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.selective_build.selector import SelectiveBuilder from torchgen.selective_build.selector import SelectiveBuilder

View File

@ -4,7 +4,7 @@ import argparse
import os import os
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Sequence from typing import Any, Callable, TYPE_CHECKING
import yaml import yaml
@ -25,6 +25,10 @@ from torchgen.utils import FileManager, NamespaceHelper
from torchgen.yaml_utils import YamlLoader from torchgen.yaml_utils import YamlLoader
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# #
# Lazy Tensor Codegen # Lazy Tensor Codegen

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import Sequence from typing import TYPE_CHECKING
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
@ -22,6 +22,10 @@ from torchgen.model import (
from torchgen.utils import mapMaybe from torchgen.utils import mapMaybe
if TYPE_CHECKING:
from collections.abc import Sequence
def is_tensor(typ: Type) -> bool: def is_tensor(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
@ -111,7 +115,7 @@ def gen_returns(
idx += 2 idx += 2
elif is_tensor_list(ret.type): elif is_tensor_list(ret.type):
wrapped_returns.append( wrapped_returns.append(
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
) )
idx += 2 idx += 2
else: else:

View File

@ -2,7 +2,11 @@ from __future__ import annotations
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterator
# Simple dynamic scoping implementation. The name "parametrize" comes # Simple dynamic scoping implementation. The name "parametrize" comes

View File

@ -5,11 +5,15 @@ import itertools
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import Callable, Iterator, List, Sequence from typing import Callable, List, TYPE_CHECKING
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# #
# DATA MODEL # DATA MODEL

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import string import string
from collections import defaultdict from collections import defaultdict
from typing import Sequence from typing import TYPE_CHECKING
import torchgen.api.dispatcher as dispatcher import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate from torchgen.api.translate import translate
@ -30,6 +30,10 @@ from torchgen.model import (
from torchgen.utils import concatMap from torchgen.utils import concatMap
if TYPE_CHECKING:
from collections.abc import Sequence
# See Note: [Out ops with functional variants that don't get grouped properly] # See Note: [Out ops with functional variants that don't get grouped properly]
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# This has a functional variant, but it's currently marked private. # This has a functional variant, but it's currently marked private.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import argparse import argparse
import itertools import itertools
import os import os
from typing import Sequence, TypeVar, Union from typing import TYPE_CHECKING, TypeVar, Union
from libfb.py.log import set_simple_logging # type: ignore[import] from libfb.py.log import set_simple_logging # type: ignore[import]
@ -13,6 +13,10 @@ from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsVie
from torchgen.static_runtime import config, generator from torchgen.static_runtime import config, generator
if TYPE_CHECKING:
from collections.abc import Sequence
# Given a list of `grouped_native_functions` sorted by their op names, return a list of # Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and # lists each of which groups ops that share the base name. For example, `mean` and
# `mean.dim` are grouped together by this function. # `mean.dim` are grouped together by this function.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
import logging import logging
import math import math
from typing import Sequence from typing import TYPE_CHECKING
import torchgen.api.cpp as cpp import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager from torchgen.context import native_function_manager
@ -23,6 +23,10 @@ from torchgen.model import (
from torchgen.static_runtime import config from torchgen.static_runtime import config
if TYPE_CHECKING:
from collections.abc import Sequence
logger: logging.Logger = logging.getLogger() logger: logging.Logger = logging.getLogger()

View File

@ -10,18 +10,7 @@ import textwrap
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import auto, Enum from enum import auto, Enum
from pathlib import Path from pathlib import Path
from typing import ( from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
Any,
Callable,
Generic,
Iterable,
Iterator,
Literal,
NoReturn,
Sequence,
TYPE_CHECKING,
TypeVar,
)
from typing_extensions import Self from typing_extensions import Self
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
@ -29,6 +18,7 @@ from torchgen.code_template import CodeTemplate
if TYPE_CHECKING: if TYPE_CHECKING:
from argparse import Namespace from argparse import Namespace
from collections.abc import Iterable, Iterator, Sequence
REPO_ROOT = Path(__file__).absolute().parent.parent REPO_ROOT = Path(__file__).absolute().parent.parent
@ -113,7 +103,7 @@ def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError(f"Unhandled type: {type(x).__name__}") raise AssertionError(f"Unhandled type: {type(x).__name__}")
@functools.lru_cache(maxsize=None) @functools.cache
def _read_template(template_fn: str) -> CodeTemplate: def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn) return CodeTemplate.from_file(template_fn)