mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply ufmt linter to all py files under torchgen (#81570)
Previous batches: * https://github.com/pytorch/pytorch/pull/81285 * https://github.com/pytorch/pytorch/pull/81335 We have multiple batches here to minimize merge conflicts and reviewing process. Once everything has been formatted by ufmt (black+usort), the current black linter will be removed. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81570 Approved by: https://github.com/ezyang
This commit is contained in:
@ -672,6 +672,7 @@ code = 'UFMT'
|
||||
include_patterns = [
|
||||
'test/onnx/**/*.py',
|
||||
'tools/**/*.py',
|
||||
'torchgen/**/*.py',
|
||||
]
|
||||
command = [
|
||||
'python3',
|
||||
|
@ -1,15 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
from typing import Optional, Sequence, Set, List, Tuple, Match
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Match, Optional, Sequence, Set, Tuple
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, NamedCType
|
||||
from torchgen.model import (
|
||||
NativeFunction,
|
||||
Type,
|
||||
SchemaKind,
|
||||
NativeFunctionsViewGroup,
|
||||
)
|
||||
from torchgen.model import NativeFunction, NativeFunctionsViewGroup, SchemaKind, Type
|
||||
from torchgen.utils import IDENT_REGEX
|
||||
|
||||
# Represents a saved attribute involved in backward calculation.
|
||||
|
@ -1,3 +1,35 @@
|
||||
from typing import List, Optional, Sequence, Set, Union
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
ArrayCType,
|
||||
ArrayRefCType,
|
||||
BaseCType,
|
||||
BaseTypeToCppMapping,
|
||||
Binding,
|
||||
boolT,
|
||||
ConstRefCType,
|
||||
CType,
|
||||
dimnameListT,
|
||||
intArrayRefT,
|
||||
ListCType,
|
||||
longT,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
OptionalCType,
|
||||
optionalIntArrayRefT,
|
||||
scalarT,
|
||||
SpecialArgName,
|
||||
symIntArrayRefT,
|
||||
SymIntT,
|
||||
tensorListT,
|
||||
tensorOptionsT,
|
||||
tensorT,
|
||||
TupleCType,
|
||||
VectorCType,
|
||||
voidT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
Arguments,
|
||||
@ -12,38 +44,7 @@ from torchgen.model import (
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
NamedCType,
|
||||
CType,
|
||||
MutRefCType,
|
||||
ArrayCType,
|
||||
ListCType,
|
||||
VectorCType,
|
||||
ArrayRefCType,
|
||||
OptionalCType,
|
||||
TupleCType,
|
||||
SpecialArgName,
|
||||
boolT,
|
||||
scalarT,
|
||||
tensorListT,
|
||||
dimnameListT,
|
||||
tensorT,
|
||||
voidT,
|
||||
longT,
|
||||
SymIntT,
|
||||
symIntArrayRefT,
|
||||
BaseTypeToCppMapping,
|
||||
intArrayRefT,
|
||||
optionalIntArrayRefT,
|
||||
tensorOptionsT,
|
||||
)
|
||||
from torchgen import local
|
||||
from torchgen.utils import assert_never
|
||||
from typing import Optional, Sequence, Union, List, Set
|
||||
|
||||
# This file describes the translation of JIT schema to the public C++
|
||||
# API, which is what people use when they call functions like at::add.
|
||||
|
@ -1,3 +1,9 @@
|
||||
import itertools
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
FunctionSchema,
|
||||
@ -6,13 +12,7 @@ from torchgen.model import (
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
|
||||
from torchgen.api.types import ArgName, Binding, NamedCType, CType
|
||||
from torchgen.api import cpp
|
||||
from torchgen.utils import concatMap, assert_never
|
||||
|
||||
import itertools
|
||||
from typing import Sequence, List, Union
|
||||
from torchgen.utils import assert_never, concatMap
|
||||
|
||||
# This file describes the translation of JIT schema to the dispatcher
|
||||
# API, the *unboxed* calling convention by which invocations through
|
||||
|
@ -1,22 +1,23 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from torchgen.api import dispatcher
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
boolT,
|
||||
ConstRefCType,
|
||||
CType,
|
||||
longT,
|
||||
NamedCType,
|
||||
tensorT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
Argument,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
NativeFunctionsViewGroup,
|
||||
Argument,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
Binding,
|
||||
NamedCType,
|
||||
ConstRefCType,
|
||||
BaseCType,
|
||||
CType,
|
||||
tensorT,
|
||||
longT,
|
||||
boolT,
|
||||
)
|
||||
from torchgen.api import dispatcher
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
# This file describes the translation of JIT schema to API's used
|
||||
|
@ -1,35 +1,36 @@
|
||||
from typing import Any, Dict, List, Union, Tuple, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Type,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
OptionalType,
|
||||
ListType,
|
||||
OperatorName,
|
||||
FunctionSchema,
|
||||
Return,
|
||||
TensorOptionsArguments,
|
||||
Argument,
|
||||
)
|
||||
from torchgen.api.types import (
|
||||
CType,
|
||||
BaseCppType,
|
||||
BaseCType,
|
||||
OptionalCType,
|
||||
NamedCType,
|
||||
deviceT,
|
||||
layoutT,
|
||||
VectorCType,
|
||||
boolT,
|
||||
longT,
|
||||
CType,
|
||||
deviceT,
|
||||
doubleT,
|
||||
layoutT,
|
||||
ListCType,
|
||||
stringT,
|
||||
longT,
|
||||
memoryFormatT,
|
||||
NamedCType,
|
||||
OptionalCType,
|
||||
scalarT,
|
||||
scalarTypeT,
|
||||
memoryFormatT,
|
||||
stringT,
|
||||
SymIntT,
|
||||
VectorCType,
|
||||
)
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
OperatorName,
|
||||
OptionalType,
|
||||
Return,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,3 +1,25 @@
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
boolT,
|
||||
ConstRefCType,
|
||||
CType,
|
||||
deviceT,
|
||||
layoutT,
|
||||
ListCType,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
OptionalCType,
|
||||
scalarT,
|
||||
scalarTypeT,
|
||||
tensorT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
FunctionSchema,
|
||||
@ -6,30 +28,8 @@ from torchgen.model import (
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
NamedCType,
|
||||
CType,
|
||||
MutRefCType,
|
||||
ListCType,
|
||||
OptionalCType,
|
||||
tensorT,
|
||||
scalarT,
|
||||
layoutT,
|
||||
deviceT,
|
||||
boolT,
|
||||
scalarTypeT,
|
||||
)
|
||||
from torchgen.api import cpp
|
||||
from torchgen import local
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
from typing import Union, Sequence, List, Optional
|
||||
|
||||
# This file describes the translation of JIT schema to the native functions API.
|
||||
# This looks a lot like the C++ API (which makes historical sense, because the
|
||||
# idea was you wrote native functions to implement functions in the C++ API),
|
||||
|
@ -1,8 +1,9 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, Sequence, Set, List, Dict, Tuple
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||
from torchgen.api import cpp
|
||||
from torchgen.gen import pythonify_default
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
|
@ -1,3 +1,25 @@
|
||||
from typing import List, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
ArrayRefCType,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
dimnameListT,
|
||||
intArrayRefT,
|
||||
iOptTensorListRefT,
|
||||
iTensorListRefT,
|
||||
NamedCType,
|
||||
OptionalCType,
|
||||
optionalIntArrayRefT,
|
||||
optionalScalarRefT,
|
||||
optionalTensorRefT,
|
||||
scalarT,
|
||||
tensorT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
@ -9,31 +31,8 @@ from torchgen.model import (
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
)
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ArrayRefCType,
|
||||
ConstRefCType,
|
||||
OptionalCType,
|
||||
NamedCType,
|
||||
tensorT,
|
||||
scalarT,
|
||||
intArrayRefT,
|
||||
dimnameListT,
|
||||
optionalTensorRefT,
|
||||
optionalScalarRefT,
|
||||
optionalIntArrayRefT,
|
||||
iTensorListRefT,
|
||||
iOptTensorListRefT,
|
||||
)
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.utils import assert_never
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
# This file describes the translation of JIT schema to the structured functions API.
|
||||
# This is similar to native API, but a number of historical problems with native
|
||||
# API have been fixed.
|
||||
|
@ -1,35 +1,36 @@
|
||||
from typing import Dict, Sequence, List, NoReturn, Union
|
||||
from typing import Dict, List, NoReturn, Sequence, Union
|
||||
|
||||
from torchgen.api.types import (
|
||||
ListCType,
|
||||
tensorListT,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
Expr,
|
||||
MutRefCType,
|
||||
OptionalCType,
|
||||
NamedCType,
|
||||
SpecialArgName,
|
||||
tensorT,
|
||||
memoryFormatT,
|
||||
tensorOptionsT,
|
||||
scalarTypeT,
|
||||
SymIntT,
|
||||
symIntArrayRefT,
|
||||
boolT,
|
||||
ConstRefCType,
|
||||
deviceT,
|
||||
layoutT,
|
||||
optionalTensorRefT,
|
||||
iTensorListRefT,
|
||||
iOptTensorListRefT,
|
||||
scalarT,
|
||||
optionalScalarRefT,
|
||||
VectorCType,
|
||||
longT,
|
||||
Expr,
|
||||
intArrayRefT,
|
||||
scalar_t,
|
||||
iOptTensorListRefT,
|
||||
iTensorListRefT,
|
||||
layoutT,
|
||||
ListCType,
|
||||
longT,
|
||||
memoryFormatT,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
opmath_t,
|
||||
OptionalCType,
|
||||
optionalIntArrayRefT,
|
||||
optionalScalarRefT,
|
||||
optionalTensorRefT,
|
||||
scalar_t,
|
||||
scalarT,
|
||||
scalarTypeT,
|
||||
SpecialArgName,
|
||||
symIntArrayRefT,
|
||||
SymIntT,
|
||||
tensorListT,
|
||||
tensorOptionsT,
|
||||
tensorT,
|
||||
VectorCType,
|
||||
)
|
||||
|
||||
# This file implements a small program synthesis engine that implements
|
||||
|
@ -1,18 +1,19 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Sequence, Set, TypeVar, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
BackendIndex,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
ScalarType,
|
||||
SelfArgument,
|
||||
TensorOptionsArguments,
|
||||
BaseTy,
|
||||
ScalarType,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, Sequence, TypeVar, List, Set, Dict
|
||||
from enum import Enum
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -752,8 +753,8 @@ def kernel_signature(
|
||||
from torchgen.api import (
|
||||
cpp,
|
||||
dispatcher,
|
||||
native,
|
||||
translate,
|
||||
functionalization,
|
||||
native,
|
||||
structured,
|
||||
translate,
|
||||
)
|
||||
|
@ -1,30 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torchgen.api.types as api_types
|
||||
|
||||
from torchgen.api import cpp, structured
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCppType,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
CType,
|
||||
NamedCType,
|
||||
scalarT,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
NativeFunctionsGroup,
|
||||
Type,
|
||||
DispatchKey,
|
||||
)
|
||||
|
||||
import torchgen.api.types as api_types
|
||||
from torchgen.api.types import (
|
||||
ArgName,
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
NamedCType,
|
||||
scalarT,
|
||||
CType,
|
||||
BaseCppType,
|
||||
)
|
||||
|
||||
from torchgen.api import cpp, structured
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
|
||||
assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
|
||||
|
@ -1,15 +1,15 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, CType, CppSignatureGroup
|
||||
from torchgen.api.types import Binding, CppSignatureGroup, CType
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
NativeFunction,
|
||||
Type,
|
||||
BaseType,
|
||||
OptionalType,
|
||||
ListType,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
OptionalType,
|
||||
Type,
|
||||
)
|
||||
|
||||
# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
|
||||
|
@ -1,5 +1,5 @@
|
||||
import re
|
||||
from typing import Match, Optional, Sequence, Mapping
|
||||
from typing import Mapping, Match, Optional, Sequence
|
||||
|
||||
# match $identifier or ${identifier} and replace with value in env
|
||||
# If this identifier is at the beginning of whitespace on a line
|
||||
|
@ -1,16 +1,17 @@
|
||||
from torchgen.utils import S, T, context
|
||||
import contextlib
|
||||
|
||||
import functools
|
||||
from typing import Callable, Dict, Iterator, Optional, TypeVar, Union
|
||||
|
||||
import torchgen.local as local
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
DispatchKey,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
BackendIndex,
|
||||
DispatchKey,
|
||||
)
|
||||
import torchgen.local as local
|
||||
|
||||
import functools
|
||||
from typing import TypeVar, Union, Iterator, Callable, Dict, Optional
|
||||
import contextlib
|
||||
from torchgen.utils import context, S, T
|
||||
|
||||
# Helper functions for defining generators on things in the model
|
||||
|
||||
|
@ -1,19 +1,19 @@
|
||||
from .lazy_ir import GenLazyIR as GenLazyIR
|
||||
from .lazy_ir import GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition
|
||||
from .lazy_ir import GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition
|
||||
from .lazy_ir import (
|
||||
generate_non_native_lazy_ir_nodes as generate_non_native_lazy_ir_nodes,
|
||||
)
|
||||
from .register_dispatch_key import (
|
||||
RegisterDispatchKey as RegisterDispatchKey,
|
||||
gen_registration_helpers as gen_registration_helpers,
|
||||
gen_registration_headers as gen_registration_headers,
|
||||
GenLazyIR as GenLazyIR,
|
||||
GenLazyNativeFuncDefinition as GenLazyNativeFuncDefinition,
|
||||
GenLazyShapeInferenceDefinition as GenLazyShapeInferenceDefinition,
|
||||
)
|
||||
from .native_functions import (
|
||||
compute_native_function_declaration as compute_native_function_declaration,
|
||||
)
|
||||
from .register_dispatch_key import (
|
||||
gen_registration_headers as gen_registration_headers,
|
||||
gen_registration_helpers as gen_registration_helpers,
|
||||
RegisterDispatchKey as RegisterDispatchKey,
|
||||
)
|
||||
from .ufunc import (
|
||||
compute_ufunc_cuda as compute_ufunc_cuda,
|
||||
compute_ufunc_cpu as compute_ufunc_cpu,
|
||||
compute_ufunc_cpu_kernel as compute_ufunc_cpu_kernel,
|
||||
compute_ufunc_cuda as compute_ufunc_cuda,
|
||||
)
|
||||
|
@ -1,35 +1,36 @@
|
||||
from abc import ABC
|
||||
import itertools
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.model import (
|
||||
FunctionSchema,
|
||||
Argument,
|
||||
BackendIndex,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.lazy import (
|
||||
getValueT,
|
||||
isValueType,
|
||||
LazyArgument,
|
||||
LazyIrProperties,
|
||||
LazyIrSchema,
|
||||
tensorListValueT,
|
||||
)
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
deviceT,
|
||||
DispatcherSignature,
|
||||
kernel_signature,
|
||||
OptionalCType,
|
||||
VectorCType,
|
||||
kernel_signature,
|
||||
deviceT,
|
||||
)
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.lazy import (
|
||||
LazyIrProperties,
|
||||
LazyIrSchema,
|
||||
LazyArgument,
|
||||
getValueT,
|
||||
isValueType,
|
||||
tensorListValueT,
|
||||
)
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.dest.lazy_ts_lowering import ts_lowering_body
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
)
|
||||
|
||||
|
||||
def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
|
||||
|
@ -1,11 +1,12 @@
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from torchgen.context import with_native_function_and_index
|
||||
from torchgen.utils import mapMaybe
|
||||
from torchgen.model import NativeFunction, NativeFunctionsGroup, BackendIndex
|
||||
from torchgen.api.types import kernel_signature
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.structured as structured
|
||||
from torchgen.api.types import kernel_signature
|
||||
|
||||
from torchgen.context import with_native_function_and_index
|
||||
from torchgen.model import BackendIndex, NativeFunction, NativeFunctionsGroup
|
||||
from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
@with_native_function_and_index
|
||||
|
@ -1,42 +1,44 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import itertools
|
||||
from typing_extensions import Literal
|
||||
from dataclasses import dataclass
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from torchgen.context import method_with_native_function, native_function_manager
|
||||
from torchgen.utils import Target, mapMaybe, assert_never
|
||||
from torchgen.model import (
|
||||
DispatchKey,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
SchemaKind,
|
||||
TensorOptionsArguments,
|
||||
DeviceCheckType,
|
||||
Argument,
|
||||
is_cuda_dispatch_key,
|
||||
BackendIndex,
|
||||
gets_generated_out_inplace_wrapper,
|
||||
)
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torchgen.api.cpp as cpp
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.structured as structured
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
ConstRefCType,
|
||||
CppSignature,
|
||||
CppSignatureGroup,
|
||||
DispatcherSignature,
|
||||
Expr,
|
||||
MutRefCType,
|
||||
kernel_signature,
|
||||
MutRefCType,
|
||||
NamedCType,
|
||||
NativeSignature,
|
||||
tensorT,
|
||||
NamedCType,
|
||||
DispatcherSignature,
|
||||
)
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.cpp as cpp
|
||||
import torchgen.api.structured as structured
|
||||
from torchgen.api.translate import translate
|
||||
|
||||
from torchgen.context import method_with_native_function, native_function_manager
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
DeviceCheckType,
|
||||
DispatchKey,
|
||||
gets_generated_out_inplace_wrapper,
|
||||
is_cuda_dispatch_key,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
SchemaKind,
|
||||
TensorOptionsArguments,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import assert_never, mapMaybe, Target
|
||||
|
||||
|
||||
def gen_registration_headers(
|
||||
|
@ -1,29 +1,30 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Union, Optional, List, Tuple, Dict, Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torchgen.api.ufunc as ufunc
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
BaseCType,
|
||||
Binding,
|
||||
CType,
|
||||
Expr,
|
||||
NamedCType,
|
||||
opmath_t,
|
||||
scalar_t,
|
||||
StructuredImplSignature,
|
||||
VectorizedCType,
|
||||
)
|
||||
from torchgen.api.ufunc import UfunctorBindings
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
DispatchKey,
|
||||
NativeFunctionsGroup,
|
||||
ScalarType,
|
||||
UfuncKey,
|
||||
DispatchKey,
|
||||
BaseType,
|
||||
BaseTy,
|
||||
Argument,
|
||||
)
|
||||
import torchgen.api.ufunc as ufunc
|
||||
from torchgen.api.ufunc import UfunctorBindings
|
||||
from torchgen.api.types import (
|
||||
StructuredImplSignature,
|
||||
scalar_t,
|
||||
opmath_t,
|
||||
Binding,
|
||||
CType,
|
||||
BaseCType,
|
||||
Expr,
|
||||
NamedCType,
|
||||
VectorizedCType,
|
||||
)
|
||||
from torchgen.context import with_native_function
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
|
124
torchgen/gen.py
124
torchgen/gen.py
@ -1,45 +1,22 @@
|
||||
import os
|
||||
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
|
||||
from typing_extensions import Literal
|
||||
import yaml
|
||||
from collections import OrderedDict, defaultdict, namedtuple
|
||||
import argparse
|
||||
import pathlib
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, TypeVar, Union
|
||||
|
||||
from torchgen.model import (
|
||||
STRUCTURED_DISPATCH_KEYS,
|
||||
Argument,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
Location,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OperatorName,
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
OptionalType,
|
||||
SchemaKind,
|
||||
SelfArgument,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
Variant,
|
||||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key,
|
||||
is_ufunc_dispatch_key,
|
||||
NativeFunctionsViewGroup,
|
||||
ViewSchemaKind,
|
||||
BaseOperatorName,
|
||||
DEFAULT_KERNEL_NAMESPACE,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
pre_group_native_functions,
|
||||
add_generated_native_functions,
|
||||
gen_composite_functional_kernel,
|
||||
gen_composite_out_kernel,
|
||||
)
|
||||
import yaml
|
||||
from typing_extensions import Literal
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.native as native
|
||||
import torchgen.api.structured as structured
|
||||
import torchgen.dest as dest
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
Binding,
|
||||
CppSignatureGroup,
|
||||
@ -48,40 +25,65 @@ from torchgen.api.types import (
|
||||
NativeSignature,
|
||||
SpecialArgName,
|
||||
)
|
||||
from torchgen.api import cpp
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
import torchgen.api.native as native
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.structured as structured
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import (
|
||||
Target,
|
||||
concatMap,
|
||||
context,
|
||||
mapMaybe,
|
||||
YamlDumper,
|
||||
YamlLoader,
|
||||
FileManager,
|
||||
assert_never,
|
||||
make_file_manager,
|
||||
NamespaceHelper,
|
||||
)
|
||||
from torchgen.context import (
|
||||
method_with_native_function,
|
||||
native_function_manager,
|
||||
with_native_function_and_indices,
|
||||
with_native_function,
|
||||
with_native_function_and_indices,
|
||||
)
|
||||
import torchgen.dest as dest
|
||||
from torchgen.gen_functionalization_type import (
|
||||
gen_composite_view_copy_kernel,
|
||||
gen_functionalization_definition,
|
||||
gen_functionalization_registration,
|
||||
gen_functionalization_view_inverse_declaration,
|
||||
gen_composite_view_copy_kernel,
|
||||
gen_symint_view_copy_kernel,
|
||||
)
|
||||
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
BaseOperatorName,
|
||||
DEFAULT_KERNEL_NAMESPACE,
|
||||
DispatchKey,
|
||||
FunctionSchema,
|
||||
is_cuda_dispatch_key,
|
||||
is_generic_dispatch_key,
|
||||
is_ufunc_dispatch_key,
|
||||
Location,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
OperatorName,
|
||||
OptionalType,
|
||||
SchemaKind,
|
||||
SelfArgument,
|
||||
STRUCTURED_DISPATCH_KEYS,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
Variant,
|
||||
ViewSchemaKind,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
add_generated_native_functions,
|
||||
gen_composite_functional_kernel,
|
||||
gen_composite_out_kernel,
|
||||
pre_group_native_functions,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import (
|
||||
assert_never,
|
||||
concatMap,
|
||||
context,
|
||||
FileManager,
|
||||
make_file_manager,
|
||||
mapMaybe,
|
||||
NamespaceHelper,
|
||||
Target,
|
||||
YamlDumper,
|
||||
YamlLoader,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Welcome to the ATen code generator v2! The ATen code generator is
|
||||
|
@ -1,14 +1,18 @@
|
||||
import pathlib
|
||||
import argparse
|
||||
import os
|
||||
import yaml
|
||||
import pathlib
|
||||
import re
|
||||
from collections import namedtuple, Counter, defaultdict
|
||||
from typing import List, Dict, Union, Sequence, Optional
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
)
|
||||
from collections import Counter, defaultdict, namedtuple
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import yaml
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
import torchgen.dest as dest
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
BackendMetadata,
|
||||
@ -19,18 +23,13 @@ from torchgen.model import (
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import (
|
||||
Target,
|
||||
concatMap,
|
||||
context,
|
||||
YamlLoader,
|
||||
FileManager,
|
||||
NamespaceHelper,
|
||||
Target,
|
||||
YamlLoader,
|
||||
)
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.code_template import CodeTemplate
|
||||
import torchgen.dest as dest
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
|
||||
|
||||
# Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
|
||||
|
@ -1,47 +1,47 @@
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
from torchgen.api import cpp, dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import (
|
||||
DispatcherSignature,
|
||||
Binding,
|
||||
FunctionalizationLambda,
|
||||
ViewInverseSignature,
|
||||
NativeSignature,
|
||||
CType,
|
||||
BaseCType,
|
||||
VectorCType,
|
||||
Binding,
|
||||
CType,
|
||||
DispatcherSignature,
|
||||
FunctionalizationLambda,
|
||||
NativeSignature,
|
||||
tensorListT,
|
||||
tensorT,
|
||||
VectorCType,
|
||||
ViewInverseSignature,
|
||||
)
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.context import (
|
||||
native_function_manager,
|
||||
with_native_function,
|
||||
with_native_function_and,
|
||||
native_function_manager,
|
||||
)
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
Return,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
NativeFunctionsViewGroup,
|
||||
Return,
|
||||
SchemaKind,
|
||||
SelfArgument,
|
||||
TensorOptionsArguments,
|
||||
BaseType,
|
||||
BaseTy,
|
||||
NativeFunctionsViewGroup,
|
||||
ListType,
|
||||
)
|
||||
from torchgen.native_function_generation import (
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
)
|
||||
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
from typing import List, Optional, Union, Tuple, Callable
|
||||
|
||||
|
||||
# Note: [Mutable Ops Not Using Functionalization]
|
||||
# Ops in this list currently do not work with functionalization and should be fixed.
|
||||
|
@ -1,44 +1,39 @@
|
||||
import pathlib
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import yaml
|
||||
from collections import namedtuple, Counter
|
||||
from collections import Counter, namedtuple
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Dict,
|
||||
Tuple,
|
||||
Union,
|
||||
Sequence,
|
||||
Optional,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from torchgen.api.types import BaseCppType
|
||||
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
|
||||
from torchgen.gen import (
|
||||
get_grouped_native_functions,
|
||||
parse_native_yaml,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
import torchgen.dest as dest
|
||||
|
||||
from torchgen.api.lazy import setValueT
|
||||
from torchgen.api.types import BaseCppType
|
||||
from torchgen.dest.lazy_ir import GenLazyIR, GenTSLazyIR
|
||||
from torchgen.gen import get_grouped_native_functions, parse_native_yaml
|
||||
|
||||
from torchgen.model import (
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
OperatorName,
|
||||
)
|
||||
from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import concatMap, YamlLoader, FileManager, NamespaceHelper
|
||||
import torchgen.dest as dest
|
||||
from torchgen.utils import concatMap, FileManager, NamespaceHelper, YamlLoader
|
||||
from .gen_backend_stubs import (
|
||||
parse_backend_yaml,
|
||||
error_on_missing_kernels,
|
||||
gen_dispatchkey_nativefunc_headers,
|
||||
gen_dispatcher_registrations,
|
||||
gen_dispatchkey_nativefunc_headers,
|
||||
parse_backend_yaml,
|
||||
)
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
@ -1,6 +1,6 @@
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Iterator
|
||||
from typing import Iterator, Optional
|
||||
|
||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||
# from Racket.
|
||||
|
@ -5,8 +5,9 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torch.jit.generate_bytecode import generate_upgraders_bytecode
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
from torchgen.operator_versions.gen_mobile_upgraders_constant import (
|
||||
MOBILE_UPGRADERS_HEADER_DESCRIPTION,
|
||||
)
|
||||
|
@ -1,5 +1,5 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
# This class holds information about a single operator used to determine
|
||||
# the outcome of a selective/custom PyTorch build that doesn't include
|
||||
|
@ -1,13 +1,13 @@
|
||||
from typing import Dict, Set, Optional, Tuple, List
|
||||
import yaml
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.model import NativeFunction
|
||||
from torchgen.selective_build.operator import (
|
||||
SelectiveBuildOperator,
|
||||
merge_debug_info,
|
||||
merge_operator_dicts,
|
||||
SelectiveBuildOperator,
|
||||
strip_operator_overload_name,
|
||||
)
|
||||
|
||||
|
@ -1,11 +1,11 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
from pathlib import Path
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
from torch.jit._shape_functions import (
|
||||
shape_compute_graph_mapping,
|
||||
bounded_compute_graph_mapping,
|
||||
shape_compute_graph_mapping,
|
||||
)
|
||||
|
||||
SHAPE_HEADER = r"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
|
||||
|
||||
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
|
@ -1,14 +1,15 @@
|
||||
from torchgen import gen
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
from torchgen.static_runtime import generator
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
from libfb.py.log import set_simple_logging
|
||||
|
||||
from torchgen import gen
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
from torchgen.static_runtime import generator
|
||||
|
||||
# Given a list of `grouped_native_functions` sorted by their op names, return a list of
|
||||
# lists each of which groups ops that share the base name. For example, `mean` and
|
||||
# `mean.dim` are grouped together by this function.
|
||||
|
@ -1,25 +1,26 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torchgen.api.cpp as cpp
|
||||
from torchgen.context import native_function_manager
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BackendIndex,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
FunctionSchema,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
OptionalType,
|
||||
SelfArgument,
|
||||
BaseType,
|
||||
NativeFunctionsGroup,
|
||||
TensorOptionsArguments,
|
||||
Type,
|
||||
NativeFunctionsViewGroup,
|
||||
)
|
||||
from torchgen.static_runtime import config
|
||||
|
||||
import math
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
logger: logger = logging.getLogger()
|
||||
|
||||
|
||||
|
@ -3,29 +3,26 @@ import functools
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
import sys
|
||||
import textwrap
|
||||
from argparse import Namespace
|
||||
from dataclasses import (
|
||||
fields,
|
||||
is_dataclass,
|
||||
)
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Tuple,
|
||||
List,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Callable,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Optional,
|
||||
Dict,
|
||||
Any,
|
||||
Union,
|
||||
Set,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from enum import Enum
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
|
||||
|
Reference in New Issue
Block a user