diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 10b011741d55..38ab7743d822 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple from torchgen import local - from torchgen.api import cpp from torchgen.api.types import BaseCType, Binding, NamedCType, tensorListT from torchgen.model import ( diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index 55ae8758b2b3..0e9d67375c78 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -48,6 +48,7 @@ from torchgen.model import ( ) from torchgen.utils import assert_never + # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. # diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index 58816959f7cd..aa3c97b2d34d 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -2,7 +2,6 @@ import itertools from typing import List, Sequence, Union from torchgen.api import cpp - from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.model import ( Argument, @@ -14,6 +13,7 @@ from torchgen.model import ( ) from torchgen.utils import assert_never, concatMap + # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through # the dispatcher are made. Historically, the dispatcher API matched diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index b14e910be0b8..166c2fc8b53e 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -20,7 +20,6 @@ from torchgen.api.types import ( SymIntT, VectorCType, ) - from torchgen.model import ( Argument, BaseTy, diff --git a/torchgen/api/meta.py b/torchgen/api/meta.py index ad488d303d46..2e99d151faea 100644 --- a/torchgen/api/meta.py +++ b/torchgen/api/meta.py @@ -1,5 +1,6 @@ from torchgen.model import NativeFunctionsGroup + # Follows dispatcher calling convention, but: # - Mutable arguments not allowed. Meta functions are always # written in functional form. Look at FunctionSchema.signature() diff --git a/torchgen/api/native.py b/torchgen/api/native.py index 7f8b3eb3af2e..df06b539d5ee 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -2,7 +2,6 @@ from typing import List, Optional, Sequence, Union from torchgen import local from torchgen.api import cpp - from torchgen.api.types import ( ArgName, BaseCType, @@ -30,6 +29,7 @@ from torchgen.model import ( ) from torchgen.utils import assert_never + # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the # idea was you wrote native functions to implement functions in the C++ API), diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 8d3e6f3b3edd..2026c40f08b9 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -17,6 +17,7 @@ from torchgen.model import ( Variant, ) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # Data Models diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index 392b8a67e01e..e3be72189bbb 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -1,7 +1,6 @@ from typing import List, Union from torchgen.api import cpp - from torchgen.api.types import ( ArgName, ArrayRefCType, @@ -33,6 +32,7 @@ from torchgen.model import ( ) from torchgen.utils import assert_never + # This file describes the translation of JIT schema to the structured functions API. # This is similar to native API, but a number of historical problems with native # API have been fixed. diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 98f0c251acbd..87fc3348b694 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -33,6 +33,7 @@ from torchgen.api.types import ( VectorCType, ) + # This file implements a small program synthesis engine that implements # conversions between one API to another. # diff --git a/torchgen/api/types/__init__.py b/torchgen/api/types/__init__.py index d3e2f9a431b4..a190896f9e01 100644 --- a/torchgen/api/types/__init__.py +++ b/torchgen/api/types/__init__.py @@ -1,3 +1,3 @@ -from .types import * -from .types_base import * -from .signatures import * # isort:skip +from torchgen.api.types.types import * +from torchgen.api.types.types_base import * +from torchgen.api.types.signatures import * # usort:skip diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index f21ab29178e5..0b7abe00012e 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -1,7 +1,7 @@ from dataclasses import dataclass - from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union +from torchgen.api.types.types_base import Binding, CType, Expr from torchgen.model import ( BackendIndex, FunctionSchema, @@ -10,8 +10,6 @@ from torchgen.model import ( NativeFunctionsViewGroup, ) -from .types_base import Binding, CType, Expr - @dataclass(frozen=True) class CppSignature: diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py index debc640a6661..3f0a90c634fc 100644 --- a/torchgen/api/types/types.py +++ b/torchgen/api/types/types.py @@ -15,9 +15,7 @@ Add new types to `types_base.py` if they are basic and not attached to ATen/c10. from dataclasses import dataclass from typing import Dict -from torchgen.model import BaseTy, ScalarType - -from .types_base import ( +from torchgen.api.types.types_base import ( BaseCppType, BaseCType, boolT, @@ -30,6 +28,7 @@ from .types_base import ( longT, shortT, ) +from torchgen.model import BaseTy, ScalarType TENSOR_LIST_LIKE_CTYPES = [ diff --git a/torchgen/api/types/types_base.py b/torchgen/api/types/types_base.py index 2f8561e49abe..e59a4b3d8201 100644 --- a/torchgen/api/types/types_base.py +++ b/torchgen/api/types/types_base.py @@ -19,6 +19,7 @@ from typing import List, Optional, Union from torchgen.model import Argument, SelfArgument, TensorOptionsArguments + # An ArgName is just the str name of the argument in schema; # but in some special circumstances, we may add a little extra # context. The Enum SpecialArgName covers all of these cases; diff --git a/torchgen/api/ufunc.py b/torchgen/api/ufunc.py index 7f044706068c..7981c2b29d7a 100644 --- a/torchgen/api/ufunc.py +++ b/torchgen/api/ufunc.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import List, Optional import torchgen.api.types as api_types - from torchgen.api import cpp, structured from torchgen.api.types import ( ArgName, diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index 7ff0c59c77d2..70128b1845bd 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -12,6 +12,7 @@ from torchgen.model import ( Type, ) + # This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the # ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is # an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the diff --git a/torchgen/code_template.py b/torchgen/code_template.py index b932a94ecc91..b4afde2d7be1 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,6 +1,7 @@ import re from typing import Mapping, Match, Optional, Sequence + # match $identifier or ${identifier} and replace with value in env # If this identifier is at the beginning of whitespace on a line # and its value is a list then it is treated as diff --git a/torchgen/context.py b/torchgen/context.py index f79bde17367e..40e765a97ec9 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -1,5 +1,4 @@ import contextlib - import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union @@ -13,6 +12,7 @@ from torchgen.model import ( ) from torchgen.utils import context, S, T + # Helper functions for defining generators on things in the model F = TypeVar( diff --git a/torchgen/decompositions/gen_jit_decompositions.py b/torchgen/decompositions/gen_jit_decompositions.py index 7a0024f91f25..b42948045cbd 100644 --- a/torchgen/decompositions/gen_jit_decompositions.py +++ b/torchgen/decompositions/gen_jit_decompositions.py @@ -4,6 +4,7 @@ from pathlib import Path from torch.jit._decompositions import decomposition_table + # from torchgen.code_template import CodeTemplate DECOMP_HEADER = r""" diff --git a/torchgen/dest/__init__.py b/torchgen/dest/__init__.py index 0c684fc1915c..8f08a743ae2d 100644 --- a/torchgen/dest/__init__.py +++ b/torchgen/dest/__init__.py @@ -1,18 +1,18 @@ -from .lazy_ir import ( +from torchgen.dest.lazy_ir import ( generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, GenLazyIR as GenLazyIR, GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, ) -from .native_functions import ( +from torchgen.dest.native_functions import ( compute_native_function_declaration as compute_native_function_declaration, ) -from .register_dispatch_key import ( +from torchgen.dest.register_dispatch_key import ( gen_registration_headers as gen_registration_headers, gen_registration_helpers as gen_registration_helpers, RegisterDispatchKey as RegisterDispatchKey, ) -from .ufunc import ( +from torchgen.dest.ufunc import ( compute_ufunc_cpu as compute_ufunc_cpu, compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, compute_ufunc_cuda as compute_ufunc_cuda, diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index 57a9217550d9..531c01b699fc 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -3,7 +3,6 @@ from typing import List, Optional, Union import torchgen.api.meta as meta import torchgen.api.structured as structured from torchgen.api.types import kernel_signature - from torchgen.context import with_native_function_and_index from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup from torchgen.utils import mapMaybe diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index ffc879afb6cd..8c90160fa695 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -27,6 +27,7 @@ from torchgen.model import ( ) from torchgen.utils import OrderedSet + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # CUDA STUFF diff --git a/torchgen/executorch/api/custom_ops.py b/torchgen/executorch/api/custom_ops.py index 7e31025675ef..e4eec9e3fb2f 100644 --- a/torchgen/executorch/api/custom_ops.py +++ b/torchgen/executorch/api/custom_ops.py @@ -1,12 +1,11 @@ from collections import defaultdict - from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple from torchgen import dest # disable import sorting to avoid circular dependency. -from torchgen.api.types import DispatcherSignature # isort:skip +from torchgen.api.types import DispatcherSignature # usort:skip from torchgen.context import method_with_native_function from torchgen.executorch.model import ETKernelIndex from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 24dda58ecdbc..18574f472fc9 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -15,6 +15,14 @@ from torchgen.api.types import ( VectorCType, voidT, ) +from torchgen.executorch.api.types import ( + ArrayRefCType, + BaseTypeToCppMapping, + OptionalCType, + scalarT, + tensorListT, + tensorT, +) from torchgen.model import ( Argument, Arguments, @@ -29,14 +37,7 @@ from torchgen.model import ( Type, ) from torchgen.utils import assert_never -from .types import ( - ArrayRefCType, - BaseTypeToCppMapping, - OptionalCType, - scalarT, - tensorListT, - tensorT, -) + """ This file describes the translation of JIT schema to the public C++ API, which is what people use when they call diff --git a/torchgen/executorch/api/types/__init__.py b/torchgen/executorch/api/types/__init__.py index eb5e802634f8..6fc9666768ba 100644 --- a/torchgen/executorch/api/types/__init__.py +++ b/torchgen/executorch/api/types/__init__.py @@ -1,2 +1,2 @@ -from .types import * -from .signatures import * # isort:skip +from torchgen.executorch.api.types.types import * +from torchgen.executorch.api.types.signatures import * # usort:skip diff --git a/torchgen/executorch/api/types/signatures.py b/torchgen/executorch/api/types/signatures.py index a53d15c036a9..3449b2b9a525 100644 --- a/torchgen/executorch/api/types/signatures.py +++ b/torchgen/executorch/api/types/signatures.py @@ -2,12 +2,10 @@ from dataclasses import dataclass from typing import List, Optional, Set import torchgen.api.cpp as aten_cpp - from torchgen.api.types import Binding, CType +from torchgen.executorch.api.types.types import contextArg from torchgen.model import FunctionSchema, NativeFunction -from .types import contextArg - @dataclass(frozen=True) class ExecutorchCppSignature: diff --git a/torchgen/executorch/api/types/types.py b/torchgen/executorch/api/types/types.py index c9db1baa245f..6ec48c803ae6 100644 --- a/torchgen/executorch/api/types/types.py +++ b/torchgen/executorch/api/types/types.py @@ -15,6 +15,7 @@ from torchgen.api.types import ( ) from torchgen.model import BaseTy + halfT = BaseCppType("torch::executor", "Half") bfloat16T = BaseCppType("torch::executor", "BFloat16") stringT = BaseCppType("torch::executor", "string_view") diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index 50d69d34e96b..a81e6d11fea6 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -12,6 +12,7 @@ from torchgen.model import ( Type, ) + connector = "\n\t" diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py index cec9251a3187..a7d5f1ceb161 100644 --- a/torchgen/executorch/model.py +++ b/torchgen/executorch/model.py @@ -17,6 +17,7 @@ from torchgen.model import ( ) from torchgen.utils import assert_never + KERNEL_KEY_VERSION = 1 diff --git a/torchgen/executorch/parse.py b/torchgen/executorch/parse.py index 89b4b93558a6..94acb5c2115e 100644 --- a/torchgen/executorch/parse.py +++ b/torchgen/executorch/parse.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple import yaml from torchgen.executorch.model import ETKernelIndex, ETKernelKey - from torchgen.gen import LineLoader, parse_native_yaml from torchgen.model import ( BackendMetadata, @@ -15,6 +14,7 @@ from torchgen.model import ( ) from torchgen.utils import NamespaceHelper + # Parse native_functions.yaml into a sequence of NativeFunctions and ET Backend Indices. ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indices"]) diff --git a/torchgen/fuse/gen_patterns.py b/torchgen/fuse/gen_patterns.py index 68bf938d2712..18562f54096b 100644 --- a/torchgen/fuse/gen_patterns.py +++ b/torchgen/fuse/gen_patterns.py @@ -4,6 +4,7 @@ import os from torch._inductor import pattern_matcher from torch._inductor.fx_passes import joint_graph + if __name__ == "__main__": # Start by deleting all the existing patterns. for file in os.listdir(pattern_matcher.SERIALIZED_PATTERN_PATH): diff --git a/torchgen/gen.py b/torchgen/gen.py index a1c1a8f957f3..e9dc04d0b9b2 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -3,7 +3,6 @@ import functools import json import os import pathlib - from collections import defaultdict, namedtuple, OrderedDict from dataclasses import dataclass, field from typing import ( @@ -27,7 +26,6 @@ import torchgen.api.meta as meta import torchgen.api.native as native import torchgen.api.structured as structured import torchgen.dest as dest - from torchgen.aoti.fallback_ops import inductor_fallback_ops from torchgen.api import cpp from torchgen.api.translate import translate @@ -59,7 +57,6 @@ from torchgen.gen_functionalization_type import ( GenCompositeViewCopyKernel, ) from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing - from torchgen.model import ( Argument, BackendIndex, @@ -105,6 +102,7 @@ from torchgen.utils import ( ) from torchgen.yaml_utils import YamlDumper, YamlLoader + T = TypeVar("T") # Welcome to the ATen code generator v2! The ATen code generator is diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 73152e1c0a7b..d8d91fbaa16c 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union from torchgen.api.types import DispatcherSignature from torchgen.api.types.signatures import CppSignature, CppSignatureGroup - from torchgen.context import method_with_native_function from torchgen.model import ( Argument, @@ -22,6 +21,7 @@ from torchgen.model import ( ) from torchgen.utils import mapMaybe + base_type_to_c_type = { BaseTy.Tensor: "AtenTensorHandle", BaseTy.bool: "int32_t", # Use int to pass bool diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index 8d2c567c3478..34b5e617e490 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -46,7 +46,6 @@ from torchgen.native_function_generation import ( MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) - from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import dataclass_repr diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index eed0e8de7ae2..6dd6f45b1dab 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -18,22 +18,21 @@ from typing import ( import yaml import torchgen.dest as dest - from torchgen.api.lazy import setValueT from torchgen.api.types import BaseCppType from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR from torchgen.gen import get_grouped_native_functions, parse_native_yaml - -from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName -from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import FileManager, NamespaceHelper -from torchgen.yaml_utils import YamlLoader -from .gen_backend_stubs import ( +from torchgen.gen_backend_stubs import ( error_on_missing_kernels, gen_dispatcher_registrations, gen_dispatchkey_nativefunc_headers, parse_backend_yaml, ) +from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import FileManager, NamespaceHelper +from torchgen.yaml_utils import YamlLoader + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # diff --git a/torchgen/local.py b/torchgen/local.py index f72e53601ab1..09532c7bfc67 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -2,6 +2,7 @@ import threading from contextlib import contextmanager from typing import Iterator, Optional + # Simple dynamic scoping implementation. The name "parametrize" comes # from Racket. # diff --git a/torchgen/model.py b/torchgen/model.py index bed8f262f592..e150cb7cf6f9 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1,13 +1,13 @@ import dataclasses import itertools import re - from dataclasses import dataclass from enum import auto, Enum from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union from torchgen.utils import assert_never, NamespaceHelper, OrderedSet + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # # DATA MODEL diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index 87cc2fef5366..3705944309d0 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -1,5 +1,4 @@ from collections import defaultdict - from typing import Dict, List, Optional, Sequence, Tuple, Union import torchgen.api.dispatcher as dispatcher @@ -27,6 +26,7 @@ from torchgen.model import ( ) from torchgen.utils import concatMap + # See Note: [Out ops with functional variants that don't get grouped properly] OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [ # This has a functional variant, but it's currently marked private. diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 29070761c55f..18b2952c9eae 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -7,7 +7,6 @@ from typing import Any, Dict, List import torch from torch.jit.generate_bytecode import generate_upgraders_bytecode - from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( MOBILE_UPGRADERS_HEADER_DESCRIPTION, diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index 737d296d9a69..93a4436fd220 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -10,6 +10,7 @@ from torchgen.context import native_function_manager from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup from torchgen.static_runtime import config, generator + # Given a list of `grouped_native_functions` sorted by their op names, return a list of # lists each of which groups ops that share the base name. For example, `mean` and # `mean.dim` are grouped together by this function. diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index b068af7728aa..7960679660b7 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -1,6 +1,5 @@ import json import logging - import math from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -21,6 +20,7 @@ from torchgen.model import ( ) from torchgen.static_runtime import config + logger: logging.Logger = logging.getLogger()