mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[1/N] Apply py39 ruff fixes (#138578)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138578 Approved by: https://github.com/Skylion007
This commit is contained in:
@ -6,7 +6,7 @@
|
|||||||
# files.
|
# files.
|
||||||
|
|
||||||
[mypy]
|
[mypy]
|
||||||
python_version = 3.8
|
python_version = 3.9
|
||||||
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin
|
||||||
|
|
||||||
cache_dir = .mypy_cache/strict
|
cache_dir = .mypy_cache/strict
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import cast, Sequence
|
from typing import cast, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
@ -20,6 +20,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import IDENT_REGEX
|
from torchgen.utils import IDENT_REGEX
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# Represents a saved attribute involved in backward calculation.
|
# Represents a saved attribute involved in backward calculation.
|
||||||
# Note that it can be a derived property of an input argument, e.g.:
|
# Note that it can be a derived property of an input argument, e.g.:
|
||||||
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
# we could save `other.scalar_type()` instead of the entire `other` tensor.
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -51,6 +51,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import assert_never
|
from torchgen.utils import assert_never
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# This file describes the translation of JIT schema to the public C++
|
# This file describes the translation of JIT schema to the public C++
|
||||||
# API, which is what people use when they call functions like at::add.
|
# API, which is what people use when they call functions like at::add.
|
||||||
#
|
#
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||||
@ -16,6 +16,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import assert_never, concatMap
|
from torchgen.utils import assert_never, concatMap
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# This file describes the translation of JIT schema to the dispatcher
|
# This file describes the translation of JIT schema to the dispatcher
|
||||||
# API, the *unboxed* calling convention by which invocations through
|
# API, the *unboxed* calling convention by which invocations through
|
||||||
# the dispatcher are made. Historically, the dispatcher API matched
|
# the dispatcher are made. Historically, the dispatcher API matched
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
@ -32,6 +32,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import assert_never
|
from torchgen.utils import assert_never
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# This file describes the translation of JIT schema to the native functions API.
|
# This file describes the translation of JIT schema to the native functions API.
|
||||||
# This looks a lot like the C++ API (which makes historical sense, because the
|
# This looks a lot like the C++ API (which makes historical sense, because the
|
||||||
# idea was you wrote native functions to implement functions in the C++ API),
|
# idea was you wrote native functions to implement functions in the C++ API),
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||||
@ -20,6 +20,10 @@ from torchgen.model import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
#
|
#
|
||||||
# Data Models
|
# Data Models
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import NoReturn, Sequence
|
from typing import NoReturn, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
ArrayRefCType,
|
ArrayRefCType,
|
||||||
@ -36,6 +36,10 @@ from torchgen.api.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# This file implements a small program synthesis engine that implements
|
# This file implements a small program synthesis engine that implements
|
||||||
# conversions between one API to another.
|
# conversions between one API to another.
|
||||||
#
|
#
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, Sequence, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.types.types_base import Binding, CType, Expr
|
from torchgen.api.types.types_base import Binding, CType, Expr
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator, Sequence
|
||||||
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
BackendIndex,
|
BackendIndex,
|
||||||
FunctionSchema,
|
FunctionSchema,
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Mapping, Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
|
||||||
|
|
||||||
# match $identifier or ${identifier} and replace with value in env
|
# match $identifier or ${identifier} and replace with value in env
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, TypeVar, Union
|
||||||
|
|
||||||
import torchgen.local as local
|
import torchgen.local as local
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
@ -15,6 +15,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import context, S, T
|
from torchgen.utils import context, S, T
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
|
||||||
# Helper functions for defining generators on things in the model
|
# Helper functions for defining generators on things in the model
|
||||||
|
|
||||||
F = TypeVar(
|
F = TypeVar(
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.ufunc as ufunc
|
import torchgen.api.ufunc as ufunc
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
@ -30,6 +30,8 @@ from torchgen.utils import OrderedSet
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.api.ufunc import UfunctorBindings
|
from torchgen.api.ufunc import UfunctorBindings
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import dest
|
from torchgen import dest
|
||||||
|
|
||||||
@ -15,6 +15,8 @@ from torchgen.utils import concatMap, Target
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.executorch.model import ETKernelIndex
|
from torchgen.executorch.model import ETKernelIndex
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -40,6 +40,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import assert_never
|
from torchgen.utils import assert_never
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
|
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
|
||||||
functions like at::add. It also serves as a native function API, which is the signature of kernels,
|
functions like at::add. It also serves as a native function API, which is the signature of kernels,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Sequence, TYPE_CHECKING
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
Argument,
|
Argument,
|
||||||
@ -15,6 +15,8 @@ from torchgen.model import (
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.api.types import Binding, CType, NamedCType
|
from torchgen.api.types import Binding, CType, NamedCType
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import os
|
|||||||
from collections import defaultdict, namedtuple, OrderedDict
|
from collections import defaultdict, namedtuple, OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Literal, Sequence, TypeVar
|
from typing import Any, Callable, Literal, TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -96,6 +96,10 @@ from torchgen.utils import (
|
|||||||
from torchgen.yaml_utils import YamlDumper, YamlLoader
|
from torchgen.yaml_utils import YamlDumper, YamlLoader
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
# Welcome to the ATen code generator v2! The ATen code generator is
|
# Welcome to the ATen code generator v2! The ATen code generator is
|
||||||
@ -229,7 +233,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
|
|||||||
return rs
|
return rs
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.cache
|
||||||
def parse_tags_yaml(path: str) -> set[str]:
|
def parse_tags_yaml(path: str) -> set[str]:
|
||||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||||
@ -24,6 +24,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import mapMaybe
|
from torchgen.utils import mapMaybe
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
base_type_to_c_type = {
|
base_type_to_c_type = {
|
||||||
BaseTy.Tensor: "AtenTensorHandle",
|
BaseTy.Tensor: "AtenTensorHandle",
|
||||||
BaseTy.bool: "int32_t", # Use int to pass bool
|
BaseTy.bool: "int32_t", # Use int to pass bool
|
||||||
@ -114,14 +118,14 @@ def convert_arg_type_and_name( # type: ignore[return]
|
|||||||
new_aten_types.append(f"::std::optional<{aten_type}>")
|
new_aten_types.append(f"::std::optional<{aten_type}>")
|
||||||
base_type = aten_type[len("c10::ArrayRef<") : -1]
|
base_type = aten_type[len("c10::ArrayRef<") : -1]
|
||||||
new_callsite_exprs.append(
|
new_callsite_exprs.append(
|
||||||
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j+1]})"
|
f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
|
||||||
)
|
)
|
||||||
j += 2
|
j += 2
|
||||||
elif aten_type == "c10::Device":
|
elif aten_type == "c10::Device":
|
||||||
# Device is passed as device_type + device_index
|
# Device is passed as device_type + device_index
|
||||||
new_aten_types.append("::std::optional<c10::Device>")
|
new_aten_types.append("::std::optional<c10::Device>")
|
||||||
new_callsite_exprs.append(
|
new_callsite_exprs.append(
|
||||||
f"pointer_to_optional_device({names[j]}, {names[j+1]})"
|
f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
|
||||||
)
|
)
|
||||||
j += 2
|
j += 2
|
||||||
else:
|
else:
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
from collections import Counter, defaultdict, namedtuple
|
from collections import Counter, defaultdict, namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -28,6 +28,10 @@ from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Tar
|
|||||||
from torchgen.yaml_utils import YamlLoader
|
from torchgen.yaml_utils import YamlLoader
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
|
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
|
||||||
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
|
# Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
|
||||||
ParsedExternalYaml = namedtuple(
|
ParsedExternalYaml = namedtuple(
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
|
from typing import Any, Callable, TextIO, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -57,6 +57,8 @@ from torchgen.utils import (
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Iterable, Iterator, Sequence
|
from typing import Any, Callable, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -25,6 +25,10 @@ from torchgen.utils import FileManager, NamespaceHelper
|
|||||||
from torchgen.yaml_utils import YamlLoader
|
from torchgen.yaml_utils import YamlLoader
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
#
|
#
|
||||||
# Lazy Tensor Codegen
|
# Lazy Tensor Codegen
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
@ -22,6 +22,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import mapMaybe
|
from torchgen.utils import mapMaybe
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
def is_tensor(typ: Type) -> bool:
|
def is_tensor(typ: Type) -> bool:
|
||||||
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
|
return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
|
||||||
|
|
||||||
@ -111,7 +115,7 @@ def gen_returns(
|
|||||||
idx += 2
|
idx += 2
|
||||||
elif is_tensor_list(ret.type):
|
elif is_tensor_list(ret.type):
|
||||||
wrapped_returns.append(
|
wrapped_returns.append(
|
||||||
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
|
f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
|
||||||
)
|
)
|
||||||
idx += 2
|
idx += 2
|
||||||
else:
|
else:
|
||||||
|
@ -2,7 +2,11 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Iterator
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
|
||||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||||
|
@ -5,11 +5,15 @@ import itertools
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import Callable, Iterator, List, Sequence
|
from typing import Callable, List, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
#
|
#
|
||||||
# DATA MODEL
|
# DATA MODEL
|
||||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import string
|
import string
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.dispatcher as dispatcher
|
import torchgen.api.dispatcher as dispatcher
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
@ -30,6 +30,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import concatMap
|
from torchgen.utils import concatMap
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# See Note: [Out ops with functional variants that don't get grouped properly]
|
# See Note: [Out ops with functional variants that don't get grouped properly]
|
||||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
||||||
# This has a functional variant, but it's currently marked private.
|
# This has a functional variant, but it's currently marked private.
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
from typing import Sequence, TypeVar, Union
|
from typing import TYPE_CHECKING, TypeVar, Union
|
||||||
|
|
||||||
from libfb.py.log import set_simple_logging # type: ignore[import]
|
from libfb.py.log import set_simple_logging # type: ignore[import]
|
||||||
|
|
||||||
@ -13,6 +13,10 @@ from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsVie
|
|||||||
from torchgen.static_runtime import config, generator
|
from torchgen.static_runtime import config, generator
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||||
# `mean.dim` are grouped together by this function.
|
# `mean.dim` are grouped together by this function.
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.cpp as cpp
|
import torchgen.api.cpp as cpp
|
||||||
from torchgen.context import native_function_manager
|
from torchgen.context import native_function_manager
|
||||||
@ -23,6 +23,10 @@ from torchgen.model import (
|
|||||||
from torchgen.static_runtime import config
|
from torchgen.static_runtime import config
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
logger: logging.Logger = logging.getLogger()
|
logger: logging.Logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,18 +10,7 @@ import textwrap
|
|||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import Any, Callable, Generic, Literal, NoReturn, TYPE_CHECKING, TypeVar
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
Literal,
|
|
||||||
NoReturn,
|
|
||||||
Sequence,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
@ -29,6 +18,7 @@ from torchgen.code_template import CodeTemplate
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
|
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).absolute().parent.parent
|
REPO_ROOT = Path(__file__).absolute().parent.parent
|
||||||
@ -113,7 +103,7 @@ def assert_never(x: NoReturn) -> NoReturn:
|
|||||||
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
raise AssertionError(f"Unhandled type: {type(x).__name__}")
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.cache
|
||||||
def _read_template(template_fn: str) -> CodeTemplate:
|
def _read_template(template_fn: str) -> CodeTemplate:
|
||||||
return CodeTemplate.from_file(template_fn)
|
return CodeTemplate.from_file(template_fn)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user