[1/N] Apply py39 ruff fixes (#138578)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138578
Approved by: https://github.com/Skylion007
This commit is contained in:
cyy
2024-12-02 21:46:15 +00:00
committed by PyTorch MergeBot
parent b47bdb06d8
commit 55250b324d
26 changed files with 118 additions and 42 deletions

View File

@ -6,7 +6,7 @@
# files.
[mypy]
python_version = 3.8
python_version = 3.9
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
cache_dir = .mypy_cache/strict

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import re
from dataclasses import dataclass
from typing import cast, Sequence
from typing import cast, TYPE_CHECKING
from torchgen import local
from torchgen.api import cpp
@ -20,6 +20,10 @@ from torchgen.model import (
from torchgen.utils import IDENT_REGEX
if TYPE_CHECKING:
from collections.abc import Sequence
# Represents a saved attribute involved in backward calculation.
# Note that it can be a derived property of an input argument, e.g.:
# we could save `other.scalar_type()` instead of the entire `other` tensor.

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen import local
from torchgen.api.types import (
@ -51,6 +51,10 @@ from torchgen.model import (
from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
#

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import itertools
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType
@ -16,6 +16,10 @@ from torchgen.model import (
from torchgen.utils import assert_never, concatMap
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the dispatcher
# API, the *unboxed* calling convention by which invocations through
# the dispatcher are made. Historically, the dispatcher API matched

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen import local
from torchgen.api import cpp
@ -32,6 +32,10 @@ from torchgen.model import (
from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
@ -20,6 +20,10 @@ from torchgen.model import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Data Models

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import NoReturn, Sequence
from typing import NoReturn, TYPE_CHECKING
from torchgen.api.types import (
ArrayRefCType,
@ -36,6 +36,10 @@ from torchgen.api.types import (
)
if TYPE_CHECKING:
from collections.abc import Sequence
# This file implements a small program synthesis engine that implements
# conversions between one API to another.
#

View File

@ -1,12 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from torchgen.model import (
BackendIndex,
FunctionSchema,

View File

@ -1,7 +1,11 @@
from __future__ import annotations
import re
from typing import Mapping, Sequence
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
# match $identifier or ${identifier} and replace with value in env

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import contextlib
import functools
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
import torchgen.local as local
from torchgen.model import (
@ -15,6 +15,10 @@ from torchgen.model import (
from torchgen.utils import context, S, T
if TYPE_CHECKING:
from collections.abc import Iterator
# Helper functions for defining generators on things in the model
F = TypeVar(

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
@ -30,6 +30,8 @@ from torchgen.utils import OrderedSet
if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.api.ufunc import UfunctorBindings

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
from torchgen import dest
@ -15,6 +15,8 @@ from torchgen.utils import concatMap, Target
if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen import local
from torchgen.api.types import (
@ -40,6 +40,10 @@ from torchgen.model import (
from torchgen.utils import assert_never
if TYPE_CHECKING:
from collections.abc import Sequence
"""
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
functions like at::add. It also serves as a native function API, which is the signature of kernels,

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
from torchgen.model import (
Argument,
@ -15,6 +15,8 @@ from torchgen.model import (
if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.api.types import Binding, CType, NamedCType

View File

@ -8,7 +8,7 @@ import os
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, TypeVar
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
import yaml
@ -96,6 +96,10 @@ from torchgen.utils import (
from torchgen.yaml_utils import YamlDumper, YamlLoader
if TYPE_CHECKING:
from collections.abc import Sequence
T = TypeVar("T")
# Welcome to the ATen code generator v2! The ATen code generator is
@ -229,7 +233,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
return rs
@functools.lru_cache(maxsize=None)
@functools.cache
def parse_tags_yaml(path: str) -> set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
@ -24,6 +24,10 @@ from torchgen.model import (
from torchgen.utils import mapMaybe
if TYPE_CHECKING:
from collections.abc import Sequence
base_type_to_c_type = {
BaseTy.Tensor: "AtenTensorHandle",
BaseTy.bool: "int32_t", # Use int to pass bool

View File

@ -5,7 +5,7 @@ import os
import re
from collections import Counter, defaultdict, namedtuple
from pathlib import Path
from typing import Sequence
from typing import TYPE_CHECKING
import yaml
@ -28,6 +28,10 @@ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Tar
from torchgen.yaml_utils import YamlLoader
if TYPE_CHECKING:
from collections.abc import Sequence
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
ParsedExternalYaml = namedtuple(

View File

@ -5,7 +5,7 @@ import os
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
from typing import Any, Callable, TextIO, TYPE_CHECKING
import yaml
@ -57,6 +57,8 @@ from torchgen.utils import (
if TYPE_CHECKING:
from collections.abc import Sequence
from torchgen.selective_build.selector import SelectiveBuilder

View File

@ -4,7 +4,7 @@ import argparse
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Sequence
from typing import Any, Callable, TYPE_CHECKING
import yaml
@ -25,6 +25,10 @@ from torchgen.utils import FileManager, NamespaceHelper
from torchgen.yaml_utils import YamlLoader
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# Lazy Tensor Codegen

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from typing import TYPE_CHECKING
from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature
@ -22,6 +22,10 @@ from torchgen.model import (
from torchgen.utils import mapMaybe
if TYPE_CHECKING:
from collections.abc import Sequence
def is_tensor(typ: Type) -> bool:
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor

View File

@ -2,7 +2,11 @@ from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import Iterator
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterator
# Simple dynamic scoping implementation. The name "parametrize" comes

View File

@ -5,11 +5,15 @@ import itertools
import re
from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, Iterator, List, Sequence
from typing import Callable, List, TYPE_CHECKING
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# DATA MODEL

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import string
from collections import defaultdict
from typing import Sequence
from typing import TYPE_CHECKING
import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate
@ -30,6 +30,10 @@ from torchgen.model import (
from torchgen.utils import concatMap
if TYPE_CHECKING:
from collections.abc import Sequence
# See Note: [Out ops with functional variants that don't get grouped properly]
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# This has a functional variant, but it's currently marked private.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import argparse
import itertools
import os
from typing import Sequence, TypeVar, Union
from typing import TYPE_CHECKING, TypeVar, Union
from libfb.py.log import set_simple_logging # type: ignore[import]
@ -13,6 +13,10 @@ from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsVie
from torchgen.static_runtime import config, generator
if TYPE_CHECKING:
from collections.abc import Sequence
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
# lists each of which groups ops that share the base name. For example, `mean` and
# `mean.dim` are grouped together by this function.

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import math
from typing import Sequence
from typing import TYPE_CHECKING
import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
@ -23,6 +23,10 @@ from torchgen.model import (
from torchgen.static_runtime import config
if TYPE_CHECKING:
from collections.abc import Sequence
logger: logging.Logger = logging.getLogger()

View File

@ -10,18 +10,7 @@ import textwrap
from dataclasses import fields, is_dataclass
from enum import auto, Enum
from pathlib import Path
from typing import (
Any,
Callable,
Generic,
Iterable,
Iterator,
Literal,
NoReturn,
Sequence,
TYPE_CHECKING,
TypeVar,
)
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
from typing_extensions import Self
from torchgen.code_template import CodeTemplate
@ -29,6 +18,7 @@ from torchgen.code_template import CodeTemplate
if TYPE_CHECKING:
from argparse import Namespace
from collections.abc import Iterable, Iterator, Sequence
REPO_ROOT = Path(__file__).absolute().parent.parent
@ -113,7 +103,7 @@ def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError(f"Unhandled type: {type(x).__name__}")
@functools.lru_cache(maxsize=None)
@functools.cache
def _read_template(template_fn: str) -> CodeTemplate:
return CodeTemplate.from_file(template_fn)