From a4647cc1fab1e207926b07f4d0c8dd31c7dbb0f2 Mon Sep 17 00:00:00 2001 From: Huy Do Date: Sat, 16 Jul 2022 03:52:25 +0000 Subject: [PATCH] Apply ufmt linter to all py files under torchgen (#81570) Previous batches: * https://github.com/pytorch/pytorch/pull/81285 * https://github.com/pytorch/pytorch/pull/81335 We have multiple batches here to minimize merge conflicts and reviewing process. Once everything has been formatted by ufmt (black+usort), the current black linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81570 Approved by: https://github.com/ezyang --- .lintrunner.toml | 1 + torchgen/api/autograd.py | 11 +- torchgen/api/cpp.py | 63 ++++----- torchgen/api/dispatcher.py | 14 +- torchgen/api/functionalization.py | 29 ++-- torchgen/api/lazy.py | 45 +++---- torchgen/api/native.py | 44 +++---- torchgen/api/python.py | 5 +- torchgen/api/structured.py | 45 ++++--- torchgen/api/translate.py | 49 +++---- torchgen/api/types.py | 17 +-- torchgen/api/ufunc.py | 35 +++-- torchgen/api/unboxing.py | 12 +- torchgen/code_template.py | 2 +- torchgen/context.py | 17 +-- torchgen/dest/__init__.py | 18 +-- torchgen/dest/lazy_ir.py | 43 +++--- torchgen/dest/native_functions.py | 11 +- torchgen/dest/register_dispatch_key.py | 50 +++---- torchgen/dest/ufunc.py | 39 +++--- torchgen/gen.py | 124 +++++++++--------- torchgen/gen_backend_stubs.py | 29 ++-- torchgen/gen_functionalization_type.py | 40 +++--- torchgen/gen_lazy_tensor.py | 43 +++--- torchgen/local.py | 2 +- .../operator_versions/gen_mobile_upgraders.py | 3 +- torchgen/selective_build/operator.py | 2 +- torchgen/selective_build/selector.py | 8 +- .../gen_jit_shape_functions.py | 4 +- torchgen/static_runtime/config.py | 4 +- .../static_runtime/gen_static_runtime_ops.py | 11 +- torchgen/static_runtime/generator.py | 17 +-- torchgen/utils.py | 29 ++-- 33 files changed, 434 insertions(+), 432 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 81df23cd65c5..bef189faa45d 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -672,6 +672,7 @@ code = 'UFMT' include_patterns = [ 'test/onnx/**/*.py', 'tools/**/*.py', + 'torchgen/**/*.py', ] command = [ 'python3', diff --git a/torchgen/api/autograd.py b/torchgen/api/autograd.py index 11dd831bacd3..417bba637ce0 100644 --- a/torchgen/api/autograd.py +++ b/torchgen/api/autograd.py @@ -1,15 +1,10 @@ -from dataclasses import dataclass import re -from typing import Optional, Sequence, Set, List, Tuple, Match +from dataclasses import dataclass +from typing import List, Match, Optional, Sequence, Set, Tuple from torchgen.api import cpp from torchgen.api.types import Binding, NamedCType -from torchgen.model import ( - NativeFunction, - Type, - SchemaKind, - NativeFunctionsViewGroup, -) +from torchgen.model import NativeFunction, NativeFunctionsViewGroup, SchemaKind, Type from torchgen.utils import IDENT_REGEX # Represents a saved attribute involved in backward calculation. diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index b45bd7456bc4..b7a1e90e81cf 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -1,3 +1,35 @@ +from typing import List, Optional, Sequence, Set, Union + +from torchgen import local +from torchgen.api.types import ( + ArgName, + ArrayCType, + ArrayRefCType, + BaseCType, + BaseTypeToCppMapping, + Binding, + boolT, + ConstRefCType, + CType, + dimnameListT, + intArrayRefT, + ListCType, + longT, + MutRefCType, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + scalarT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + TupleCType, + VectorCType, + voidT, +) from torchgen.model import ( Argument, Arguments, @@ -12,38 +44,7 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ConstRefCType, - NamedCType, - CType, - MutRefCType, - ArrayCType, - ListCType, - VectorCType, - ArrayRefCType, - OptionalCType, - TupleCType, - SpecialArgName, - boolT, - scalarT, - tensorListT, - dimnameListT, - tensorT, - voidT, - longT, - SymIntT, - symIntArrayRefT, - BaseTypeToCppMapping, - intArrayRefT, - optionalIntArrayRefT, - tensorOptionsT, -) -from torchgen import local from torchgen.utils import assert_never -from typing import Optional, Sequence, Union, List, Set # This file describes the translation of JIT schema to the public C++ # API, which is what people use when they call functions like at::add. diff --git a/torchgen/api/dispatcher.py b/torchgen/api/dispatcher.py index ad1f17f71940..008e8c5664a4 100644 --- a/torchgen/api/dispatcher.py +++ b/torchgen/api/dispatcher.py @@ -1,3 +1,9 @@ +import itertools +from typing import List, Sequence, Union + +from torchgen.api import cpp + +from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.model import ( Argument, FunctionSchema, @@ -6,13 +12,7 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) - -from torchgen.api.types import ArgName, Binding, NamedCType, CType -from torchgen.api import cpp -from torchgen.utils import concatMap, assert_never - -import itertools -from typing import Sequence, List, Union +from torchgen.utils import assert_never, concatMap # This file describes the translation of JIT schema to the dispatcher # API, the *unboxed* calling convention by which invocations through diff --git a/torchgen/api/functionalization.py b/torchgen/api/functionalization.py index 22ce2c3c4d00..c071fd10087b 100644 --- a/torchgen/api/functionalization.py +++ b/torchgen/api/functionalization.py @@ -1,22 +1,23 @@ +from typing import List, Optional + +from torchgen.api import dispatcher +from torchgen.api.types import ( + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + longT, + NamedCType, + tensorT, +) from torchgen.model import ( - FunctionSchema, + Argument, BaseTy, BaseType, + FunctionSchema, NativeFunctionsViewGroup, - Argument, ) -from torchgen.api.types import ( - Binding, - NamedCType, - ConstRefCType, - BaseCType, - CType, - tensorT, - longT, - boolT, -) -from torchgen.api import dispatcher -from typing import List, Optional # This file describes the translation of JIT schema to API's used diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index d424ae02ecb4..6bce9db92bdb 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -1,35 +1,36 @@ -from typing import Any, Dict, List, Union, Tuple, Optional +from typing import Any, Dict, List, Optional, Tuple, Union -from torchgen.model import ( - Type, - BaseTy, - BaseType, - OptionalType, - ListType, - OperatorName, - FunctionSchema, - Return, - TensorOptionsArguments, - Argument, -) from torchgen.api.types import ( - CType, BaseCppType, BaseCType, - OptionalCType, - NamedCType, - deviceT, - layoutT, - VectorCType, boolT, - longT, + CType, + deviceT, doubleT, + layoutT, ListCType, - stringT, + longT, + memoryFormatT, + NamedCType, + OptionalCType, scalarT, scalarTypeT, - memoryFormatT, + stringT, SymIntT, + VectorCType, +) + +from torchgen.model import ( + Argument, + BaseTy, + BaseType, + FunctionSchema, + ListType, + OperatorName, + OptionalType, + Return, + TensorOptionsArguments, + Type, ) diff --git a/torchgen/api/native.py b/torchgen/api/native.py index 47610022e55a..16814e34867c 100644 --- a/torchgen/api/native.py +++ b/torchgen/api/native.py @@ -1,3 +1,25 @@ +from typing import List, Optional, Sequence, Union + +from torchgen import local +from torchgen.api import cpp + +from torchgen.api.types import ( + ArgName, + BaseCType, + Binding, + boolT, + ConstRefCType, + CType, + deviceT, + layoutT, + ListCType, + MutRefCType, + NamedCType, + OptionalCType, + scalarT, + scalarTypeT, + tensorT, +) from torchgen.model import ( Argument, FunctionSchema, @@ -6,30 +28,8 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) - -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ConstRefCType, - NamedCType, - CType, - MutRefCType, - ListCType, - OptionalCType, - tensorT, - scalarT, - layoutT, - deviceT, - boolT, - scalarTypeT, -) -from torchgen.api import cpp -from torchgen import local from torchgen.utils import assert_never -from typing import Union, Sequence, List, Optional - # This file describes the translation of JIT schema to the native functions API. # This looks a lot like the C++ API (which makes historical sense, because the # idea was you wrote native functions to implement functions in the C++ API), diff --git a/torchgen/api/python.py b/torchgen/api/python.py index 64ce1a9700f7..505325ccc74d 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -1,8 +1,9 @@ from dataclasses import dataclass -from typing import Optional, Union, Sequence, Set, List, Dict, Tuple +from typing import Dict, List, Optional, Sequence, Set, Tuple, Union + +from torchgen.api import cpp from torchgen.api.types import Binding, CppSignature, CppSignatureGroup -from torchgen.api import cpp from torchgen.gen import pythonify_default from torchgen.model import ( Argument, diff --git a/torchgen/api/structured.py b/torchgen/api/structured.py index 2a0ecd918292..4787adccae6b 100644 --- a/torchgen/api/structured.py +++ b/torchgen/api/structured.py @@ -1,3 +1,25 @@ +from typing import List, Union + +from torchgen.api import cpp + +from torchgen.api.types import ( + ArgName, + ArrayRefCType, + BaseCType, + Binding, + ConstRefCType, + dimnameListT, + intArrayRefT, + iOptTensorListRefT, + iTensorListRefT, + NamedCType, + OptionalCType, + optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalarT, + tensorT, +) from torchgen.model import ( Argument, BaseTy, @@ -9,31 +31,8 @@ from torchgen.model import ( TensorOptionsArguments, Type, ) - -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ArrayRefCType, - ConstRefCType, - OptionalCType, - NamedCType, - tensorT, - scalarT, - intArrayRefT, - dimnameListT, - optionalTensorRefT, - optionalScalarRefT, - optionalIntArrayRefT, - iTensorListRefT, - iOptTensorListRefT, -) - -from torchgen.api import cpp from torchgen.utils import assert_never -from typing import Union, List - # This file describes the translation of JIT schema to the structured functions API. # This is similar to native API, but a number of historical problems with native # API have been fixed. diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index c73978668abe..bee33b473dc9 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -1,35 +1,36 @@ -from typing import Dict, Sequence, List, NoReturn, Union +from typing import Dict, List, NoReturn, Sequence, Union + from torchgen.api.types import ( - ListCType, - tensorListT, BaseCType, Binding, - ConstRefCType, - Expr, - MutRefCType, - OptionalCType, - NamedCType, - SpecialArgName, - tensorT, - memoryFormatT, - tensorOptionsT, - scalarTypeT, - SymIntT, - symIntArrayRefT, boolT, + ConstRefCType, deviceT, - layoutT, - optionalTensorRefT, - iTensorListRefT, - iOptTensorListRefT, - scalarT, - optionalScalarRefT, - VectorCType, - longT, + Expr, intArrayRefT, - scalar_t, + iOptTensorListRefT, + iTensorListRefT, + layoutT, + ListCType, + longT, + memoryFormatT, + MutRefCType, + NamedCType, opmath_t, + OptionalCType, optionalIntArrayRefT, + optionalScalarRefT, + optionalTensorRefT, + scalar_t, + scalarT, + scalarTypeT, + SpecialArgName, + symIntArrayRefT, + SymIntT, + tensorListT, + tensorOptionsT, + tensorT, + VectorCType, ) # This file implements a small program synthesis engine that implements diff --git a/torchgen/api/types.py b/torchgen/api/types.py index 9717133a4cdb..6bee40a421d4 100644 --- a/torchgen/api/types.py +++ b/torchgen/api/types.py @@ -1,18 +1,19 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union + from torchgen.model import ( Argument, + BackendIndex, + BaseTy, FunctionSchema, NativeFunction, - BackendIndex, NativeFunctionsGroup, NativeFunctionsViewGroup, + ScalarType, SelfArgument, TensorOptionsArguments, - BaseTy, - ScalarType, ) -from dataclasses import dataclass -from typing import Optional, Union, Sequence, TypeVar, List, Set, Dict -from enum import Enum _T = TypeVar("_T") @@ -752,8 +753,8 @@ def kernel_signature( from torchgen.api import ( cpp, dispatcher, - native, - translate, functionalization, + native, structured, + translate, ) diff --git a/torchgen/api/ufunc.py b/torchgen/api/ufunc.py index 5836e276240e..34384ce340d5 100644 --- a/torchgen/api/ufunc.py +++ b/torchgen/api/ufunc.py @@ -1,30 +1,29 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torchgen.api.types as api_types + +from torchgen.api import cpp, structured +from torchgen.api.types import ( + ArgName, + BaseCppType, + BaseCType, + Binding, + ConstRefCType, + CType, + NamedCType, + scalarT, +) from torchgen.model import ( Argument, BaseTy, BaseType, + DispatchKey, FunctionSchema, NativeFunctionsGroup, Type, - DispatchKey, ) -import torchgen.api.types as api_types -from torchgen.api.types import ( - ArgName, - BaseCType, - Binding, - ConstRefCType, - NamedCType, - scalarT, - CType, - BaseCppType, -) - -from torchgen.api import cpp, structured - -from dataclasses import dataclass -from typing import List, Optional - def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index 06595353de29..b5afdc099fa9 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -1,15 +1,15 @@ from typing import List, Tuple from torchgen.api import cpp -from torchgen.api.types import Binding, CType, CppSignatureGroup +from torchgen.api.types import Binding, CppSignatureGroup, CType from torchgen.model import ( Argument, - NativeFunction, - Type, - BaseType, - OptionalType, - ListType, BaseTy, + BaseType, + ListType, + NativeFunction, + OptionalType, + Type, ) # This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the diff --git a/torchgen/code_template.py b/torchgen/code_template.py index e8241c65586f..9f877771afe9 100644 --- a/torchgen/code_template.py +++ b/torchgen/code_template.py @@ -1,5 +1,5 @@ import re -from typing import Match, Optional, Sequence, Mapping +from typing import Mapping, Match, Optional, Sequence # match $identifier or ${identifier} and replace with value in env # If this identifier is at the beginning of whitespace on a line diff --git a/torchgen/context.py b/torchgen/context.py index f65e3daaa8d9..bbb8ea4d5c4c 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -1,16 +1,17 @@ -from torchgen.utils import S, T, context +import contextlib + +import functools +from typing import Callable, Dict, Iterator, Optional, TypeVar, Union + +import torchgen.local as local from torchgen.model import ( + BackendIndex, + DispatchKey, NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup, - BackendIndex, - DispatchKey, ) -import torchgen.local as local - -import functools -from typing import TypeVar, Union, Iterator, Callable, Dict, Optional -import contextlib +from torchgen.utils import context, S, T # Helper functions for defining generators on things in the model diff --git a/torchgen/dest/__init__.py b/torchgen/dest/__init__.py index 498c437a88a3..0c684fc1915c 100644 --- a/torchgen/dest/__init__.py +++ b/torchgen/dest/__init__.py @@ -1,19 +1,19 @@ -from .lazy_ir import GenLazyIR as GenLazyIR -from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition -from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition from .lazy_ir import ( generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes, -) -from .register_dispatch_key import ( - RegisterDispatchKey as RegisterDispatchKey, - gen_registration_helpers as gen_registration_helpers, - gen_registration_headers as gen_registration_headers, + GenLazyIR as GenLazyIR, + GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition, + GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition, ) from .native_functions import ( compute_native_function_declaration as compute_native_function_declaration, ) +from .register_dispatch_key import ( + gen_registration_headers as gen_registration_headers, + gen_registration_helpers as gen_registration_helpers, + RegisterDispatchKey as RegisterDispatchKey, +) from .ufunc import ( - compute_ufunc_cuda as compute_ufunc_cuda, compute_ufunc_cpu as compute_ufunc_cpu, compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel, + compute_ufunc_cuda as compute_ufunc_cuda, ) diff --git a/torchgen/dest/lazy_ir.py b/torchgen/dest/lazy_ir.py index 2a0e224818c2..36d9d4edc469 100644 --- a/torchgen/dest/lazy_ir.py +++ b/torchgen/dest/lazy_ir.py @@ -1,35 +1,36 @@ -from abc import ABC import itertools +from abc import ABC from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union, Tuple -from torchgen.context import method_with_native_function -from torchgen.model import ( - FunctionSchema, - Argument, - BackendIndex, - NativeFunction, - NativeFunctionsGroup, +from typing import Any, Dict, List, Optional, Tuple, Union + +import torchgen.api.dispatcher as dispatcher +from torchgen.api.lazy import ( + getValueT, + isValueType, + LazyArgument, + LazyIrProperties, + LazyIrSchema, + tensorListValueT, ) +from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, + deviceT, DispatcherSignature, + kernel_signature, OptionalCType, VectorCType, - kernel_signature, - deviceT, -) -import torchgen.api.dispatcher as dispatcher -from torchgen.api.translate import translate -from torchgen.api.lazy import ( - LazyIrProperties, - LazyIrSchema, - LazyArgument, - getValueT, - isValueType, - tensorListValueT, ) +from torchgen.context import method_with_native_function from torchgen.dest.lazy_ts_lowering import ts_lowering_body +from torchgen.model import ( + Argument, + BackendIndex, + FunctionSchema, + NativeFunction, + NativeFunctionsGroup, +) def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str: diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index 67db9795f11e..57a9217550d9 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -1,11 +1,12 @@ -from typing import List, Union, Optional +from typing import List, Optional, Union -from torchgen.context import with_native_function_and_index -from torchgen.utils import mapMaybe -from torchgen.model import NativeFunction, NativeFunctionsGroup, BackendIndex -from torchgen.api.types import kernel_signature import torchgen.api.meta as meta import torchgen.api.structured as structured +from torchgen.api.types import kernel_signature + +from torchgen.context import with_native_function_and_index +from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup +from torchgen.utils import mapMaybe @with_native_function_and_index diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 5a814ec10ba0..f7a3ef7bb644 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -1,42 +1,44 @@ -from typing import List, Optional, Tuple, Union import itertools -from typing_extensions import Literal -from dataclasses import dataclass import textwrap +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union -from torchgen.context import method_with_native_function, native_function_manager -from torchgen.utils import Target, mapMaybe, assert_never -from torchgen.model import ( - DispatchKey, - NativeFunction, - NativeFunctionsGroup, - SchemaKind, - TensorOptionsArguments, - DeviceCheckType, - Argument, - is_cuda_dispatch_key, - BackendIndex, - gets_generated_out_inplace_wrapper, -) +from typing_extensions import Literal + +import torchgen.api.cpp as cpp +import torchgen.api.meta as meta +import torchgen.api.structured as structured +from torchgen.api.translate import translate from torchgen.api.types import ( BaseCType, Binding, ConstRefCType, CppSignature, CppSignatureGroup, + DispatcherSignature, Expr, - MutRefCType, kernel_signature, + MutRefCType, + NamedCType, NativeSignature, tensorT, - NamedCType, - DispatcherSignature, ) -import torchgen.api.meta as meta -import torchgen.api.cpp as cpp -import torchgen.api.structured as structured -from torchgen.api.translate import translate + +from torchgen.context import method_with_native_function, native_function_manager +from torchgen.model import ( + Argument, + BackendIndex, + DeviceCheckType, + DispatchKey, + gets_generated_out_inplace_wrapper, + is_cuda_dispatch_key, + NativeFunction, + NativeFunctionsGroup, + SchemaKind, + TensorOptionsArguments, +) from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import assert_never, mapMaybe, Target def gen_registration_headers( diff --git a/torchgen/dest/ufunc.py b/torchgen/dest/ufunc.py index 09b964958c34..55dd214793e1 100644 --- a/torchgen/dest/ufunc.py +++ b/torchgen/dest/ufunc.py @@ -1,29 +1,30 @@ from dataclasses import dataclass -from typing import Union, Optional, List, Tuple, Dict, Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Union + +import torchgen.api.ufunc as ufunc from torchgen.api.translate import translate +from torchgen.api.types import ( + BaseCType, + Binding, + CType, + Expr, + NamedCType, + opmath_t, + scalar_t, + StructuredImplSignature, + VectorizedCType, +) +from torchgen.api.ufunc import UfunctorBindings +from torchgen.context import with_native_function from torchgen.model import ( + Argument, + BaseTy, + BaseType, + DispatchKey, NativeFunctionsGroup, ScalarType, UfuncKey, - DispatchKey, - BaseType, - BaseTy, - Argument, ) -import torchgen.api.ufunc as ufunc -from torchgen.api.ufunc import UfunctorBindings -from torchgen.api.types import ( - StructuredImplSignature, - scalar_t, - opmath_t, - Binding, - CType, - BaseCType, - Expr, - NamedCType, - VectorizedCType, -) -from torchgen.context import with_native_function # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # diff --git a/torchgen/gen.py b/torchgen/gen.py index bcb68dd6eb32..c12c680ad625 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -1,45 +1,22 @@ -import os -from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar -from typing_extensions import Literal -import yaml -from collections import OrderedDict, defaultdict, namedtuple import argparse -import pathlib -import json -from dataclasses import dataclass import functools +import json +import os +import pathlib +from collections import defaultdict, namedtuple, OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union -from torchgen.model import ( - STRUCTURED_DISPATCH_KEYS, - Argument, - DispatchKey, - FunctionSchema, - Location, - NativeFunction, - NativeFunctionsGroup, - OperatorName, - BackendIndex, - BackendMetadata, - OptionalType, - SchemaKind, - SelfArgument, - TensorOptionsArguments, - Type, - Variant, - is_cuda_dispatch_key, - is_generic_dispatch_key, - is_ufunc_dispatch_key, - NativeFunctionsViewGroup, - ViewSchemaKind, - BaseOperatorName, - DEFAULT_KERNEL_NAMESPACE, -) -from torchgen.native_function_generation import ( - pre_group_native_functions, - add_generated_native_functions, - gen_composite_functional_kernel, - gen_composite_out_kernel, -) +import yaml +from typing_extensions import Literal + +import torchgen.api.dispatcher as dispatcher +import torchgen.api.meta as meta +import torchgen.api.native as native +import torchgen.api.structured as structured +import torchgen.dest as dest +from torchgen.api import cpp +from torchgen.api.translate import translate from torchgen.api.types import ( Binding, CppSignatureGroup, @@ -48,40 +25,65 @@ from torchgen.api.types import ( NativeSignature, SpecialArgName, ) -from torchgen.api import cpp -import torchgen.api.dispatcher as dispatcher -import torchgen.api.native as native -import torchgen.api.meta as meta -import torchgen.api.structured as structured -from torchgen.api.translate import translate -from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import ( - Target, - concatMap, - context, - mapMaybe, - YamlDumper, - YamlLoader, - FileManager, - assert_never, - make_file_manager, - NamespaceHelper, -) from torchgen.context import ( method_with_native_function, native_function_manager, - with_native_function_and_indices, with_native_function, + with_native_function_and_indices, ) -import torchgen.dest as dest from torchgen.gen_functionalization_type import ( + gen_composite_view_copy_kernel, gen_functionalization_definition, gen_functionalization_registration, gen_functionalization_view_inverse_declaration, - gen_composite_view_copy_kernel, gen_symint_view_copy_kernel, ) +from torchgen.model import ( + Argument, + BackendIndex, + BackendMetadata, + BaseOperatorName, + DEFAULT_KERNEL_NAMESPACE, + DispatchKey, + FunctionSchema, + is_cuda_dispatch_key, + is_generic_dispatch_key, + is_ufunc_dispatch_key, + Location, + NativeFunction, + NativeFunctionsGroup, + NativeFunctionsViewGroup, + OperatorName, + OptionalType, + SchemaKind, + SelfArgument, + STRUCTURED_DISPATCH_KEYS, + TensorOptionsArguments, + Type, + Variant, + ViewSchemaKind, +) +from torchgen.native_function_generation import ( + add_generated_native_functions, + gen_composite_functional_kernel, + gen_composite_out_kernel, + pre_group_native_functions, +) +from torchgen.selective_build.selector import SelectiveBuilder +from torchgen.utils import ( + assert_never, + concatMap, + context, + FileManager, + make_file_manager, + mapMaybe, + NamespaceHelper, + Target, + YamlDumper, + YamlLoader, +) + T = TypeVar("T") # Welcome to the ATen code generator v2! The ATen code generator is diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 74378d7eab30..e2c2b46fd76c 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -1,14 +1,18 @@ -import pathlib import argparse import os -import yaml +import pathlib import re -from collections import namedtuple, Counter, defaultdict -from typing import List, Dict, Union, Sequence, Optional -from torchgen.gen import ( - get_grouped_native_functions, - parse_native_yaml, -) +from collections import Counter, defaultdict, namedtuple +from typing import Dict, List, Optional, Sequence, Union + +import yaml + +import torchgen.api.dispatcher as dispatcher +import torchgen.dest as dest +from torchgen.api.types import DispatcherSignature +from torchgen.code_template import CodeTemplate +from torchgen.context import native_function_manager +from torchgen.gen import get_grouped_native_functions, parse_native_yaml from torchgen.model import ( BackendIndex, BackendMetadata, @@ -19,18 +23,13 @@ from torchgen.model import ( ) from torchgen.selective_build.selector import SelectiveBuilder from torchgen.utils import ( - Target, concatMap, context, - YamlLoader, FileManager, NamespaceHelper, + Target, + YamlLoader, ) -from torchgen.context import native_function_manager -from torchgen.code_template import CodeTemplate -import torchgen.dest as dest -import torchgen.api.dispatcher as dispatcher -from torchgen.api.types import DispatcherSignature # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key. diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index e0aa157c8029..988f0847729d 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -1,47 +1,47 @@ +from typing import Callable, List, Optional, Tuple, Union + from torchgen.api import cpp, dispatcher +from torchgen.api.translate import translate from torchgen.api.types import ( - DispatcherSignature, - Binding, - FunctionalizationLambda, - ViewInverseSignature, - NativeSignature, - CType, BaseCType, - VectorCType, + Binding, + CType, + DispatcherSignature, + FunctionalizationLambda, + NativeSignature, tensorListT, tensorT, + VectorCType, + ViewInverseSignature, ) -from torchgen.api.translate import translate from torchgen.context import ( + native_function_manager, with_native_function, with_native_function_and, - native_function_manager, ) from torchgen.model import ( Argument, - Return, + BackendIndex, + BaseTy, + BaseType, + FunctionSchema, + ListType, NativeFunction, NativeFunctionsGroup, - BackendIndex, - FunctionSchema, + NativeFunctionsViewGroup, + Return, SchemaKind, SelfArgument, TensorOptionsArguments, - BaseType, - BaseTy, - NativeFunctionsViewGroup, - ListType, ) from torchgen.native_function_generation import ( - OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, - MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY, + MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, + OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, ) from torchgen.selective_build.selector import SelectiveBuilder -from typing import List, Optional, Union, Tuple, Callable - # Note: [Mutable Ops Not Using Functionalization] # Ops in this list currently do not work with functionalization and should be fixed. diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 85ce23966445..0f6887e36986 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -1,44 +1,39 @@ -import pathlib import argparse import os +import pathlib import re -import yaml -from collections import namedtuple, Counter +from collections import Counter, namedtuple from typing import ( Any, - List, - Dict, - Tuple, - Union, - Sequence, - Optional, Callable, + Dict, Iterable, Iterator, + List, + Optional, + Sequence, + Tuple, Type, + Union, ) -from torchgen.api.types import BaseCppType -from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR -from torchgen.gen import ( - get_grouped_native_functions, - parse_native_yaml, -) + +import yaml + +import torchgen.dest as dest from torchgen.api.lazy import setValueT +from torchgen.api.types import BaseCppType +from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR +from torchgen.gen import get_grouped_native_functions, parse_native_yaml -from torchgen.model import ( - NativeFunction, - NativeFunctionsGroup, - OperatorName, -) +from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName from torchgen.selective_build.selector import SelectiveBuilder -from torchgen.utils import concatMap, YamlLoader, FileManager, NamespaceHelper -import torchgen.dest as dest +from torchgen.utils import concatMap, FileManager, NamespaceHelper, YamlLoader from .gen_backend_stubs import ( - parse_backend_yaml, error_on_missing_kernels, - gen_dispatchkey_nativefunc_headers, gen_dispatcher_registrations, + gen_dispatchkey_nativefunc_headers, + parse_backend_yaml, ) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # diff --git a/torchgen/local.py b/torchgen/local.py index dd570dd8d7ee..65efce2c3b11 100644 --- a/torchgen/local.py +++ b/torchgen/local.py @@ -1,6 +1,6 @@ import threading from contextlib import contextmanager -from typing import Optional, Iterator +from typing import Iterator, Optional # Simple dynamic scoping implementation. The name "parametrize" comes # from Racket. diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 54c5b3a5628a..5006f4f6d89a 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -5,8 +5,9 @@ from pathlib import Path from typing import Any, Dict, List import torch -from torchgen.code_template import CodeTemplate from torch.jit.generate_bytecode import generate_upgraders_bytecode + +from torchgen.code_template import CodeTemplate from torchgen.operator_versions.gen_mobile_upgraders_constant import ( MOBILE_UPGRADERS_HEADER_DESCRIPTION, ) diff --git a/torchgen/selective_build/operator.py b/torchgen/selective_build/operator.py index ca80f5ad7f2a..76f8b963b990 100644 --- a/torchgen/selective_build/operator.py +++ b/torchgen/selective_build/operator.py @@ -1,5 +1,5 @@ -from typing import Dict, Optional, Tuple from dataclasses import dataclass +from typing import Dict, Optional, Tuple # This class holds information about a single operator used to determine # the outcome of a selective/custom PyTorch build that doesn't include diff --git a/torchgen/selective_build/selector.py b/torchgen/selective_build/selector.py index e65ecf5eaf45..dd94dd17dd0e 100644 --- a/torchgen/selective_build/selector.py +++ b/torchgen/selective_build/selector.py @@ -1,13 +1,13 @@ -from typing import Dict, Set, Optional, Tuple, List -import yaml - from dataclasses import dataclass +from typing import Dict, List, Optional, Set, Tuple + +import yaml from torchgen.model import NativeFunction from torchgen.selective_build.operator import ( - SelectiveBuildOperator, merge_debug_info, merge_operator_dicts, + SelectiveBuildOperator, strip_operator_overload_name, ) diff --git a/torchgen/shape_functions/gen_jit_shape_functions.py b/torchgen/shape_functions/gen_jit_shape_functions.py index 9d1f7a75f9a5..6013c4de5350 100644 --- a/torchgen/shape_functions/gen_jit_shape_functions.py +++ b/torchgen/shape_functions/gen_jit_shape_functions.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import os -from pathlib import Path from itertools import chain +from pathlib import Path from torch.jit._shape_functions import ( - shape_compute_graph_mapping, bounded_compute_graph_mapping, + shape_compute_graph_mapping, ) SHAPE_HEADER = r""" diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index bfcab625e2e3..03b20852c6f5 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -1,7 +1,7 @@ -from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup - from typing import Dict, Union +from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup + def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str: if isinstance(g, NativeFunctionsGroup): diff --git a/torchgen/static_runtime/gen_static_runtime_ops.py b/torchgen/static_runtime/gen_static_runtime_ops.py index 8608aa82401e..b2c5c9fe1534 100644 --- a/torchgen/static_runtime/gen_static_runtime_ops.py +++ b/torchgen/static_runtime/gen_static_runtime_ops.py @@ -1,14 +1,15 @@ -from torchgen import gen -from torchgen.context import native_function_manager -from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup -from torchgen.static_runtime import generator - import argparse import itertools import os from typing import Sequence, Union + from libfb.py.log import set_simple_logging +from torchgen import gen +from torchgen.context import native_function_manager +from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup +from torchgen.static_runtime import generator + # Given a list of `grouped_native_functions` sorted by their op names, return a list of # lists each of which groups ops that share the base name. For example, `mean` and # `mean.dim` are grouped together by this function. diff --git a/torchgen/static_runtime/generator.py b/torchgen/static_runtime/generator.py index 24593726056c..390f5cb1e4d0 100644 --- a/torchgen/static_runtime/generator.py +++ b/torchgen/static_runtime/generator.py @@ -1,25 +1,26 @@ +import json +import logging + +import math +from typing import List, Optional, Sequence, Tuple, Union + import torchgen.api.cpp as cpp from torchgen.context import native_function_manager from torchgen.model import ( Argument, BackendIndex, BaseTy, + BaseType, FunctionSchema, + NativeFunctionsGroup, + NativeFunctionsViewGroup, OptionalType, SelfArgument, - BaseType, - NativeFunctionsGroup, TensorOptionsArguments, Type, - NativeFunctionsViewGroup, ) from torchgen.static_runtime import config -import math -import logging -import json -from typing import List, Optional, Sequence, Tuple, Union - logger: logger = logging.getLogger() diff --git a/torchgen/utils.py b/torchgen/utils.py index 68371c12dc72..bb5860ce3ce8 100644 --- a/torchgen/utils.py +++ b/torchgen/utils.py @@ -3,29 +3,26 @@ import functools import hashlib import os import re -import textwrap import sys +import textwrap from argparse import Namespace -from dataclasses import ( - fields, - is_dataclass, -) +from dataclasses import fields, is_dataclass +from enum import Enum from typing import ( - Tuple, - List, + Any, + Callable, + Dict, Iterable, Iterator, - Callable, - Sequence, - TypeVar, - Optional, - Dict, - Any, - Union, - Set, + List, NoReturn, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, ) -from enum import Enum from torchgen.code_template import CodeTemplate