mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/N] Apply py39 ruff fixes (#138578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138578 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -6,7 +6,7 @@
|
||||
# files.
|
||||
|
||||
[mypy]
|
||||
python_version = 3.8
|
||||
python_version = 3.9
|
||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||
|
||||
cache_dir = .mypy_cache/strict
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Sequence
|
||||
from typing import cast, TYPE_CHECKING
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
@ -20,6 +20,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import IDENT_REGEX
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# Represents a saved attribute involved in backward calculation.
|
||||
# Note that it can be a derived property of an input argument, e.g.:
|
||||
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
@ -51,6 +51,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to the public C++
|
||||
# API, which is what people use when they call functions like at::add.
|
||||
#
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||
@ -16,6 +16,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import assert_never, concatMap
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to the dispatcher
|
||||
# API, the *unboxed* calling convention by which invocations through
|
||||
# the dispatcher are made. Historically, the dispatcher API matched
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
@ -32,6 +32,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to the native functions API.
|
||||
# This looks a lot like the C++ API (which makes historical sense, because the
|
||||
# idea was you wrote native functions to implement functions in the C++ API),
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||
@ -20,6 +20,10 @@ from torchgen.model import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# Data Models
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn, Sequence
|
||||
from typing import NoReturn, TYPE_CHECKING
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArrayRefCType,
|
||||
@ -36,6 +36,10 @@ from torchgen.api.types import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# This file implements a small program synthesis engine that implements
|
||||
# conversions between one API to another.
|
||||
#
|
||||
|
@ -1,12 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api.types.types_base import Binding, CType, Expr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
|
@ -1,7 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
|
||||
# match $identifier or ${identifier} and replace with value in env
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import torchgen.local as local
|
||||
from torchgen.model import (
|
||||
@ -15,6 +15,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import context, S, T
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
# Helper functions for defining generators on things in the model
|
||||
|
||||
F = TypeVar(
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torchgen.api.ufunc as ufunc
|
||||
from torchgen.api.translate import translate
|
||||
@ -30,6 +30,8 @@ from torchgen.utils import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.api.ufunc import UfunctorBindings
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen import dest
|
||||
|
||||
@ -15,6 +15,8 @@ from torchgen.utils import concatMap, Target
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.executorch.model import ETKernelIndex
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
@ -40,6 +40,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
"""
|
||||
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
|
||||
functions like at::add. It also serves as a native function API, which is the signature of kernels,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Sequence, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
@ -15,6 +15,8 @@ from torchgen.model import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.api.types import Binding, CType, NamedCType
|
||||
|
||||
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Sequence, TypeVar
|
||||
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
||||
|
||||
import yaml
|
||||
|
||||
@ -96,6 +96,10 @@ from torchgen.utils import (
|
||||
from torchgen.yaml_utils import YamlDumper, YamlLoader
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Welcome to the ATen code generator v2! The ATen code generator is
|
||||
@ -229,7 +233,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
|
||||
return rs
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def parse_tags_yaml(path: str) -> set[str]:
|
||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
@ -24,6 +24,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
base_type_to_c_type = {
|
||||
BaseTy.Tensor: "AtenTensorHandle",
|
||||
BaseTy.bool: "int32_t", # Use int to pass bool
|
||||
@ -114,14 +118,14 @@ def convert_arg_type_and_name( # type: ignore[return]
|
||||
new_aten_types.append(f"::std::optional<{aten_type}>")
|
||||
base_type = aten_type[len("c10::ArrayRef<") : -1]
|
||||
new_callsite_exprs.append(
|
||||
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
|
||||
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
|
||||
)
|
||||
j += 2
|
||||
elif aten_type == "c10::Device":
|
||||
# Device is passed as device_type + device_index
|
||||
new_aten_types.append("::std::optional<c10::Device>")
|
||||
new_callsite_exprs.append(
|
||||
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
|
||||
f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
|
||||
)
|
||||
j += 2
|
||||
else:
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import re
|
||||
from collections import Counter, defaultdict, namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
|
||||
@ -28,6 +28,10 @@ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Tar
|
||||
from torchgen.yaml_utils import YamlLoader
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
|
||||
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
|
||||
ParsedExternalYaml = namedtuple(
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
|
||||
from typing import Any, Callable, TextIO, TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
|
||||
@ -57,6 +57,8 @@ from torchgen.utils import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
|
@ -4,7 +4,7 @@ import argparse
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence
|
||||
from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import yaml
|
||||
|
||||
@ -25,6 +25,10 @@ from torchgen.utils import FileManager, NamespaceHelper
|
||||
from torchgen.yaml_utils import YamlLoader
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# Lazy Tensor Codegen
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
@ -22,6 +22,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
def is_tensor(typ: Type) -> bool:
|
||||
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
|
||||
|
||||
@ -111,7 +115,7 @@ def gen_returns(
|
||||
idx += 2
|
||||
elif is_tensor_list(ret.type):
|
||||
wrapped_returns.append(
|
||||
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
|
||||
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
|
||||
)
|
||||
idx += 2
|
||||
else:
|
||||
|
@ -2,7 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator
|
||||
|
||||
|
||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||
|
@ -5,11 +5,15 @@ import itertools
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Callable, Iterator, List, Sequence
|
||||
from typing import Callable, List, TYPE_CHECKING
|
||||
|
||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# DATA MODEL
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import string
|
||||
from collections import defaultdict
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
@ -30,6 +30,10 @@ from torchgen.model import (
|
||||
from torchgen.utils import concatMap
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# See Note: [Out ops with functional variants that don't get grouped properly]
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
||||
# This has a functional variant, but it's currently marked private.
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
from typing import Sequence, TypeVar, Union
|
||||
from typing import TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
from libfb.py.log import set_simple_logging # type: ignore[import]
|
||||
|
||||
@ -13,6 +13,10 @@ from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsVie
|
||||
from torchgen.static_runtime import config, generator
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||
# `mean.dim` are grouped together by this function.
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torchgen.api.cpp as cpp
|
||||
from torchgen.context import native_function_manager
|
||||
@ -23,6 +23,10 @@ from torchgen.model import (
|
||||
from torchgen.static_runtime import config
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
logger: logging.Logger = logging.getLogger()
|
||||
|
||||
|
||||
|
@ -10,18 +10,7 @@ import textwrap
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
|
||||
from typing_extensions import Self
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
@ -29,6 +18,7 @@ from torchgen.code_template import CodeTemplate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
||||
@ -113,7 +103,7 @@ def assert_never(x: NoReturn) -> NoReturn:
|
||||
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@functools.cache
|
||||
def _read_template(template_fn: str) -> CodeTemplate:
|
||||
return CodeTemplate.from_file(template_fn)
|
||||
|
||||
|
Reference in New Issue
Block a user