Revert "[BE][Easy] enable postponed annotations in torchgen (#129376)"

This reverts commit 494057d6d4e9b40daf81a6a4d7a8c839b7424b14.

Reverted https://github.com/pytorch/pytorch/pull/129376 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert https://github.com/pytorch/pytorch/pull/129374, please do a rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/129375#issuecomment-2197800541))
This commit is contained in:
PyTorch MergeBot
2024-06-29 00:44:24 +00:00
parent 83caf4960f
commit 6063bb9d45
45 changed files with 900 additions and 976 deletions

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import cast, Sequence
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
from torchgen import local
from torchgen.api import cpp
@ -50,16 +48,16 @@ class Derivative:
original_formula: str
# Names of the arguments for which this formula calculates derivatives.
var_names: tuple[str, ...]
var_names: Tuple[str, ...]
# Saved inputs that are referenced by the formula.
saved_inputs: tuple[SavedAttribute, ...]
saved_inputs: Tuple[SavedAttribute, ...]
# Saved outputs that are referenced by the formula.
saved_outputs: tuple[SavedAttribute, ...]
saved_outputs: Tuple[SavedAttribute, ...]
# Gradients that are referenced by name in the formula.
named_gradients: set[str]
named_gradients: Set[str]
# Represents a forward formula that calculates forward derivatives
@ -73,17 +71,17 @@ class ForwardDerivative:
# Name of the output arguments for which this formula calculates forward
# derivatives
var_names: tuple[str, ...]
var_names: Tuple[str, ...]
# Type of the output arguments for which this formula calculates forward
# derivatives
var_types: tuple[Type, ...]
var_types: Tuple[Type, ...]
# Inputs for which the forward derivatives are required for this formula
required_inputs_fw_grad: tuple[str, ...] | None
required_inputs_fw_grad: Optional[Tuple[str, ...]]
# Inputs for which the primal is required for this formula
required_inputs_primal: tuple[str, ...] | None
required_inputs_primal: Optional[Tuple[str, ...]]
# Flag to specify if this formula requires the original value of self
# This is only used by inplace operations
@ -118,7 +116,7 @@ class DifferentiabilityInfo:
# The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty.
op: str | None
op: Optional[str]
# The derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable inputs
@ -140,7 +138,7 @@ class DifferentiabilityInfo:
# The named gradients that are used in any of the derivatives.
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
used_named_gradients: set[str]
used_named_gradients: Set[str]
# The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the
@ -151,7 +149,7 @@ class DifferentiabilityInfo:
non_differentiable_arg_names: Sequence[str]
# Raw data read from derivatives.yaml.
output_differentiability: list[bool] | None
output_differentiability: Optional[List[bool]]
# output_differentiability in derivatives.yaml can be a list of
# conditions that express if the output is differentiable. In this case,
@ -159,7 +157,7 @@ class DifferentiabilityInfo:
# (NB: we only support one condition right now).
# output_differentiability gets populated with True for each condition,
# while output_differentiability_conditions gets populated with the conditions
output_differentiability_conditions: list[str] | None
output_differentiability_conditions: Optional[List[str]]
@property
def has_derivatives(self) -> bool:
@ -172,7 +170,7 @@ class DifferentiabilityInfo:
# See Note [Codegen'd {view}_copy Operators]
def create_view_copy_from_view_derivative(
self, g: NativeFunctionsViewGroup
) -> DifferentiabilityInfo | None:
) -> Optional["DifferentiabilityInfo"]:
if g.view_copy is None:
return None
f = g.view_copy
@ -203,7 +201,7 @@ class DifferentiabilityInfo:
)
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
if info is None:
return False
for derivative in info.derivatives:
@ -213,11 +211,11 @@ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
return False
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
return uses_ident(info, "retain_variables")
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
return uses_ident(info, "grad")
@ -255,8 +253,8 @@ class DifferentiableOutput:
@dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction
info: dict[str, DifferentiabilityInfo] | None
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
info: Optional[Dict[str, DifferentiabilityInfo]]
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
# TODO: Update comment below since it is out of date.
@ -365,19 +363,19 @@ def is_reference_for_foreach(
# TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo(
foreach_function: NativeFunction,
functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
functional_info_by_signature: Dict[
FunctionSchema, Dict[str, DifferentiabilityInfo]
],
non_functional_info_by_signature: dict[
FunctionSchema, dict[str, DifferentiabilityInfo]
non_functional_info_by_signature: Dict[
FunctionSchema, Dict[str, DifferentiabilityInfo]
],
dispatch_key: str = "Default",
) -> tuple[DifferentiabilityInfo | None, bool]:
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
The second return value indicates whether the info is generated in this function.
"""
ref_diff_info: DifferentiabilityInfo | None = None
ref_diff_info: Optional[DifferentiabilityInfo] = None
for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema):
@ -487,13 +485,13 @@ def gen_foreach_derivativeinfo(
if arg.name in all_var_names
]
forward_derivatives: list[ForwardDerivative] = []
forward_derivatives: List[ForwardDerivative] = []
fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives:
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: list[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: list[str] = []
required_inputs_primal: list[str] = []
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: List[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: List[str] = []
required_inputs_primal: List[str] = []
if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal:
@ -580,9 +578,9 @@ def gen_foreach_derivativeinfo(
def match_differentiability_info(
native_functions: list[NativeFunction],
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
) -> list[NativeFunctionWithDifferentiabilityInfo]:
native_functions: List[NativeFunction],
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
) -> List[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative.
@ -601,7 +599,7 @@ def match_differentiability_info(
def find_info(
f: NativeFunction,
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
# Don't bother matching info to generated out= variants
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
return None, False
@ -655,7 +653,7 @@ Attempted to convert a derivative formula for a mutable operator
return None, False
result: list[NativeFunctionWithDifferentiabilityInfo] = []
result: List[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions:
info_dict, is_exact_match = find_info(f)
@ -679,7 +677,7 @@ Attempted to convert a derivative formula for a mutable operator
)
continue
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
for key, info in info_dict.items():
if not info.forward_derivatives:
fw_derivative_dict[key] = []
@ -715,7 +713,7 @@ Attempted to convert a derivative formula for a mutable operator
formula = fw_info.formula
def replace_self_with_original_self(formula: str, postfix: str) -> str:
def repl(m: re.Match[str]) -> str:
def repl(m: Match[str]) -> str:
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
@ -736,7 +734,7 @@ Attempted to convert a derivative formula for a mutable operator
formula = replace_self_with_original_self(formula, "_t")
# replace "result" from the formula by "self_p"
def repl(m: re.Match[str]) -> str:
def repl(m: Match[str]) -> str:
return f"{m.group(1)}self_p{m.group(2)}"
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
@ -760,8 +758,8 @@ Attempted to convert a derivative formula for a mutable operator
# If there is a need, we can relax (2) to allow any op that has an in-place variant
is_single_method_on_self_t = False
directly_do_inplace = False
op_name: str | None = None
between_parens: str | None = None
op_name: Optional[str] = None
between_parens: Optional[str] = None
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
if match:
op_name, between_parens = match.group(1), match.group(2)
@ -825,7 +823,7 @@ Attempted to convert a derivative formula for a mutable operator
def is_differentiable(
name: str, type: Type, info: DifferentiabilityInfo | None
name: str, type: Type, info: Optional[DifferentiabilityInfo]
) -> bool:
return type.is_tensor_like() and (
info is None or name not in info.non_differentiable_arg_names
@ -834,10 +832,10 @@ def is_differentiable(
def gen_differentiable_outputs(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> list[DifferentiableOutput]:
) -> List[DifferentiableOutput]:
f = fn.func
info = fn.info[key] if fn.info else None
outputs: list[DifferentiableOutput] = [
outputs: List[DifferentiableOutput] = [
DifferentiableOutput(
name=name,
type=ret.type,
@ -852,7 +850,7 @@ def gen_differentiable_outputs(
f"The length of output_differentiability ({len(output_differentiability)}), "
f"does not match the number of outputs ({len(outputs)})."
)
differentiable_outputs: list[DifferentiableOutput] = []
differentiable_outputs: List[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError(
"output_differentiability=False for inplace operation (version_counter won't get updated)"

View File

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import Sequence
from typing import List, Optional, Sequence, Set, Union
from torchgen import local
from torchgen.api.types import (
@ -96,7 +94,7 @@ def valuetype_type(
binds: ArgName,
remove_non_owning_ref_types: bool = False,
symint: bool = False,
) -> NamedCType | None:
) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
@ -281,7 +279,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
@ -370,17 +368,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*,
cpp_no_default_args: set[str],
cpp_no_default_args: Set[str],
method: bool,
faithful: bool,
symint: bool = False,
has_tensor_options: bool,
) -> list[Binding]:
) -> List[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
@ -396,7 +394,7 @@ def argument(
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type, symint=symint)
return [
@ -447,9 +445,9 @@ def arguments(
faithful: bool,
symint: bool = False,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
cpp_no_default_args: Set[str],
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)

View File

@ -1,7 +1,5 @@
from __future__ import annotations
import itertools
from typing import Sequence
from typing import List, Sequence, Union
from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType
@ -78,10 +76,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
return cpp.returns_type(rs, symint=symint)
def jit_arguments(func: FunctionSchema) -> list[Argument]:
def jit_arguments(func: FunctionSchema) -> List[Argument]:
def to_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Argument]:
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Argument]:
if isinstance(a, Argument):
return [a]
elif isinstance(a, SelfArgument):
@ -116,5 +114,5 @@ def argument(
)
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
return [argument(a, symint=symint) for a in jit_arguments(func)]

View File

@ -1,4 +1,4 @@
from __future__ import annotations
from typing import List, Optional
from torchgen.api import dispatcher
from torchgen.api.types import (
@ -93,7 +93,7 @@ def name(
*,
is_reverse: bool,
include_namespace: bool,
reapply_views: bool | None = None,
reapply_views: Optional[bool] = None,
) -> str:
if reapply_views is None:
# reapply_views is only important for the fwd lambda,
@ -124,7 +124,7 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
# capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
@ -152,14 +152,14 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
def outer_arguments(*, is_reverse: bool) -> List[Binding]:
if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else:
return [base_binding, mutated_view_idx_binding]
def inner_call_index(func: FunctionSchema) -> Binding | None:
def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or (
@ -169,7 +169,7 @@ def inner_call_index(func: FunctionSchema) -> Binding | None:
return None
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:]

View File

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import Any
from typing import Any, Dict, List, Optional, Tuple, Union
from torchgen.api.types import (
BaseCppType,
@ -36,7 +34,7 @@ from torchgen.model import (
)
_valueT: BaseCppType | None = None
_valueT: Optional[BaseCppType] = None
# A ValueT is an IR type which represents the computation of a Tensor. In other
@ -68,8 +66,8 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type(
typ: Type, properties: LazyIrProperties, *, symint: bool
) -> BaseCType | VectorCType | OptionalCType | ListCType:
typ: Type, properties: "LazyIrProperties", *, symint: bool
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
"""
This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen.
@ -149,7 +147,7 @@ def process_ir_type(
#
# Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
"""
Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed.
@ -204,7 +202,7 @@ def isGeneratorType(typ: Type) -> bool:
class LazyArgument:
name: str
orig_type: Type
lazy_type_: CType | None
lazy_type_: Optional[CType]
is_wrapped_scalar: bool
is_generator: bool
# TODO: this is lies, it is false for symint list
@ -216,9 +214,7 @@ class LazyArgument:
# true if this argument is or contains a lazy IR value
is_lazy_value: bool
def __init__(
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
) -> None:
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
self.name = arg.name
self.orig_type = arg.type
self.symint = symint
@ -252,7 +248,7 @@ class LazyIrProperties:
attributes. The mutual exclusivity is automatically handled.
"""
Properties: tuple[tuple[str, ...], ...] = (
Properties: Tuple[Tuple[str, ...], ...] = (
(
"ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction
@ -275,8 +271,8 @@ class LazyIrProperties:
),
)
def __init__(self, *default_properties: str) -> None:
properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
def __init__(self, *default_properties: str):
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
LazyIrProperties.Properties
)
self.__dict__["properties"] = properties
@ -309,17 +305,17 @@ class LazyIrProperties:
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema:
# The name of the operator this function schema describes.
name: OperatorName
name: "OperatorName"
positional_args: tuple[LazyArgument, ...]
keyword_args: tuple[LazyArgument, ...]
positional_args: Tuple[LazyArgument, ...]
keyword_args: Tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point
returns: tuple[Return, ...]
returns: Tuple["Return", ...]
# if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it
generator_arg: NamedCType | None = None
generator_arg: Optional[NamedCType] = None
# original function schema
func: FunctionSchema
@ -333,21 +329,21 @@ class LazyIrSchema:
"Lower",
"CanBeReused",
)
opkind: str | None = None
opkind: Optional[str] = None
def __init__(
self,
func: FunctionSchema,
properties: LazyIrProperties | None = None,
properties: Optional[LazyIrProperties] = None,
*,
symint: bool,
) -> None:
):
if properties:
self.properties = properties
self.func = func
self.symint = symint
positional_args: list[LazyArgument] = []
positional_args: List[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = func.arguments.self_arg.argument
@ -361,7 +357,7 @@ class LazyIrSchema:
)
self.positional_args = tuple(positional_args)
keyword_args: list[LazyArgument] = []
keyword_args: List[LazyArgument] = []
for arg_field in [
"pre_tensor_options_kwarg_only",
"tensor_options",
@ -415,13 +411,13 @@ class LazyIrSchema:
values: bool = True,
scalars: bool = True,
generator: bool = True,
) -> list[LazyArgument]:
) -> List[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR.
args: list[LazyArgument] = []
args: List[LazyArgument] = []
if positional:
args.extend(self.positional_args)
if keyword:
@ -443,25 +439,25 @@ class LazyIrSchema:
return []
@property
def positional_values(self) -> list[LazyArgument]:
def positional_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False
)
@property
def positional_scalars(self) -> list[LazyArgument]:
def positional_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True
)
@property
def keyword_values(self) -> list[LazyArgument]:
def keyword_values(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False
)
@property
def keyword_scalars(self) -> list[LazyArgument]:
def keyword_scalars(self) -> List[LazyArgument]:
return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True
)

View File

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import Sequence
from typing import List, Optional, Sequence, Union
from torchgen import local
from torchgen.api import cpp
@ -83,11 +81,11 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
def argument(
a: Argument | SelfArgument | TensorOptionsArguments,
a: Union[Argument, SelfArgument, TensorOptionsArguments],
*,
is_out: bool,
symint: bool,
) -> list[Binding]:
) -> List[Binding]:
# Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting
# existing. So for BC, we generate defaults for non-out variants (but not
@ -95,7 +93,7 @@ def argument(
# default)
should_default = not is_out
if isinstance(a, Argument):
default: str | None = None
default: Optional[str] = None
if should_default and a.default is not None:
default = cpp.default_expr(a.default, a.type, symint=symint)
return [
@ -146,8 +144,8 @@ def argument(
assert_never(a)
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(func.arguments.non_out)
args.extend(func.arguments.out)
return [

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
@ -199,14 +197,14 @@ from torchgen.model import (
@dataclass(frozen=True)
class PythonReturns:
returns: tuple[Return, ...]
returns: Tuple[Return, ...]
@dataclass(frozen=True)
class PythonArgument:
name: str
type: Type
default: str | None
default: Optional[str]
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
#
@ -214,7 +212,7 @@ class PythonArgument:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ^
# +--- default_init str
default_init: str | None
default_init: Optional[str]
# Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
@ -302,10 +300,12 @@ class PythonOutArgument(PythonArgument):
# 'auto out = _r.tensorlist_n<2>(2);',
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
# TODO: maybe don't need keep scattered out fields for python signature?
outputs: tuple[PythonArgument, ...]
outputs: Tuple[PythonArgument, ...]
@staticmethod
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
def from_outputs(
outputs: Tuple[PythonArgument, ...]
) -> Optional["PythonOutArgument"]:
if not outputs:
return None
@ -339,13 +339,13 @@ class PythonSignature:
# Positional arguments.
# TODO: create a dedicated SelfArgument type for 'self'?
input_args: tuple[PythonArgument, ...]
input_args: Tuple[PythonArgument, ...]
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
input_kwargs: tuple[PythonArgument, ...]
input_kwargs: Tuple[PythonArgument, ...]
output_args: PythonOutArgument | None
output_args: Optional[PythonOutArgument]
# Return types, which are only used by pyi
returns: PythonReturns
@ -356,7 +356,7 @@ class PythonSignature:
# for out variant), in which case they will be used as scattered fields without
# being packed into 'options'.
# TODO: maybe create a PythonTensorOptionsArgument?
tensor_options_args: tuple[PythonArgument, ...]
tensor_options_args: Tuple[PythonArgument, ...]
# method or function signature?
method: bool
@ -367,8 +367,8 @@ class PythonSignature:
def arguments(
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
) -> tuple[PythonArgument | PythonOutArgument, ...]:
result: list[PythonArgument | PythonOutArgument] = []
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
result: List[Union[PythonArgument, PythonOutArgument]] = []
result.extend(self.input_args)
result.extend(self.input_kwargs)
if self.output_args is not None and not skip_outputs:
@ -394,7 +394,7 @@ class PythonSignature:
# signature_str_pyi().
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: list[str] = [
schema_formals: List[str] = [
a.argument_str(method=self.method, symint=symint) for a in args
]
positional_argc = len(self.input_args)
@ -405,7 +405,7 @@ class PythonSignature:
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: list[str] = [
schema_formals: List[str] = [
a.argument_str_pyi(method=self.method) for a in args
]
positional_argc = len(self.input_args)
@ -419,10 +419,10 @@ class PythonSignature:
schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
# only pyi uses vararg signatures
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: list[str] = [
schema_formals: List[str] = [
a.argument_str_pyi(method=self.method) for a in args
]
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
# [func call]: self.addmm(mat1, mat2, beta, 1)
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
deprecated_args_exprs: tuple[str, ...]
deprecated_args_exprs: Tuple[str, ...]
@property
def deprecated(self) -> bool:
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs)
schema_formals: list[str] = [
schema_formals: List[str] = [
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
]
positional_argc = len(self.input_args)
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
returns_str = returns_str_pyi(self)
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
# the codegen doesn't include vararg variants for deprecated signatures
return None
@ -530,14 +530,14 @@ class PythonSignatureGroup:
base: NativeFunction
# The out variant (e.g. conv2d_out)
outplace: NativeFunction | None
outplace: Optional[NativeFunction]
@classmethod
def from_pairs(
cls,
functional: PythonSignatureNativeFunctionPair,
out: PythonSignatureNativeFunctionPair | None,
) -> PythonSignatureGroup:
out: Optional[PythonSignatureNativeFunctionPair],
) -> "PythonSignatureGroup":
if out is None:
return PythonSignatureGroup(
signature=functional.signature,
@ -716,7 +716,7 @@ def argument_type_str(
raise RuntimeError(f"unrecognized type {repr(t)}")
def argument_type_size(t: Type) -> int | None:
def argument_type_size(t: Type) -> Optional[int]:
l = t.is_list_like()
if l is not None and str(l.elem) != "bool":
return l.size
@ -750,11 +750,11 @@ def signature(
def signature_from_schema(
func: FunctionSchema,
*,
category_override: str | None,
category_override: Optional[str],
method: bool = False,
pyi: bool = False,
) -> PythonSignature:
args: list[Argument] = []
args: List[Argument] = []
args.extend(func.arguments.pre_self_positional)
# Skip SelfArgument if this is method.
if not method and func.arguments.self_arg is not None:
@ -807,10 +807,10 @@ def signature_from_schema(
)
is_dummy_function = category_override == "dummy"
tensor_options_args: list[PythonArgument] = []
tensor_options_args: List[PythonArgument] = []
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
def topt_default_init(name: str) -> str | None:
def topt_default_init(name: str) -> Optional[str]:
topt_args = func.arguments.tensor_options
if topt_args is None:
return None
@ -891,7 +891,7 @@ def signature_from_schema(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
if len(returns) <= 1 or all(r.name is None for r in returns):
return []
else:
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
return argument_type_str_pyi(t)
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
structseq_name = signature.name
field_names = structseq_fieldnames(signature.returns.returns)
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
def dispatch_lambda_args(
ps: PythonSignature, f: NativeFunction, symint: bool = True
) -> tuple[DispatchLambdaArgument, ...]:
) -> Tuple[DispatchLambdaArgument, ...]:
if isinstance(ps, PythonSignatureDeprecated):
schema = ps.deprecated_schema
else:
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
method=False,
cpp_no_default_args=f.cpp_no_default_args,
)
out_args: set[str] = {a.name for a in schema.arguments.out}
out_args: Set[str] = {a.name for a in schema.arguments.out}
# Convert from cpp argument to lambda argument
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
def cpp_dispatch_exprs(
f: NativeFunction,
*,
python_signature: PythonSignature | None = None,
) -> tuple[str, ...]:
python_signature: Optional[PythonSignature] = None,
) -> Tuple[str, ...]:
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
exprs: tuple[str, ...] = tuple()
exprs: Tuple[str, ...] = tuple()
if not isinstance(python_signature, PythonSignatureDeprecated):
# By default the exprs are consistent with the C++ signature.
exprs = tuple(a.name for a in cpp_args)
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
# For certain cases it is intentionally more restrictive than necessary,
# e.g.: it doesn't accepts doublelist with definite size.
def arg_parser_unpack_method(
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
) -> str:
has_default_init = default_init is not None
if has_default_init and str(t) not in (
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
def arg_parser_output_exprs(
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
) -> dict[str, PythonArgParserOutputExpr]:
) -> Dict[str, PythonArgParserOutputExpr]:
return {
e.name: e
for i, a in enumerate(ps.arguments())
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
# outputs.
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
inits: list[str] = []
lambda_args_exprs: dict[str, str] = {}
inits: List[str] = []
lambda_args_exprs: Dict[str, str] = {}
has_toptions = has_tensor_options(f)

View File

@ -1,4 +1,4 @@
from __future__ import annotations
from typing import List, Union
from torchgen.api import cpp
from torchgen.api.types import (
@ -97,7 +97,7 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
# Structured kernels are never defaulted
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
if isinstance(a, Argument):
return [
Binding(
@ -115,15 +115,15 @@ def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Bindin
assert_never(a)
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if g.out.precomputed:
# A list of parameters for the impl function with
# certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml.
non_out_args_replaced: list[
Argument | TensorOptionsArguments | SelfArgument
non_out_args_replaced: List[
Union[Argument, TensorOptionsArguments, SelfArgument]
] = []
for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
@ -145,13 +145,13 @@ def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
return [r for arg in args for r in argument(arg)]
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)]
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)]

View File

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import NoReturn, Sequence
from typing import Dict, List, NoReturn, Sequence, Union
from torchgen.api.types import (
ArrayRefCType,
@ -97,13 +95,13 @@ class UnsatError(RuntimeError):
# something more complicated, e.g., tracking the set of bindings in a context,
# you may find using these smaller types more convenient.
def translate(
bindings: Sequence[Expr | Binding],
goals: Sequence[NamedCType | Binding],
bindings: Sequence[Union[Expr, Binding]],
goals: Sequence[Union[NamedCType, Binding]],
*,
method: bool = False,
allow_expensive_conversions: bool = False,
) -> list[Expr]:
binding_exprs: list[Expr] = []
) -> List[Expr]:
binding_exprs: List[Expr] = []
for b in bindings:
if isinstance(b, Binding):
binding_exprs.append(
@ -115,7 +113,7 @@ def translate(
else:
binding_exprs.append(b)
goal_ctypes: list[NamedCType] = []
goal_ctypes: List[NamedCType] = []
for g in goals:
if isinstance(g, Binding):
goal_ctypes.append(g.nctype)
@ -123,7 +121,7 @@ def translate(
goal_ctypes.append(g)
# Add all the bindings to the context
ctx: dict[NamedCType, str] = {}
ctx: Dict[NamedCType, str] = {}
for b in binding_exprs:
ctx[b.type] = b.expr

View File

@ -1,19 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterator, Sequence, TYPE_CHECKING
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING:
from torchgen.model import (
BackendIndex,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
from torchgen.model import (
BackendIndex,
FunctionSchema,
NativeFunction,
NativeFunctionsGroup,
NativeFunctionsViewGroup,
)
@dataclass(frozen=True)
@ -43,7 +38,7 @@ class CppSignature:
symint: bool
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
cpp_no_default_args: Set[str]
# Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that
@ -77,7 +72,7 @@ class CppSignature:
def decl(
self,
*,
name: str | None = None,
name: Optional[str] = None,
prefix: str = "",
is_redispatching_fn: bool = False,
suppress_symint_suffix: bool = False,
@ -98,7 +93,7 @@ class CppSignature:
def defn(
self,
*,
name: str | None = None,
name: Optional[str] = None,
prefix: str = "",
is_redispatching_fn: bool = False,
) -> str:
@ -131,9 +126,9 @@ class CppSignature:
class CppSignatureGroup:
func: FunctionSchema
signature: CppSignature
faithful_signature: CppSignature | None
symint_signature: CppSignature | None
symint_faithful_signature: CppSignature | None
faithful_signature: Optional[CppSignature]
symint_signature: Optional[CppSignature]
symint_faithful_signature: Optional[CppSignature]
def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature:
@ -154,7 +149,7 @@ class CppSignatureGroup:
@staticmethod
def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> CppSignatureGroup:
) -> "CppSignatureGroup":
func = f.func
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
@ -167,16 +162,16 @@ class CppSignatureGroup:
cpp_no_default_args=f.cpp_no_default_args,
)
def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
faithful_signature: CppSignature | None = None
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
faithful_signature: Optional[CppSignature] = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True, symint=symint)
signature = make_sig(faithful=False, symint=symint)
return signature, faithful_signature
signature, faithful_signature = make_sigs(symint=False)
symint_signature: CppSignature | None = None
symint_faithful_signature: CppSignature | None = None
symint_signature: Optional[CppSignature] = None
symint_faithful_signature: Optional[CppSignature] = None
if func.has_symint():
symint_signature, symint_faithful_signature = make_sigs(symint=True)
@ -201,20 +196,20 @@ class DispatcherSignature:
symint: bool = True
def arguments(self) -> list[Binding]:
def arguments(self) -> List[Binding]:
return dispatcher.arguments(self.func, symint=self.symint)
def name(self) -> str:
return self.prefix + dispatcher.name(self.func)
def decl(self, name: str | None = None) -> str:
def decl(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(
self, name: str | None = None, *, is_redispatching_fn: bool = False
self, name: Optional[str] = None, *, is_redispatching_fn: bool = False
) -> str:
args = [a.defn() for a in self.arguments()]
if is_redispatching_fn:
@ -224,7 +219,7 @@ class DispatcherSignature:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def exprs(self) -> list[Expr]:
def exprs(self) -> List[Expr]:
return [Expr(a.name, a.nctype) for a in self.arguments()]
def returns_type(self) -> CType:
@ -242,7 +237,7 @@ class DispatcherSignature:
@staticmethod
def from_schema(
func: FunctionSchema, *, prefix: str = "", symint: bool = True
) -> DispatcherSignature:
) -> "DispatcherSignature":
return DispatcherSignature(func, prefix, symint)
@ -258,13 +253,13 @@ class NativeSignature:
def name(self) -> str:
return self.prefix + native.name(self.func)
def decl(self, name: str | None = None) -> str:
def decl(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments())
if name is None:
name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
def defn(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
if name is None:
name = self.name()
@ -275,13 +270,13 @@ class NativeSignature:
args_str = ", ".join(a.defn() for a in self.arguments())
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
def arguments(self) -> list[Binding]:
def arguments(self) -> List[Binding]:
return native.arguments(self.func, symint=self.symint)
def returns_type(self) -> CType:
return native.returns_type(self.func.returns, symint=self.symint)
def dispatcher_exprs(self) -> list[Expr]:
def dispatcher_exprs(self) -> List[Expr]:
return translate.translate(
self.arguments(), dispatcher.arguments(self.func), method=False
)
@ -312,7 +307,7 @@ class FunctionalizationLambda:
# are we generating the forward lambda or the reverse lambda?
is_reverse: bool
def captures(self) -> list[Expr]:
def captures(self) -> List[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda.
@ -341,7 +336,7 @@ class FunctionalizationLambda:
]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: bool | None = None) -> str:
def inner_call(self, *, reapply_views: Optional[bool] = None) -> str:
inner_call_name = functionalization.name(
self.g,
is_reverse=self.is_reverse,
@ -371,7 +366,7 @@ class FunctionalizationLambda:
@staticmethod
def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> FunctionalizationLambda:
) -> "FunctionalizationLambda":
return FunctionalizationLambda(g, is_reverse)
@ -380,11 +375,11 @@ class StructuredImplSignature:
g: NativeFunctionsGroup
name: str
def defn(self, name: str | None = None) -> str:
def defn(self, name: Optional[str] = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments())
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
def arguments(self) -> list[Binding]:
def arguments(self) -> List[Binding]:
return structured.impl_arguments(self.g)
@ -393,7 +388,7 @@ class StructuredImplSignature:
def kernel_signature(
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
) -> NativeSignature | DispatcherSignature:
) -> Union["NativeSignature", "DispatcherSignature"]:
# Note [External Backends Follow Dispatcher API]
# Kernel signatures for in-tree backends follow the "native" API,
# while kernels for out-of-tree backends follow the dispatcher API.

View File

@ -12,10 +12,8 @@ if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
from torchgen.api.types.types_base import (
BaseCppType,
@ -85,7 +83,7 @@ symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT,
ScalarType.Char: charT,
ScalarType.Short: shortT,
@ -104,7 +102,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
}
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
@ -130,7 +128,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -139,13 +137,13 @@ class OptionalCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ListCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -154,13 +152,13 @@ class ListCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -169,7 +167,7 @@ class ArrayRefCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return ArrayRefCType(self.elem.remove_const_ref())
@ -187,5 +185,5 @@ class VectorizedCType(CType):
def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return self

View File

@ -12,17 +12,12 @@ if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import auto, Enum
from typing import TYPE_CHECKING, Union
from typing import List, Optional, Union
if TYPE_CHECKING:
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
# An ArgName is just the str name of the argument in schema;
@ -41,7 +36,7 @@ ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True)
class BaseCppType:
ns: str | None
ns: Optional[str]
name: str
def __str__(self) -> str:
@ -76,7 +71,7 @@ class CType(ABC):
raise NotImplementedError
@abstractmethod
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return self
@ -92,13 +87,13 @@ class BaseCType(CType):
def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "")
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return self
@dataclass(frozen=True)
class ConstRefCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
@ -108,13 +103,13 @@ class ConstRefCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
@dataclass(frozen=True)
class VectorCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -123,13 +118,13 @@ class VectorCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayCType(CType):
elem: CType
elem: "CType"
size: int
def cpp_type(self, *, strip_ref: bool = False) -> str:
@ -139,13 +134,13 @@ class ArrayCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True)
class TupleCType(CType):
elems: list[CType]
elems: List["CType"]
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -154,13 +149,13 @@ class TupleCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True)
class MutRefCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref:
@ -170,7 +165,7 @@ class MutRefCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return self.elem.remove_const_ref()
@ -195,10 +190,10 @@ class NamedCType:
def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> NamedCType:
def remove_const_ref(self) -> "NamedCType":
return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> NamedCType:
def with_name(self, name: str) -> "NamedCType":
return NamedCType(name, self.type)
@ -213,11 +208,11 @@ class NamedCType:
class Binding:
name: str
nctype: NamedCType
argument: Argument | TensorOptionsArguments | SelfArgument
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
# TODO: maybe don't represent default here
default: str | None = None
default: Optional[str] = None
def rename(self, name: str) -> Binding:
def rename(self, name: str) -> "Binding":
return Binding(
name=name,
nctype=self.nctype,
@ -229,7 +224,7 @@ class Binding:
def type(self) -> str:
return self.nctype.cpp_type()
def no_default(self) -> Binding:
def no_default(self) -> "Binding":
return Binding(
name=self.name,
nctype=self.nctype,
@ -260,7 +255,7 @@ class Binding:
def defn(self) -> str:
return f"{self.type} {self.name}"
def with_name(self, name: str) -> Binding:
def with_name(self, name: str) -> "Binding":
return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default
)

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import torchgen.api.types as api_types
from torchgen.api import cpp, structured
@ -39,7 +38,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
# argument registers)
#
# NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
# Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None:
@ -135,8 +134,8 @@ def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
@dataclass(frozen=True)
class UfunctorBindings:
ctor: list[Binding]
apply: list[Binding]
ctor: List[Binding]
apply: List[Binding]
# ufunctors are a CUDA-only concept representing functors that take some of
@ -157,7 +156,7 @@ class UfunctorBindings:
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition
def ufunctor_arguments(
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
) -> UfunctorBindings:
ctor = []
apply = []
@ -186,7 +185,7 @@ def ufunctor_arguments(
# }
#
# In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
return [
ufunc_argument(a, compute_t=compute_t)
for a in g.functional.func.arguments.flat_non_out
@ -198,7 +197,7 @@ def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Bindin
#
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
# stubs drop all tensor arguments (they are implicit in the TensorIterator
# argument and keep everything else)
return [

View File

@ -1,4 +1,4 @@
from __future__ import annotations
from typing import List, Tuple
from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignatureGroup, CType
@ -103,7 +103,7 @@ def name(f: NativeFunction) -> str:
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
# we need the 'self' argument so method needs to be False
args = (
CppSignatureGroup.from_native_function(f, method=False)
@ -138,7 +138,7 @@ def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
def argumenttype_ivalue_convert(
t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
) -> Tuple[str, CType, List[str], List[str]]:
# Unboxing is for mobile, which doesn't care about SymInts
ctype = cpp.argumenttype_type(
t=t, mutable=mutable, binds=arg_name, symint=False
@ -172,7 +172,7 @@ def argumenttype_ivalue_convert(
def _gen_code_base_type(
arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
return [
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
@ -180,7 +180,7 @@ def _gen_code_base_type(
def _gen_code_optional_type(
arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_opt_in"
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
return (
@ -203,7 +203,7 @@ if ({arg_name}_opt.has_value()) {{
def _gen_code_list_type(
arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]

View File

@ -1,7 +1,5 @@
from __future__ import annotations
import re
from typing import Mapping, Sequence
from typing import Mapping, Match, Optional, Sequence
# match $identifier or ${identifier} and replace with value in env
@ -22,7 +20,7 @@ class CodeTemplate:
filename: str
@staticmethod
def from_file(filename: str) -> CodeTemplate:
def from_file(filename: str) -> "CodeTemplate":
with open(filename) as f:
return CodeTemplate(f.read(), filename)
@ -31,7 +29,7 @@ class CodeTemplate:
self.filename = filename
def substitute(
self, env: Mapping[str, object] | None = None, **kwargs: object
self, env: Optional[Mapping[str, object]] = None, **kwargs: object
) -> str:
if env is None:
env = {}
@ -45,7 +43,7 @@ class CodeTemplate:
[indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip()
def replace(match: re.Match[str]) -> str:
def replace(match: Match[str]) -> str:
indent = match.group(1)
key = match.group(2)
comma_before = ""

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import contextlib
import functools
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
import torchgen.local as local
from torchgen.model import (
@ -40,7 +38,7 @@ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
@contextlib.contextmanager
def native_function_manager(
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions
@ -120,10 +118,10 @@ def with_native_function_and_index(
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices(
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func)
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f):
return func(f, backend_indices)

View File

@ -1,9 +1,7 @@
from __future__ import annotations
import itertools
from abc import ABC
from dataclasses import dataclass
from typing import Any
from typing import Any, Dict, List, Optional, Tuple, Union
import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import (
@ -111,7 +109,7 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str:
def gen_fallback_code(
schema: LazyIrSchema,
sig: DispatcherSignature | NativeSignature,
sig: Union[DispatcherSignature, NativeSignature],
overload_name: str,
) -> str:
"""
@ -149,9 +147,9 @@ def aten_symbol(schema: LazyIrSchema) -> str:
# converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
context: List[Binding] = []
unwrapped_tensor_args: List[str] = []
for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta"
@ -173,7 +171,7 @@ class GenLazyIR(ABC):
use_lazy_shape: bool
@method_with_native_function
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f
@ -238,7 +236,7 @@ class GenLazyIR(ABC):
/* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))"""
def gen(self, schema: LazyIrSchema) -> list[str]:
def gen(self, schema: LazyIrSchema) -> List[str]:
opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs
@ -415,7 +413,7 @@ class GenLazyNativeFuncDefinition:
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False)
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
lazy_tensor_decls: list[str] = []
lazy_tensor_decls: List[str] = []
for arg in value_args:
if arg.is_wrapped_scalar:
if isinstance(arg.lazy_type, OptionalCType):
@ -462,7 +460,7 @@ class GenLazyNativeFuncDefinition:
func: NativeFunction,
schema: LazyIrSchema,
metadata: BackendMetadata,
sig: DispatcherSignature | NativeSignature,
sig: Union[DispatcherSignature, NativeSignature],
) -> str:
if self.gen_forced_fallback_code:
return gen_fallback_code(
@ -576,7 +574,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
}}
"""
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
# xla uses an instance method for tensor creation, for the time being
if self.create_from_first_tensor:
# TODO(whc) remove this if XLA switches to using static method for creation
@ -617,7 +615,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
return bridge_str
@method_with_native_function
def __call__(self, func: NativeFunction) -> list[str]:
def __call__(self, func: NativeFunction) -> List[str]:
sig = kernel_signature(func, self.backend_index)
metadata = self.backend_index.get_kernel(func)
assert metadata is not None
@ -641,7 +639,7 @@ class ComputeShapeSignature:
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
"""
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
self.__schema = LazyIrSchema(f.func, symint=symint)
self.__dispatch_args = ", ".join(
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
@ -672,7 +670,7 @@ class GenLazyShapeInferenceDefinition:
tensor_class: str
@method_with_native_function
def __call__(self, f: NativeFunction) -> list[str]:
def __call__(self, f: NativeFunction) -> List[str]:
metadata = self.backend_index.get_kernel(f)
assert metadata is not None
@ -689,8 +687,8 @@ class GenLazyShapeInferenceDefinition:
def generate_non_native_lazy_ir_nodes(
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> list[str]:
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> List[str]:
"""Generate the non-native lazy IR node classes"""
nodes = []
for op in non_native:

View File

@ -1,4 +1,4 @@
from __future__ import annotations
from typing import List, Optional, Union
import torchgen.api.meta as meta
import torchgen.api.structured as structured
@ -9,7 +9,7 @@ from torchgen.utils import mapMaybe
@with_native_function_and_index
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
sig = kernel_signature(f, backend_index)
metadata = backend_index.get_kernel(f)
if metadata is None:
@ -22,7 +22,7 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | No
@with_native_function_and_index
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
meta_name = meta.name(g)
out_args = structured.impl_arguments(g)
metadata = backend_index.get_kernel(g)
@ -42,8 +42,8 @@ void impl({', '.join(a.decl() for a in out_args)});
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function_and_index
def compute_native_function_declaration(
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
) -> list[str]:
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
) -> List[str]:
metadata = backend_index.get_kernel(g)
if isinstance(g, NativeFunctionsGroup):
if metadata is not None and metadata.structured:

View File

@ -1,9 +1,7 @@
from __future__ import annotations
import itertools
import textwrap
from dataclasses import dataclass
from typing import Literal, TYPE_CHECKING
from typing import List, Literal, Optional, Tuple, Union
import torchgen.api.cpp as cpp
import torchgen.api.meta as meta
@ -36,18 +34,15 @@ from torchgen.model import (
SchemaKind,
TensorOptionsArguments,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import assert_never, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
def gen_registration_headers(
backend_index: BackendIndex,
per_operator_headers: bool,
rocm: bool,
) -> list[str]:
) -> List[str]:
if per_operator_headers:
headers = ["#include <ATen/ops/as_strided_native.h>"]
else:
@ -78,7 +73,7 @@ def gen_registration_headers(
def gen_empty_impl_names(
backend_index: BackendIndex,
) -> tuple[str | None, str | None]:
) -> Tuple[Optional[str], Optional[str]]:
empty_impl = None
empty_strided_impl = None
@ -102,7 +97,7 @@ def gen_empty_impl_names(
return empty_impl, empty_strided_impl
def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
if backend_index.dispatch_key == DispatchKey.Meta:
empty_options = "options.device(at::kMeta)"
else:
@ -125,7 +120,7 @@ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &o
]
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
_, empty_strided_impl = gen_empty_impl_names(backend_index)
return (
[]
@ -143,7 +138,7 @@ std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I
)
def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
# The function isn't used by this key (since only functional ops have a kernel for this key),
# so we need to not include it to avoid a defined-but-not-used error.
@ -173,7 +168,7 @@ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const
]
def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
return [
"""
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
@ -196,7 +191,7 @@ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &o
]
def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
return [
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
*gen_create_out_helper(backend_index),
@ -254,7 +249,7 @@ class RegisterDispatchKey:
# Finally, this field is currently Optional because it is only used by external backends.
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
# all of the existing kernel signatures scattered across aten/src/ATen/native.
class_method_name: str | None
class_method_name: Optional[str]
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
@ -262,7 +257,7 @@ class RegisterDispatchKey:
@staticmethod
def gen_device_check(
type: DeviceCheckType, args: list[Argument], method_name: str
type: DeviceCheckType, args: List[Argument], method_name: str
) -> str:
if type == DeviceCheckType.NoCheck:
return " // No device check\n"
@ -277,7 +272,7 @@ class RegisterDispatchKey:
return device_check
@method_with_native_function
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
if isinstance(f, NativeFunctionsGroup):
g: NativeFunctionsGroup = f
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
@ -296,7 +291,7 @@ class RegisterDispatchKey:
def wrapper_kernel_sig(
self, f: NativeFunction
) -> NativeSignature | DispatcherSignature:
) -> Union[NativeSignature, DispatcherSignature]:
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
return DispatcherSignature.from_schema(
f.func,
@ -305,8 +300,8 @@ class RegisterDispatchKey:
)
def gen_out_inplace_wrapper(
self, f: NativeFunction, g: NativeFunctionsGroup | None
) -> str | None:
self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
) -> Optional[str]:
if g is None:
return None
k = f.func.kind()
@ -355,7 +350,7 @@ class RegisterDispatchKey:
}}
"""
def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
metadata = self.backend_index.get_kernel(g)
if self.backend_index.dispatch_key == DispatchKey.Meta:
assert not self.backend_index.has_kernel(g.out), (
@ -385,8 +380,8 @@ class RegisterDispatchKey:
return list(mapMaybe(structured_gen.gen_one, g.functions()))
def gen_unstructured(
self, f: NativeFunction, g: NativeFunctionsGroup | None = None
) -> str | None:
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
) -> Optional[str]:
with native_function_manager(f):
inplace_meta = False
gets_out_inplace_wrapper = False
@ -737,7 +732,7 @@ resize_out(out, sizes, strides, options);
return "\n".join(line for line in lines if line)
@method_with_native_function
def gen_one(self, f: NativeFunction) -> str | None:
def gen_one(self, f: NativeFunction) -> Optional[str]:
assert not f.manual_kernel_registration
if (
@ -811,7 +806,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
sig_body = []
# We'll use context to keep track of any variables we've brought
# into scope while generating code
context: list[Binding | Expr] = list(sig.arguments())
context: List[Union[Binding, Expr]] = list(sig.arguments())
# Initialize the class corresponding to this structured
# operator; feeding it the output argument(s) if it is known

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
@ -16,6 +14,7 @@ from torchgen.api.types import (
StructuredImplSignature,
VectorizedCType,
)
from torchgen.api.ufunc import UfunctorBindings
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
@ -29,10 +28,6 @@ from torchgen.model import (
from torchgen.utils import OrderedSet
if TYPE_CHECKING:
from torchgen.api.ufunc import UfunctorBindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CUDA STUFF
@ -65,7 +60,7 @@ if TYPE_CHECKING:
@dataclass(frozen=True)
class UfunctorSignature:
g: NativeFunctionsGroup
scalar_tensor_idx: int | None
scalar_tensor_idx: Optional[int]
name: str
def arguments(self) -> UfunctorBindings:
@ -73,7 +68,7 @@ class UfunctorSignature:
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
)
def fields(self) -> list[Binding]:
def fields(self) -> List[Binding]:
# fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
@ -103,10 +98,10 @@ class UfuncSignature:
name: str
compute_t: CType
def arguments(self) -> list[Binding]:
def arguments(self) -> List[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Binding | Expr]) -> str:
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@ -137,10 +132,10 @@ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup,
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors.
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: list[str] = []
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: List[str] = []
loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1,
@ -242,7 +237,7 @@ BinaryScalarSpecializationConfigs = [
def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfunctorSignature],
inner_loops: Dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding],
) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;"
@ -254,7 +249,7 @@ def compute_ufunc_cuda_dtype_body(
scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway)
ctx: list[Expr | Binding] = list(parent_ctx)
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
ctx.append(
Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
@ -351,7 +346,7 @@ class StubSignature:
def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> list[Binding]:
def arguments(self) -> List[Binding]:
return ufunc.stub_arguments(self.g)
def type(self) -> str:
@ -398,7 +393,7 @@ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: dict[UfuncKey, UfuncSignature],
inner_loops: Dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding],
) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
@ -464,8 +459,8 @@ def compute_ufunc_cpu_dtype_body(
)
)
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
r: list[Expr | Binding] = []
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
r: List[Union[Expr, Binding]] = []
r.extend(ctx)
r.extend(b)
return r
@ -494,7 +489,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = []
# ORDER MATTERS: this specifies overriding precedence

View File

@ -1,29 +1,24 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import Sequence, TYPE_CHECKING
from typing import Dict, List, Optional, Sequence, Tuple
from torchgen import dest
# disable import sorting to avoid circular dependency.
from torchgen.api.types import DispatcherSignature # usort:skip
from torchgen.context import method_with_native_function
from torchgen.executorch.model import ETKernelIndex
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import concatMap, Target
if TYPE_CHECKING:
from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
# model authoring side.
@dataclass(frozen=True)
class ComputeNativeFunctionStub:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants:
return None
@ -85,7 +80,7 @@ def gen_custom_ops_registration(
selector: SelectiveBuilder,
kernel_index: ETKernelIndex,
rocm: bool,
) -> tuple[str, str]:
) -> Tuple[str, str]:
"""
Generate custom ops registration code for dest.RegisterDispatchKey.
@ -102,7 +97,7 @@ def gen_custom_ops_registration(
dispatch_key = DispatchKey.CPU
backend_index = kernel_index._to_backend_index()
static_init_dispatch_registrations = ""
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_grouped_native_functions[native_function.namespace].append(native_function)

View File

@ -1,6 +1,4 @@
from __future__ import annotations
from typing import Sequence
from typing import List, Optional, Sequence, Set, Union
from torchgen import local
from torchgen.api.types import (
@ -65,7 +63,7 @@ def valuetype_type(
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType | None:
) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
@ -211,7 +209,7 @@ def returns_type(rs: Sequence[Return]) -> CType:
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: list[str] = []
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
@ -297,16 +295,16 @@ def default_expr(d: str, t: Type) -> str:
def argument(
a: Argument | TensorOptionsArguments | SelfArgument,
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*,
cpp_no_default_args: set[str],
cpp_no_default_args: Set[str],
method: bool,
faithful: bool,
has_tensor_options: bool,
) -> list[Binding]:
) -> List[Binding]:
def sub_argument(
a: Argument | TensorOptionsArguments | SelfArgument,
) -> list[Binding]:
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
@ -321,7 +319,7 @@ def argument(
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: str | None = None
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [
@ -349,9 +347,9 @@ def arguments(
*,
faithful: bool,
method: bool,
cpp_no_default_args: set[str],
) -> list[Binding]:
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
cpp_no_default_args: Set[str],
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)

View File

@ -1,15 +1,10 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import List, Optional, Set
import torchgen.api.cpp as aten_cpp
from torchgen.api.types import Binding, CType
from torchgen.executorch.api.types.types import contextArg
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType
from torchgen.model import FunctionSchema, NativeFunction
from torchgen.model import FunctionSchema, NativeFunction
@dataclass(frozen=True)
@ -25,14 +20,14 @@ class ExecutorchCppSignature:
func: FunctionSchema
# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: set[str]
cpp_no_default_args: Set[str]
# Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions.
prefix: str = ""
def arguments(self, *, include_context: bool = True) -> list[Binding]:
def arguments(self, *, include_context: bool = True) -> List[Binding]:
return ([contextArg] if include_context else []) + et_cpp.arguments(
self.func.arguments,
faithful=True, # always faithful, out argument at the end
@ -46,7 +41,7 @@ class ExecutorchCppSignature:
faithful_name_for_out_overloads=True,
)
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
args_str = ", ".join(
a.decl() for a in self.arguments(include_context=include_context)
)
@ -54,7 +49,7 @@ class ExecutorchCppSignature:
name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(self, name: str | None = None) -> str:
def defn(self, name: Optional[str] = None) -> str:
args = [a.defn() for a in self.arguments()]
args_str = ", ".join(args)
if name is None:
@ -67,7 +62,7 @@ class ExecutorchCppSignature:
@staticmethod
def from_native_function(
f: NativeFunction, *, prefix: str = ""
) -> ExecutorchCppSignature:
) -> "ExecutorchCppSignature":
return ExecutorchCppSignature(
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
)

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
from torchgen.api.types import (
BaseCppType,
@ -41,7 +40,7 @@ contextArg = Binding(
default=None,
)
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
BaseTy.int: longT,
BaseTy.float: doubleT,
BaseTy.bool: boolT,
@ -55,7 +54,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
@dataclass(frozen=True)
class OptionalCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -64,13 +63,13 @@ class OptionalCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True)
class ArrayRefCType(CType):
elem: CType
elem: "CType"
def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively.
@ -79,5 +78,5 @@ class ArrayRefCType(CType):
def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> CType:
def remove_const_ref(self) -> "CType":
return ArrayRefCType(self.elem.remove_const_ref())

View File

@ -1,8 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING
from typing import Callable, List, Sequence, Tuple
from torchgen.api.types import Binding, CType, NamedCType
from torchgen.model import (
Argument,
BaseTy,
@ -14,10 +13,6 @@ from torchgen.model import (
)
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType, NamedCType
connector = "\n\t"
@ -57,7 +52,7 @@ class Unboxing:
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> tuple[list[Binding], list[str]]:
) -> Tuple[List[Binding], List[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
@ -77,7 +72,7 @@ class Unboxing:
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> tuple[str, CType, list[str], list[str]]:
) -> Tuple[str, CType, List[str], List[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
@ -112,14 +107,14 @@ class Unboxing:
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
@ -135,7 +130,7 @@ class Unboxing:
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []

View File

@ -1,12 +1,11 @@
# Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
from __future__ import annotations
import itertools
from collections import defaultdict, namedtuple
from dataclasses import dataclass
from enum import IntEnum
from typing import Dict, List, Tuple, Union
from torchgen.model import (
BackendIndex,
@ -42,7 +41,7 @@ class ETKernelKeyOpArgMeta:
arg_name: str
dtype: str
# The order of the dimensions if entry is a Tensor
dim_order: tuple[int, ...]
dim_order: Tuple[int, ...]
def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value
@ -53,7 +52,7 @@ class ETKernelKeyOpArgMeta:
@dataclass(frozen=True)
class ETKernelKey:
# Field undefined is default = True
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all
default: bool = False
@ -62,10 +61,10 @@ class ETKernelKey:
@staticmethod
def gen_from_yaml(
args: dict[str, tuple[str, str]],
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
dim_order_alias_map: dict[str, list[int]],
) -> list[ETKernelKey]:
args: Dict[str, Tuple[str, str]],
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
dim_order_alias_map: Dict[str, List[int]],
) -> List["ETKernelKey"]:
"""Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey)
@ -138,15 +137,15 @@ class ETKernelKey:
@dataclass(frozen=True)
class ETKernelIndex:
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
m = self.get_kernels(g)
return m is not None
def get_kernels(
self, g: NativeFunction | NativeFunctionsGroup
) -> dict[ETKernelKey, BackendMetadata]:
self, g: Union[NativeFunction, NativeFunctionsGroup]
) -> Dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
@ -159,8 +158,8 @@ class ETKernelIndex:
@staticmethod
def grow_from_backend_indices(
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
) -> None:
for dk in backend_indices:
index = backend_indices[dk]
@ -172,17 +171,17 @@ class ETKernelIndex:
@staticmethod
def from_backend_indices(
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
kernel_index: dict[
OperatorName, dict[ETKernelKey, BackendMetadata]
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex":
kernel_index: Dict[
OperatorName, Dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index)
def grow(
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> ETKernelIndex:
self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex":
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self
@ -190,7 +189,7 @@ class ETKernelIndex:
"""
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
"""
index: dict[OperatorName, BackendMetadata] = {}
index: Dict[OperatorName, BackendMetadata] = {}
for op in self.index:
kernel_dict = self.index[op]
assert (
@ -210,7 +209,9 @@ class ETKernelIndex:
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
def merge_indices(
index_a: "ETKernelIndex", index_b: "ETKernelIndex"
) -> "ETKernelIndex":
combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items():

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from collections import defaultdict, namedtuple
from typing import Any
from typing import Any, Dict, List, Optional, Set, Tuple
import yaml
@ -24,7 +22,7 @@ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indice
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
"""Given a loaded yaml representing kernel assignment information, extract the
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
@ -36,11 +34,11 @@ def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]
if (kernels := e.pop("kernels", None)) is None:
return {}
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
dim_order_alias.pop("__line__", None)
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
for entry in kernels: # type: ignore[attr-defined]
arg_meta = entry.get("arg_meta")
@ -78,7 +76,7 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
that should be used by the kernel key).
"""
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
for ei in es: # type: ignore[attr-defined]
e = ei.copy()
@ -97,11 +95,11 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
return ETKernelIndex(indices)
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
"""Given a loaded yaml representing a list of operators, extract the
kernel key related fields indexed by the operator name.
"""
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
for ei in es: # type: ignore[attr-defined]
funcs = ei.get("func")
assert isinstance(funcs, str), f"not a str: {funcs}"
@ -120,9 +118,9 @@ def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
def parse_et_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: set[DispatchKey] | None = None,
ignore_keys: Optional[Set[DispatchKey]] = None,
skip_native_fns_gen: bool = False,
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml
"""

View File

@ -1,13 +1,23 @@
from __future__ import annotations
import argparse
import functools
import json
import os
import pathlib
from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Literal, Sequence, TypeVar
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import yaml
@ -138,20 +148,20 @@ class LineLoader(YamlLoader):
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {}
def parse_native_yaml_struct(
es: object,
valid_tags: set[str],
ignore_keys: set[DispatchKey] | None = None,
valid_tags: Set[str],
ignore_keys: Optional[Set[DispatchKey]] = None,
path: str = "<stdin>",
skip_native_fns_gen: bool = False,
) -> ParsedYaml:
assert isinstance(es, list)
rs: list[NativeFunction] = []
bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
rs: List[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es:
assert isinstance(e, dict), f"expected to be dict: {e}"
assert isinstance(e.get("__line__"), int), e
@ -164,7 +174,7 @@ def parse_native_yaml_struct(
BackendIndex.grow_index(bs, m)
error_check_native_functions(rs)
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
indices: dict[DispatchKey, BackendIndex] = defaultdict(
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined,
use_out_as_primary=True,
@ -190,9 +200,9 @@ def parse_native_yaml_struct(
return ParsedYaml(rs, indices)
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
assert isinstance(es, list)
rs: set[str] = set()
rs: Set[str] = set()
for e in es:
assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"])
@ -208,7 +218,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
@functools.lru_cache(maxsize=None)
def parse_tags_yaml(path: str) -> set[str]:
def parse_tags_yaml(path: str) -> Set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path) as f:
@ -221,10 +231,10 @@ def parse_tags_yaml(path: str) -> set[str]:
def parse_native_yaml(
path: str,
tags_yaml_path: str,
ignore_keys: set[DispatchKey] | None = None,
ignore_keys: Optional[Set[DispatchKey]] = None,
*,
skip_native_fns_gen: bool = False,
loaded_yaml: object | None = None,
loaded_yaml: Optional[object] = None,
) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
@ -251,8 +261,8 @@ def parse_native_yaml(
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
# Assertions here are meant to be performed across NativeFunctions.
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: dict[OperatorName, NativeFunction] = {}
base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
func_map: Dict[OperatorName, NativeFunction] = {}
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
for f in funcs:
func_map[f.func.name] = f
base_func_map[f.func.name.name].append(f)
@ -319,7 +329,7 @@ def cpp_string(s: str) -> str:
# and similar functional combinators.
def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
if len(backends) == 0:
return []
else:
@ -333,7 +343,7 @@ def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
def get_static_dispatch_backend(
f: NativeFunction, backend_index: BackendIndex
) -> DispatchKey | None:
) -> Optional[DispatchKey]:
if f.structured_delegate is not None or backend_index.has_kernel(f):
# TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
@ -352,8 +362,8 @@ def get_static_dispatch_backend(
def static_dispatch_ops_header(
f: NativeFunction, backend_index: list[BackendIndex]
) -> str | None:
f: NativeFunction, backend_index: List[BackendIndex]
) -> Optional[str]:
if backend_index is None or f.manual_kernel_registration:
return None
@ -367,7 +377,7 @@ def static_dispatch_ops_header(
return "\n".join(output)
def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
return [
f"#include <ATen/{dispatch_key}Functions.h>"
for dispatch_key in static_dispatch_keys(backends)
@ -378,12 +388,12 @@ def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
# Note that we have a special case for `memory_format` argument and this case is not covered by
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
def translate_args(
sig: CppSignature | DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
cpp_sig: CppSignature,
) -> str:
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
output_bindings: list[Binding] = []
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
output_bindings: List[Binding] = []
for binding in input_bindings:
if binding.name == "memory_format":
spl_mem_format_binding = Binding(
@ -413,7 +423,7 @@ def translate_args(
def generate_static_dispatch_backend_call(
sig: CppSignature | DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_index: BackendIndex,
) -> str:
@ -431,9 +441,9 @@ def generate_static_dispatch_backend_call(
def generate_static_dispatch_fallback_call(
sig: CppSignature | DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: list[BackendIndex],
backend_indices: List[BackendIndex],
) -> str:
cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False
@ -460,9 +470,9 @@ def generate_static_dispatch_fallback_call(
def static_dispatch(
sig: CppSignature | DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
backend_indices: list[BackendIndex],
backend_indices: List[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
@ -502,7 +512,7 @@ def static_dispatch(
tensor_opts = f.func.arguments.tensor_options
stmts = []
subexprs: list[str] = []
subexprs: List[str] = []
if tensor_opts is not None:
subexprs.append(
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
@ -538,10 +548,10 @@ def static_dispatch(
@dataclass(frozen=True)
class RegisterSchema:
selector: SelectiveBuilder
known_tags: dict[str, int] = field(default_factory=dict)
known_tags: Dict[str, int] = field(default_factory=dict)
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
if not self.selector.is_native_function_selected(f):
return None
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
@ -563,7 +573,7 @@ class RegisterSchema:
@dataclass(frozen=True)
class ComputeOperators:
target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: list[BackendIndex]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str:
@ -660,7 +670,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
@dataclass(frozen=True)
class ComputeFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
)
@ -708,10 +718,10 @@ namespace symint {{
@dataclass(frozen=True)
class ComputeTensorMethod:
target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: list[BackendIndex]
static_dispatch_backend_indices: List[BackendIndex]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
return None
@ -754,7 +764,7 @@ inline {sig.defn(prefix="Tensor::")} const {{
@dataclass(frozen=True)
class ComputeRedispatchFunction:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
# We unconditionally generate function variants of the redispatch API.
# This is mainly because we can namespace functions separately, but not methods,
sig_group = CppSignatureGroup.from_native_function(
@ -788,7 +798,7 @@ def compute_aten_op(f: NativeFunction) -> str:
# Generates MetaFunctions.h
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
if not g.structured:
return None
with native_function_manager(g.out):
@ -933,7 +943,7 @@ class ComputeBackendSelect:
selector: SelectiveBuilder
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
if not needs_backend_select(f, self.selector):
return None
@ -949,7 +959,7 @@ class ComputeBackendSelect:
dispatcher_sig = DispatcherSignature.from_schema(f.func)
sig: NativeSignature | DispatcherSignature
sig: Union[NativeSignature, DispatcherSignature]
sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
@ -1049,7 +1059,7 @@ def dynamic_type(t: Type) -> str:
).cpp_type()
def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
# This is written out explicitly to ensure that Tensor and
# namespace are put into the list in the right order
method_of = ["Type"]
@ -1062,7 +1072,7 @@ def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
def compute_returns_yaml(
f: NativeFunction,
) -> tuple[list[dict[str, str]], dict[str, str]]:
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
# Note [name and field_name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# To understand name_to_field_name, we must first talk about this
@ -1102,7 +1112,7 @@ def compute_returns_yaml(
# schema itself.
#
# See also https://github.com/pytorch/pytorch/issues/43114
name_to_field_name: dict[str, str] = {}
name_to_field_name: Dict[str, str] = {}
# Compute the returns field of the YAML entry
names = cpp.return_names(f)
@ -1131,12 +1141,12 @@ def compute_cpp_argument_yaml(
cpp_a: Binding,
*,
schema_order: bool,
kwarg_only_set: set[str],
out_arg_set: set[str],
name_to_field_name: dict[str, str],
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
if isinstance(cpp_a.argument, TensorOptionsArguments):
arg: dict[str, object] = {
arg: Dict[str, object] = {
"annotation": None,
"dynamic_type": "at::TensorOptions",
"is_nullable": False,
@ -1163,11 +1173,11 @@ def compute_argument_yaml(
a: Argument,
*,
schema_order: bool,
kwarg_only_set: set[str],
out_arg_set: set[str],
name_to_field_name: dict[str, str],
kwarg_only_set: Set[str],
out_arg_set: Set[str],
name_to_field_name: Dict[str, str],
) -> object:
arg: dict[str, object] = {
arg: Dict[str, object] = {
"annotation": str(a.annotation) if a.annotation else None,
"dynamic_type": dynamic_type(a.type),
"is_nullable": a.type.is_nullable(),
@ -1293,7 +1303,7 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
@with_native_function_and_indices
def compute_registration_declarations(
f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
) -> str:
name = dispatcher.name(f.func)
returns_type = dispatcher.returns_type(
@ -1301,7 +1311,7 @@ def compute_registration_declarations(
).cpp_type_registration_declarations()
args = dispatcher.arguments(f.func)
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
comment_data: dict[str, str] = {
comment_data: Dict[str, str] = {
"schema": f"aten::{f.func}",
# TODO: What exactly is the semantics of the 'dispatch' field?
"dispatch": str(
@ -1327,8 +1337,8 @@ def compute_registration_declarations(
def get_custom_build_selector(
provided_op_registration_allowlist: list[str] | None,
op_selection_yaml_path: str | None,
provided_op_registration_allowlist: Optional[List[str]],
op_selection_yaml_path: Optional[str],
) -> SelectiveBuilder:
assert not (
provided_op_registration_allowlist is not None
@ -1339,7 +1349,7 @@ def get_custom_build_selector(
+ "same time."
)
op_registration_allowlist: set[str] | None = None
op_registration_allowlist: Optional[Set[str]] = None
if provided_op_registration_allowlist is not None:
op_registration_allowlist = set(provided_op_registration_allowlist)
@ -1359,11 +1369,11 @@ def get_custom_build_selector(
def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
def maybe_create_view_group(
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
) -> list[NativeFunction | NativeFunctionsViewGroup]:
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
if ViewSchemaKind.aliasing in d:
view = d.pop(ViewSchemaKind.aliasing)
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
@ -1381,8 +1391,8 @@ def get_grouped_by_view_native_functions(
funcs.extend(d.values())
return funcs
grouped_by_views: dict[
FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
grouped_by_views: Dict[
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
] = defaultdict(dict)
for f in native_functions:
schema = f.func.view_signature()
@ -1406,10 +1416,10 @@ def get_grouped_by_view_native_functions(
def get_grouped_native_functions(
native_functions: Sequence[NativeFunction],
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
def flatten_pre_group(
d: dict[SchemaKind, NativeFunction]
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
d: Dict[SchemaKind, NativeFunction]
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
r = NativeFunctionsGroup.from_dict(d)
if r is None:
# Invariant: any NativeFunctions that are code-generated
@ -1428,13 +1438,13 @@ def get_grouped_native_functions(
def get_ns_grouped_kernels(
*,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
] = dest.compute_native_function_declaration,
) -> dict[str, list[str]]:
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
) -> Dict[str, List[str]]:
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
for f in grouped_native_functions:
native_function_namespaces = set()
dispatch_keys = set()
@ -1457,9 +1467,9 @@ def get_ns_grouped_kernels(
def get_native_function_declarations_from_ns_grouped_kernels(
*,
ns_grouped_kernels: dict[str, list[str]],
) -> list[str]:
declarations: list[str] = []
ns_grouped_kernels: Dict[str, List[str]],
) -> List[str]:
declarations: List[str] = []
newline = "\n"
for namespace, kernels in ns_grouped_kernels.items():
ns_helper = NamespaceHelper(
@ -1485,12 +1495,12 @@ def get_native_function_declarations_from_ns_grouped_kernels(
# Return native function declarations grouped by their namespaces.
def get_native_function_declarations(
*,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
] = dest.compute_native_function_declaration,
) -> list[str]:
) -> List[str]:
"""
Generate kernel declarations, in `NativeFunction(s).h`.
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
@ -1510,7 +1520,7 @@ def get_native_function_declarations(
def get_kernel_namespace(
*, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
) -> str:
backend_metadata = backend_idx.get_kernel(f)
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
@ -1528,7 +1538,7 @@ def get_kernel_namespace(
def get_native_function_definitions(
*,
fm: FileManager,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
@ -1536,11 +1546,11 @@ def get_native_function_definitions(
symint: bool,
skip_dispatcher_op_registration: bool,
gen_dispatch_helpers: bool,
) -> list[str]:
definitions: list[str] = []
ns_definitions: dict[str, list[str]] = defaultdict(list)
anonymous_definitions: dict[str, list[str]] = defaultdict(list)
registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
) -> List[str]:
definitions: List[str] = []
ns_definitions: Dict[str, List[str]] = defaultdict(list)
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
newline = "\n"
ns_gen = dest.RegisterDispatchKey(
backend_idx,
@ -1630,15 +1640,15 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
# Used in CPUFunctions_inl.h and etc.
def get_namespaced_declaration(
*,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
dispatch_key: DispatchKey,
backend_idx: BackendIndex,
selector: SelectiveBuilder,
rocm: bool,
symint: bool,
) -> list[str]:
declarations: list[str] = []
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
) -> List[str]:
declarations: List[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
newline = "\n"
func = dest.RegisterDispatchKey(
backend_idx,
@ -1682,8 +1692,8 @@ def get_native_function_schema_registrations(
*,
native_functions: Sequence[NativeFunction],
schema_selector: SelectiveBuilder,
) -> tuple[list[str], str]:
ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
) -> Tuple[List[str], str]:
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
for native_function in native_functions:
ns_native_functions[native_function.namespace].append(native_function)
schema_registrations = ""
@ -1717,14 +1727,14 @@ def get_native_function_schema_registrations(
def gen_aggregated_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: list[BackendIndex],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
functions_keys: set[DispatchKey],
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
@ -1838,25 +1848,25 @@ def gen_aggregated_headers(
def gen_per_operator_headers(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
static_dispatch_idx: list[BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
functions_keys: set[DispatchKey],
functions_keys: Set[DispatchKey],
dispatch_keys: Sequence[DispatchKey],
rocm: bool,
) -> None:
# For CMake builds, split operator declarations into separate headers in
# the ATen/ops folder to split up header dependencies
functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list)
for fn in native_functions:
functions_by_root_name[fn.root_name].append(fn)
grouped_functions_by_root_name: dict[
str, list[NativeFunction | NativeFunctionsGroup]
grouped_functions_by_root_name: Dict[
str, List[Union[NativeFunction, NativeFunctionsGroup]]
] = defaultdict(list)
for group in grouped_native_functions:
name = group.root_name
@ -2032,18 +2042,18 @@ def gen_per_operator_headers(
def gen_headers(
*,
native_functions: Sequence[NativeFunction],
valid_tags: set[str],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
valid_tags: Set[str],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: list[BackendIndex],
static_dispatch_idx: List[BackendIndex],
selector: SelectiveBuilder,
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
core_fm: FileManager,
cpu_fm: FileManager,
cuda_fm: FileManager,
ops_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: set[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
per_operator_headers: bool,
) -> None:
@ -2123,8 +2133,8 @@ def gen_headers(
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
)
def gen_aten_interned_strings() -> dict[str, str]:
attrs: set[str] = set() # All function argument names
def gen_aten_interned_strings() -> Dict[str, str]:
attrs: Set[str] = set() # All function argument names
names = set() # All ATen function names
for func in native_functions:
names.add(str(func.func.name.name))
@ -2161,7 +2171,7 @@ def gen_headers(
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
def gen_tags_enum() -> dict[str, str]:
def gen_tags_enum() -> Dict[str, str]:
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
core_fm.write("enum_tag.h", gen_tags_enum)
@ -2170,19 +2180,19 @@ def gen_headers(
def gen_source_files(
*,
native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
structured_native_functions: Sequence[NativeFunctionsGroup],
view_groups: Sequence[NativeFunctionsViewGroup],
selector: SelectiveBuilder,
static_dispatch_idx: list[BackendIndex],
backend_indices: dict[DispatchKey, BackendIndex],
static_dispatch_idx: List[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
aoti_fm: FileManager,
core_fm: FileManager,
cpu_fm: FileManager,
cpu_vec_fm: FileManager,
cuda_fm: FileManager,
dispatch_keys: Sequence[DispatchKey],
functions_keys: set[DispatchKey],
functions_keys: Set[DispatchKey],
rocm: bool,
force_schema_registration: bool,
per_operator_headers: bool,
@ -2206,7 +2216,7 @@ def gen_source_files(
if per_operator_headers:
def operator_headers() -> list[str]:
def operator_headers() -> List[str]:
headers = []
for g in grouped_native_functions:
is_registered = False
@ -2248,7 +2258,7 @@ def gen_source_files(
else:
def operator_headers() -> list[str]:
def operator_headers() -> List[str]:
headers = ["#include <ATen/NativeFunctions.h>"]
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
headers.append("#include <ATen/Functions.h>")
@ -2439,7 +2449,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
del fm
# BackendSelect is generated specially
def gen_backend_select() -> dict[str, list[str]]:
def gen_backend_select() -> Dict[str, List[str]]:
relevant_fns = [
fn for fn in native_functions if needs_backend_select(fn, selector)
]
@ -2484,7 +2494,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
)
def key_func(
fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> str:
return fn.root_name
@ -2526,11 +2536,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
)
def functionalization_env_callable(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> dict[str, list[str]]:
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> Dict[str, List[str]]:
def gen_op_headers(
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
) -> List[str]:
if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel
headers = [
@ -2580,8 +2590,8 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
),
}
all_groups: list[
NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
all_groups: List[
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
] = list(structured_native_functions) + list(
view_groups # type: ignore[assignment, arg-type, operator]
)
@ -2590,11 +2600,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
structured_map: dict[OperatorName, NativeFunction] = {
structured_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
}
view_map: dict[OperatorName, NativeFunction] = {
view_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
}
for f in native_functions:
@ -2705,12 +2715,12 @@ def gen_declarations_yaml(
)
def get_torchgen_root() -> Path:
def get_torchgen_root() -> pathlib.Path:
"""
If you're depending on torchgen out-of-tree, you can use the root to figure
out the path to native_functions.yaml
"""
return Path(__file__).parent.resolve()
return pathlib.Path(__file__).parent.resolve()
def main() -> None:
@ -2872,11 +2882,11 @@ def main() -> None:
#
# Invalid character escape '\c'.
core_install_dir = f"{options.install_dir}/core"
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
ops_install_dir = f"{options.install_dir}/ops"
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
aoti_install_dir = f"{options.aoti_install_dir}"
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
cpu_fm = make_file_manager(options=options)
@ -2906,7 +2916,7 @@ def main() -> None:
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
]
static_dispatch_idx: list[BackendIndex] = []
static_dispatch_idx: List[BackendIndex] = []
if options.static_dispatch_backend:
static_dispatch_idx = [
backend_indices[DispatchKey.parse(key)]
@ -2963,7 +2973,7 @@ def main() -> None:
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
if options.output_dependencies:
depfile_path = Path(options.output_dependencies).resolve()
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Tuple, Union
from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
@ -71,7 +69,7 @@ base_type_to_callsite_expr = {
# convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return]
if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type:
return (
@ -169,12 +167,12 @@ def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str
)
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]:
return [typ + " " + name for typ, name in zip(types, names)]
# Generate argument declarations and callsite expressions
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]:
types = []
new_names = []
callsite_exprs = []
@ -191,7 +189,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[s
# Return values are passed out as pointer arguments because all the C shim functions
# are expected to return AOTITorchError.
# Generate returns as declarations and callsite expressions
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
types = []
names = []
for idx, ret in enumerate(schema.returns):
@ -224,7 +222,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
ret_pointer_can_be_null = True
break
callsite_exprs: list[str] = []
callsite_exprs: List[str] = []
for idx, ret in enumerate(schema.returns):
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
assert isinstance(ret.type, BaseType)
@ -238,12 +236,12 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {}
def gen_declaration_and_definition(
schema: FunctionSchema, device: str, backend_call: str
) -> tuple[str, str]:
) -> Tuple[str, str]:
func_name = schema.name.unambiguous_name()
global declaration_definition_cache
@ -256,7 +254,7 @@ def gen_declaration_and_definition(
args, callsite_exprs = gen_arguments(
[*schema.arguments.out, *schema.arguments.flat_non_out]
)
ret_assignments: list[str] = []
ret_assignments: List[str] = []
else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
# ignore return values for inplace ops
@ -286,7 +284,7 @@ def gen_declaration_and_definition(
def gen_static_dispatch_backend_call_signature(
sig: CppSignature | DispatcherSignature,
sig: Union[CppSignature, DispatcherSignature],
f: NativeFunction,
) -> CppSignature:
sig = DispatcherSignature.from_schema(f.func)
@ -312,10 +310,10 @@ def gen_static_dispatch_backend_call(
def get_backend_index_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
) -> BackendIndex | None:
backend_indices: Dict[DispatchKey, BackendIndex],
) -> Optional[BackendIndex]:
backend_index = None
if backend_indices[dispatch_key].has_kernel(func) or (
func.structured_delegate is not None
@ -343,10 +341,10 @@ def get_backend_index_for_aoti(
def get_header_for_aoti(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
) -> str | None:
backend_indices: Dict[DispatchKey, BackendIndex],
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
)
@ -367,11 +365,11 @@ def get_fallback_op_name(func: NativeFunction) -> str:
def gen_c_shim(
func: NativeFunction,
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
) -> str | None:
) -> Optional[str]:
backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices
)
@ -401,16 +399,16 @@ def gen_c_shim(
@dataclass(frozen=True)
class ShimGenerator:
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup]
dispatch_key: DispatchKey
backend_indices: dict[DispatchKey, BackendIndex]
backend_indices: Dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp
@method_with_native_function
def __call__(
self,
func: NativeFunction,
) -> str | None:
) -> Optional[str]:
result = gen_c_shim(
func,
self.func_group_mapping,
@ -423,9 +421,9 @@ class ShimGenerator:
def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction],
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey,
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
header: bool,
includes: str = "",
) -> str:

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import argparse
import os
import re
from collections import Counter, defaultdict, namedtuple
from pathlib import Path
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Set, Union
import yaml
@ -38,10 +36,10 @@ ParsedExternalYaml = namedtuple(
def parse_backend_yaml(
backend_yaml_path: str,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_indices: Dict[DispatchKey, BackendIndex],
) -> ParsedExternalYaml:
native_functions_map: dict[OperatorName, NativeFunction] = {
native_functions_map: Dict[OperatorName, NativeFunction] = {
f.func.name: f
for f in concatMap(
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
@ -121,14 +119,14 @@ def parse_backend_yaml(
Only the following keys are supported: {", ".join(valid_keys)}'
def create_backend_index(
backend_ops: list[str],
symint_ops: set[str],
backend_ops: List[str],
symint_ops: Set[str],
dispatch_key: DispatchKey,
*,
use_out_as_primary: bool,
use_device_guard: bool,
) -> BackendIndex:
metadata: dict[OperatorName, BackendMetadata] = {}
metadata: Dict[OperatorName, BackendMetadata] = {}
for op in backend_ops:
op_name = OperatorName.parse(op)
assert (
@ -151,7 +149,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
index=metadata,
)
backend_key: DispatchKey | None = None
backend_key: Optional[DispatchKey] = None
if len(supported) > 0:
with context(
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
@ -168,7 +166,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
assert backend_key not in backend_indices
backend_indices[backend_key] = backend_idx
autograd_key: DispatchKey | None = None
autograd_key: Optional[DispatchKey] = None
if len(supported_autograd) > 0:
with context(
lambda: f'The "autograd" key was specified, which indicates that you would like to override \
@ -247,12 +245,12 @@ autograd key. They cannot be mix and matched. If this is something you need, fee
def error_on_missing_kernels(
native_functions: Sequence[NativeFunction],
backend_indices: dict[DispatchKey, BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex],
backend_key: DispatchKey,
autograd_key: DispatchKey | None,
autograd_key: Optional[DispatchKey],
class_name: str,
kernel_defn_file_path: str,
full_codegen: list[OperatorName] | None = None,
full_codegen: Optional[List[OperatorName]] = None,
) -> None:
try:
with open(kernel_defn_file_path) as f:
@ -270,7 +268,7 @@ def error_on_missing_kernels(
)
# Quick mapping from each OperatorName used by the external backend
# to its backend kernel name
expected_backend_op_names: dict[OperatorName, str] = dict(
expected_backend_op_names: Dict[OperatorName, str] = dict(
list(
concatMap(
lambda index: [
@ -280,13 +278,13 @@ def error_on_missing_kernels(
)
)
)
expected_backend_native_funcs: list[NativeFunction] = [
expected_backend_native_funcs: List[NativeFunction] = [
f
for f in native_functions
if f.func.name in expected_backend_op_names.keys()
and f.func.name not in full_codegen
]
expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
list
)
for native_f in expected_backend_native_funcs:
@ -358,10 +356,10 @@ def gen_dispatchkey_nativefunc_headers(
fm: FileManager,
class_name: str,
cpp_namespace: str,
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: Dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_dispatch_key: DispatchKey,
autograd_dispatch_key: DispatchKey | None,
autograd_dispatch_key: Optional[DispatchKey],
backend_name: str = "",
) -> None:
assert class_name is not None
@ -415,11 +413,11 @@ def gen_dispatcher_registrations(
fm: FileManager,
output_dir: str,
class_name: str,
backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: Dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
backend_dispatch_key: DispatchKey,
dispatch_key: DispatchKey,
selector: SelectiveBuilder,
selector: "SelectiveBuilder",
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
build_in_tree: bool = False,
per_operator_headers: bool = False,
@ -526,7 +524,7 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
def run(
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None
) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
pytorch_root = Path(__file__).absolute().parent.parent

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import argparse
import os
import pathlib
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
import yaml
@ -47,6 +45,7 @@ from torchgen.model import (
OperatorName,
Variant,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import (
context,
FileManager,
@ -56,11 +55,7 @@ from torchgen.utils import (
)
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
"""
A wrapper function to basically get `sig.decl(include_context=True)`.
For ATen kernel, the codegen has no idea about ET contextArg, so we
@ -77,9 +72,9 @@ def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
def static_dispatch(
sig: CppSignature | ExecutorchCppSignature,
sig: Union[CppSignature, ExecutorchCppSignature],
f: NativeFunction,
backend_indices: list[BackendIndex],
backend_indices: List[BackendIndex],
) -> str:
"""
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
@ -118,7 +113,7 @@ TORCH_API inline {_sig_decl_wrapper(sig)} {{
# and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True)
class ComputeFunction:
static_dispatch_backend_indices: list[BackendIndex]
static_dispatch_backend_indices: List[BackendIndex]
selector: SelectiveBuilder
@ -127,7 +122,7 @@ class ComputeFunction:
is_custom_op: Callable[[NativeFunction], bool]
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
is_method_variant = False
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
return None
@ -141,7 +136,7 @@ class ComputeFunction:
f"Can't handle native function {f.func} with the following variant specification {f.variants}."
)
sig: CppSignature | ExecutorchCppSignature = (
sig: Union[CppSignature, ExecutorchCppSignature] = (
CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature()
@ -184,10 +179,10 @@ class ComputeCodegenUnboxedKernels:
@method_with_nested_native_function
def __call__(
self,
unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
unbox_kernel_entry: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]],
) -> str:
f: NativeFunction = unbox_kernel_entry[0]
kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
kernel_key: Union[ETKernelKey, List[ETKernelKey]] = unbox_kernel_entry[1][0]
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
op_name = f"{f.namespace}::{f.func.name}"
@ -201,7 +196,7 @@ class ComputeCodegenUnboxedKernels:
)
if not used_kernel_keys:
return ""
sig: CppSignature | ExecutorchCppSignature
sig: Union[CppSignature, ExecutorchCppSignature]
argument_type_gen: Callable[..., NamedCType]
return_type_gen: Callable[..., CType]
if self.use_aten_lib:
@ -295,11 +290,11 @@ def gen_unboxing(
) -> None:
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
def key_func(
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
item: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]
) -> str:
return item[0].root_name + ":" + item[1][0].to_native_string()
items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
items: List[Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]] = [
(native_function, (kernel_key, metadata))
for native_function in native_functions
for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
@ -330,8 +325,8 @@ def gen_unboxing(
@with_native_function_and_index # type: ignore[arg-type]
def compute_native_function_declaration(
g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
) -> list[str]:
g: Union[NativeFunctionsGroup, NativeFunction], kernel_index: ETKernelIndex
) -> List[str]:
assert isinstance(g, NativeFunction)
sig = ExecutorchCppSignature.from_native_function(f=g)
metadata_list = kernel_index.get_kernels(g).values()
@ -357,7 +352,7 @@ def gen_functions_declarations(
kernel_index: ETKernelIndex,
selector: SelectiveBuilder,
use_aten_lib: bool,
custom_ops_native_functions: Sequence[NativeFunction] | None = None,
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
) -> str:
"""
Generates namespace separated C++ function API inline declaration/definitions.
@ -411,13 +406,13 @@ def get_ns_grouped_kernels(
kernel_index: ETKernelIndex,
native_function_decl_gen: Callable[
[
NativeFunctionsGroup | NativeFunction,
Union[NativeFunctionsGroup, NativeFunction],
ETKernelIndex,
],
list[str],
List[str],
],
) -> dict[str, list[str]]:
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
) -> Dict[str, List[str]]:
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
for f in native_functions:
native_function_namespaces = set()
op_kernels = kernel_index.get_kernels(f)
@ -600,7 +595,7 @@ def gen_custom_ops(
def translate_native_yaml(
tags_yaml_path: str,
aten_yaml_path: str,
native_yaml_path: str | None,
native_yaml_path: Optional[str],
use_aten_lib: bool,
out_file: TextIO,
) -> None:
@ -651,15 +646,15 @@ def translate_native_yaml(
skip_native_fns_gen=False,
)
func_to_scoped_name: dict[FunctionSchema, str] = {
func_to_scoped_name: Dict[FunctionSchema, str] = {
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
}
op_to_scoped_name: dict[OperatorName, str] = {
op_to_scoped_name: Dict[OperatorName, str] = {
func.name: name for func, name in func_to_scoped_name.items()
}
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
kernel_persist_dict: dict[str, dict[str, Any]] = {
kernel_persist_dict: Dict[str, Dict[str, Any]] = {
op_to_scoped_name[op]: v for op, v in persisted_fields.items()
}
@ -697,13 +692,13 @@ def translate_native_yaml(
def parse_yaml(
path: str | None,
path: Optional[str],
tags_yaml_path: str,
function_filter: Callable[[NativeFunction], bool],
skip_native_fns_gen: bool = False,
) -> tuple[
list[NativeFunction],
dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
) -> Tuple[
List[NativeFunction],
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex],
]:
if path and os.path.exists(path) and os.stat(path).st_size > 0:
with open(path) as f:
@ -740,8 +735,8 @@ def parse_yaml(
# (2) Return BackendIndices if kernel index is absent
def map_index(
m: dict[OperatorName, BackendMetadata]
) -> dict[OperatorName, BackendMetadata]:
m: Dict[OperatorName, BackendMetadata]
) -> Dict[OperatorName, BackendMetadata]:
return {op: m[op] for op in m if op in op_names}
backend_indices = {
@ -756,11 +751,11 @@ def parse_yaml(
def parse_yaml_files(
tags_yaml_path: str,
aten_yaml_path: str,
native_yaml_path: str | None,
custom_ops_yaml_path: str | None,
native_yaml_path: Optional[str],
custom_ops_yaml_path: Optional[str],
selector: SelectiveBuilder,
use_aten_lib: bool,
) -> tuple[ETParsedYaml, ETParsedYaml | None]:
) -> Tuple[ETParsedYaml, Optional[ETParsedYaml]]:
"""Parses functions.yaml and custom_ops.yaml files.
Args:
@ -983,7 +978,7 @@ def main() -> None:
)
if options.output_dependencies:
depfile_path = Path(options.output_dependencies).resolve()
depfile_path = pathlib.Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name
depfile_stem = depfile_path.stem

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, TYPE_CHECKING
from typing import Callable, List, Optional, Tuple, Union
from torchgen.api import cpp, dispatcher
from torchgen.api.translate import translate
@ -48,13 +46,10 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
)
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import dataclass_repr
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Note: [Mutable Ops Not Using Functionalization]
# Ops in this list currently do not work with functionalization and should be fixed.
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
@ -93,7 +88,7 @@ class GenCompositeViewCopyKernel:
backend_index: BackendIndex
@method_with_native_function
def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]:
if g.view_copy is None:
return None
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
@ -165,7 +160,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
"""
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:
return ""
@ -189,7 +184,7 @@ def wrapper_name(func: FunctionSchema) -> str:
return cpp.name(func)
def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> bool:
return isinstance(a, SelfArgument) or (
isinstance(a, Argument) and a.type.is_tensor_like()
)
@ -199,7 +194,7 @@ def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
# Some op schemas include non-owning types though (like TensorList),
# and when we unwrap them we expect to get out an owning type!.
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
if t == BaseCType(tensorListT):
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
if t == BaseCType(iTensorListRefT):
@ -214,9 +209,9 @@ def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
# (2) a context, to be used by translate(), with all of the relevant bindings.
def unwrap_tensor_args(
sig: DispatcherSignature, *, is_view_op: bool
) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
) -> Tuple[str, List[Binding]]:
context: List[Binding] = []
unwrapped_tensor_args: List[str] = []
for arg in sig.arguments():
if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
@ -252,9 +247,9 @@ def unwrap_tensor_args(
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
# (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: list[Binding] = []
unwrapped_tensor_args: list[str] = []
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
context: List[Binding] = []
unwrapped_tensor_args: List[str] = []
for arg in sig.arguments():
if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
@ -322,7 +317,7 @@ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
# Detects whether any of the SymInt arguments are, in fact, symbolic values.
# This is used in the constructor of ViewMeta.
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]:
name = "has_symbolic_inputs"
statements = [
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
@ -527,7 +522,7 @@ def maybe_create_output(f: NativeFunction, var_name: str) -> str:
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
def get_mutable_redispatch_return_names(
f: NativeFunction, inner_return_var: str
) -> tuple[list[str], list[str]]:
) -> Tuple[List[str], List[str]]:
aliased_returns = []
non_aliased_returns = []
for i, name in enumerate(f.func.aliased_return_names()):
@ -756,11 +751,11 @@ def emit_inplace_functionalization_body(
# See Note [Functionalization Pass: View Inverses].
def gen_functionalization_view_inverse_declaration(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> str | None:
) -> Optional[str]:
# For every (non-composite) view op, we need a corresponding "inverse view" function.
# This generates the declarations so we get a good compiler error when someone adds a new view.
@with_native_function
def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]:
if g.view.has_composite_implicit_autograd_kernel:
return None
view_inverse_sig = ViewInverseSignature(g)
@ -771,9 +766,9 @@ def gen_functionalization_view_inverse_declaration(
def gen_functionalization_registration(
selector: SelectiveBuilder,
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
composite_implicit_autograd_index: BackendIndex,
) -> list[str]:
) -> List[str]:
@with_native_function
def emit_registration_helper(f: NativeFunction) -> str:
assert not f.has_composite_implicit_autograd_kernel
@ -837,8 +832,8 @@ def gen_functionalization_definition(
# (and instead only need to operate on grouped NativeFunctions).
# The only reason currently is because we need to emit direct dispatch registrations
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> list[str]:
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
) -> List[str]:
# Don't generate kernels in mobile build
if not selector.include_all_operators:
return []

View File

@ -1,10 +1,19 @@
from __future__ import annotations
import argparse
import os
from collections import namedtuple
from pathlib import Path
from typing import Any, Callable, Iterable, Iterator, Sequence
from typing import (
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import yaml
@ -93,8 +102,8 @@ ParsedExternalYaml = namedtuple(
def parse_native_functions_keys(
backend_yaml_path: str,
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:
with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict)
@ -111,7 +120,7 @@ def parse_native_functions_keys(
def validate_shape_inference_header(
shape_inference_hdr: str, expected_shape_infr_decls: list[str]
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
) -> None:
try:
with open(shape_inference_hdr) as f:
@ -171,12 +180,12 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
class default_args:
node_base: str = "Node"
node_base_hdr: str | None = None
node_base_hdr: Optional[str] = None
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
tensor_class: str = "torch::lazy::LazyTensor"
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
native_func_definition_generator: type[
lazy_ir_generator: Type[GenLazyIR] = GenLazyIR
native_func_definition_generator: Type[
GenLazyNativeFuncDefinition
] = GenLazyNativeFuncDefinition
backend_name: str = "TorchScript"
@ -254,10 +263,10 @@ def main() -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
torch_root = Path(__file__).absolute().parents[2]
aten_path = str(torch_root / "aten" / "src" / "ATen")
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings:
lazy_ir_generator = GenTSLazyIR
native_func_definition_generator: type[
native_func_definition_generator: Type[
GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator
@ -283,14 +292,14 @@ def run_gen_lazy_tensor(
source_yaml: str,
output_dir: str,
dry_run: bool,
impl_path: str | None,
impl_path: Optional[str],
node_base: str = default_args.node_base,
node_base_hdr: str | None = default_args.node_base_hdr,
node_base_hdr: Optional[str] = default_args.node_base_hdr,
tensor_class: str = default_args.tensor_class,
tensor_class_hdr: str = default_args.tensor_class_hdr,
shape_inference_hdr: str = default_args.shape_inference_hdr,
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
native_func_definition_generator: type[
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator,
native_func_definition_generator: Type[
GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator,
# build_in_tree is true for TS backend and affects include paths
@ -338,7 +347,7 @@ def run_gen_lazy_tensor(
)
grouped_native_functions = get_grouped_native_functions(native_functions)
def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
"""
We sort the native function because of the note in concat_map_codegen.
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
@ -368,8 +377,8 @@ def run_gen_lazy_tensor(
def concat_map_codegen(
func: Callable[[NativeFunction], Sequence[str]],
xs: Iterable[NativeFunctionsGroup | NativeFunction],
ops_list: list[OperatorName] = full_codegen,
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
ops_list: List[OperatorName] = full_codegen,
) -> Iterator[str]:
"""
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import textwrap
from dataclasses import dataclass
from typing import Sequence
from typing import List, Optional, Sequence, Tuple
from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature
@ -34,7 +32,7 @@ def is_tensor_list(typ: Type) -> bool:
return isinstance(typ, ListType) and is_tensor(typ.elem)
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
result = f"""\
Tensor {name}_value;
optional<int64_t> {name}_bdim;
@ -42,7 +40,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
return textwrap.dedent(result).split("\n")
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
result = f"""\
optional<Tensor> {name}_value;
optional<int64_t> {name}_bdim;
@ -54,7 +52,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
def gen_unwraps(
flat_arguments: Sequence[Argument], cur_level_var: str
) -> tuple[str, list[str]]:
) -> Tuple[str, List[str]]:
arg_names = [a.name for a in flat_arguments]
arg_types = [a.type for a in flat_arguments]
@ -101,7 +99,7 @@ if ({' && '.join(conditions)}) {{
def gen_returns(
returns: tuple[Return, ...], cur_level_var: str, results_var: str
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
) -> str:
idx = 0
wrapped_returns = []
@ -134,7 +132,7 @@ def is_mutated_arg(argument: Argument) -> bool:
return argument.annotation is not None and argument.annotation.is_write
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
# Assumptions:
# - only one argument is being modified in-place
# - the argument that is being modified in-place is the first argument
@ -199,7 +197,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
}}"""
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
schema = native_function.func
sig = DispatcherSignature.from_schema(schema)
returns = schema.returns
@ -246,7 +244,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
@dataclass(frozen=True)
class ComputeBatchRulePlumbing:
@method_with_native_function
def __call__(self, f: NativeFunction) -> str | None:
def __call__(self, f: NativeFunction) -> Optional[str]:
result = gen_vmap_plumbing(f)
return result

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import threading
from contextlib import contextmanager
from typing import Iterator
from typing import Iterator, Optional
# Simple dynamic scoping implementation. The name "parametrize" comes
@ -19,8 +17,8 @@ from typing import Iterator
class Locals(threading.local):
use_const_ref_for_mutable_tensors: bool | None = None
use_ilistref_for_tensor_lists: bool | None = None
use_const_ref_for_mutable_tensors: Optional[bool] = None
use_ilistref_for_tensor_lists: Optional[bool] = None
_locals = Locals()

View File

@ -1,11 +1,9 @@
from __future__ import annotations
import dataclasses
import itertools
import re
from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, Iterator, Sequence
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
@ -231,7 +229,7 @@ class DispatchKey(Enum):
return str(self).lower()
@staticmethod
def parse(value: str) -> DispatchKey:
def parse(value: str) -> "DispatchKey":
for k, v in DispatchKey.__members__.items():
if k == value:
return v
@ -352,20 +350,20 @@ class ScalarType(Enum):
return self.name
@staticmethod
def maybe_parse(value: str) -> ScalarType | None:
def maybe_parse(value: str) -> Optional["ScalarType"]:
for k, v in ScalarType.__members__.items():
if k == value:
return v
return None
@staticmethod
def parse(value: str) -> ScalarType:
def parse(value: str) -> "ScalarType":
mb_r = ScalarType.maybe_parse(value)
assert mb_r is not None, f"unknown dtype {value}"
return mb_r
@staticmethod
def parse_set(values: str) -> OrderedSet[ScalarType]:
def parse_set(values: str) -> OrderedSet["ScalarType"]:
dtypes: OrderedSet[ScalarType] = OrderedSet()
for value in values.split(", "):
if value in DTYPE_CLASSES:
@ -375,7 +373,7 @@ class ScalarType(Enum):
return dtypes
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
# NB: Integral doesn't include boolean
DTYPE_CLASSES["Integral"] = OrderedSet(
[
@ -421,7 +419,7 @@ class UfuncKey(Enum):
return self.name
@staticmethod
def parse(value: str) -> UfuncKey:
def parse(value: str) -> "UfuncKey":
for k, v in UfuncKey.__members__.items():
if k == value:
return v
@ -464,7 +462,7 @@ class NativeFunction:
# (This type is quoted as we are forward referencing a type
# defined later in the file. I opted for this ordering of the
# classes for expository clarity.)
func: FunctionSchema
func: "FunctionSchema"
# Whether or not to generate mutable tensor arguments like regular
# ones
@ -477,14 +475,14 @@ class NativeFunction:
device_check: DeviceCheckType
# What python module to put the function in
python_module: str | None
python_module: Optional[str]
# TODO: figure out what this does
category_override: str | None
category_override: Optional[str]
# If no variants are specified in native_functions.yaml, this is
# assumed to be {'function'}.
variants: set[Variant]
variants: Set[Variant]
# Whether or not we should skip generating registrations for
# this kernel. This is a bit of a double-edged sword, as manual
@ -499,7 +497,7 @@ class NativeFunction:
# The location in the YAML file were this native function entry was
# defined. This is for conveniently reporting error messages!
loc: Location
loc: "Location"
# A list of operators that are expected to be auto-generated for this NativeFunction.
# Note: This list isn't actually directly used by the codegen to generate anything.
@ -507,11 +505,11 @@ class NativeFunction:
# function schema, and uses the autogen declarations to error check.
# We expect every NativeFunction that gets auto-generated be explicitly called out
# in native_functions.yaml
autogen: list[OperatorName]
autogen: List["OperatorName"]
# If non-empty, this kernel is subject to ufunc codegen.
# Sorted by ufunc_key
ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
# Whether or not this out functions is a "structured kernel". Structured
# kernels are defined a little differently from normal kernels; in
@ -524,13 +522,13 @@ class NativeFunction:
# Whether or not this non-out function is a structured kernel, defined
# in terms of the out kernel referenced by the string here.
structured_delegate: OperatorName | None
structured_delegate: Optional["OperatorName"]
# Only valid for structured kernels. Specifies alternative of what
# to inherit from when defining the meta class for the structured
# operator. This will usually be TensorIteratorBase. This also
# changes the semantics of set_output to call the parent class.
structured_inherits: str | None
structured_inherits: Optional[str]
# Structured kernels can declare elements as "precomputed". These elements
# are returned by the meta function in one struct and passed to the impl
@ -538,11 +536,11 @@ class NativeFunction:
# elements supersede. Information about the names and types of these
# precomputed elements and how they correspond to kernel arguments is stored
# in this member, if applicable.
precomputed: Precompute | None
precomputed: Optional["Precompute"]
# Argument names whose default should be excluded from the C++ interface.
# Intended for resolving overload ambiguities between signatures.
cpp_no_default_args: set[str]
cpp_no_default_args: Set[str]
# Note [Abstract ATen methods]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -562,7 +560,7 @@ class NativeFunction:
# Tags are used to describe semantic information about (groups of) operators,
# That aren't easily inferrable directly from the operator's schema.
tags: set[str]
tags: Set[str]
# NB: The benefit of defining a dataclass is that we automatically get
# a constructor defined for all the fields we specify. No need
@ -571,11 +569,13 @@ class NativeFunction:
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
@staticmethod
def from_yaml(
ei: dict[str, object],
loc: Location,
valid_tags: set[str],
ignore_keys: set[DispatchKey] | None = None,
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
ei: Dict[str, object],
loc: "Location",
valid_tags: Set[str],
ignore_keys: Optional[Set[DispatchKey]] = None,
) -> Tuple[
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
]:
"""
Parse a NativeFunction from a dictionary as directly parsed
from native_functions.yaml
@ -602,7 +602,7 @@ class NativeFunction:
variants_s = e.pop("variants", "function")
assert isinstance(variants_s, str)
variants: set[Variant] = set()
variants: Set[Variant] = set()
for v in variants_s.split(", "):
if v == "function":
variants.add(Variant.function)
@ -646,7 +646,7 @@ class NativeFunction:
"namespace is not supported in structured delegate,"
" using the same namespace as the native function"
)
structured_delegate: OperatorName | None = None
structured_delegate: Optional[OperatorName] = None
if structured_delegate_s is not None:
structured_delegate = OperatorName.parse(structured_delegate_s)
@ -685,7 +685,7 @@ class NativeFunction:
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
tags_inp.append("pt2_compliant_tag")
tags: set[str] = set()
tags: Set[str] = set()
for t in tags_inp:
assert len(valid_tags) > 0
# TODO: verify that the tag is valid and has an entry in tags.yaml
@ -698,7 +698,7 @@ class NativeFunction:
raw_dispatch = e.pop("dispatch", None)
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
dispatch: dict[DispatchKey, BackendMetadata] = {}
dispatch: Dict[DispatchKey, BackendMetadata] = {}
num_dispatch_keys: int = 0
if raw_dispatch is not None:
assert not manual_kernel_registration, (
@ -1081,8 +1081,8 @@ class SchemaKind(Enum):
@dataclass(frozen=True)
class NativeFunctionsGroup:
functional: NativeFunction
inplace: NativeFunction | None
mutable: NativeFunction | None
inplace: Optional[NativeFunction]
mutable: Optional[NativeFunction]
out: NativeFunction
@property
@ -1136,7 +1136,7 @@ class NativeFunctionsGroup:
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
)
generated_fns_str = ", ".join(str(x) for x in generated_fns)
expected_generated_fns: set[str] = set()
expected_generated_fns: Set[str] = set()
for f in self.functions():
expected_generated_fns.update(str(op) for op in f.autogen)
expected_generated_fns_str = ", ".join(
@ -1155,7 +1155,7 @@ class NativeFunctionsGroup:
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
)
def signature(self) -> FunctionSchema:
def signature(self) -> "FunctionSchema":
return self.out.func.signature()
def functions(self) -> Iterator[NativeFunction]:
@ -1171,7 +1171,9 @@ class NativeFunctionsGroup:
return self.functional.root_name
@staticmethod
def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
def from_dict(
d: Dict[SchemaKind, NativeFunction]
) -> Optional["NativeFunctionsGroup"]:
assert d
if len(d) == 1:
return None
@ -1227,7 +1229,7 @@ class UfuncInnerLoop:
ufunc_key: UfuncKey
@staticmethod
def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
name, supported_dtypes_str = value.split(" ", 1)
assert supported_dtypes_str[0] == "("
assert supported_dtypes_str[-1] == ")"
@ -1259,12 +1261,12 @@ class BackendIndex:
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
external: bool
# Other backend-specific information that is on a per-operator basis
index: dict[OperatorName, BackendMetadata]
index: Dict["OperatorName", BackendMetadata]
@staticmethod
def grow_index(
parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
) -> None:
for k, v in child_index.items():
for op_name, metadata in v.items():
@ -1279,13 +1281,13 @@ class BackendIndex:
else:
return g.functional
def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
m = self.get_kernel(g)
return m is not None
def get_kernel(
self, g: NativeFunction | NativeFunctionsGroup
) -> BackendMetadata | None:
self, g: Union[NativeFunction, NativeFunctionsGroup]
) -> Optional[BackendMetadata]:
if isinstance(g, NativeFunction):
f = g
elif isinstance(g, NativeFunctionsGroup):
@ -1296,7 +1298,7 @@ class BackendIndex:
return None
return self.index[f.func.name]
def native_function_class_name(self) -> str | None:
def native_function_class_name(self) -> Optional[str]:
if self.external:
return f"{str(self.dispatch_key)}NativeFunctions"
else:
@ -1362,16 +1364,16 @@ class BackendIndex:
@dataclass(frozen=True)
class FunctionSchema:
# The name of the operator this function schema describes.
name: OperatorName
name: "OperatorName"
arguments: Arguments
arguments: "Arguments"
# TODO: Need to handle collisions with argument names at some point
returns: tuple[Return, ...]
returns: Tuple["Return", ...]
@property
def is_mutable(self) -> bool:
def is_write(arg: Argument) -> bool:
def is_write(arg: "Argument") -> bool:
if arg.annotation is None:
return False
return arg.annotation.is_write
@ -1380,7 +1382,7 @@ class FunctionSchema:
# See aten/src/ATen/core/function_schema.h (keep these in sync)
return any(is_write(a) for a in self.arguments.flat_all)
def schema_order_arguments(self) -> Iterator[Argument]:
def schema_order_arguments(self) -> Iterator["Argument"]:
return itertools.chain(
self.arguments.flat_positional,
self.arguments.flat_kwarg_only,
@ -1390,7 +1392,7 @@ class FunctionSchema:
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
@staticmethod
def parse(func: str) -> FunctionSchema:
def parse(func: str) -> "FunctionSchema":
# We should probably get a proper parser here
decls = FunctionSchema.decl_re.findall(func)
assert len(decls) == 1, f"Invalid function schema: {func}"
@ -1585,8 +1587,8 @@ class FunctionSchema:
# - If the return aliases an input, we return the input name
# - Otherwise, we return None.
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
def aliased_return_names(self) -> list[str | None]:
outs: list[str | None] = []
def aliased_return_names(self) -> List[Optional[str]]:
outs: List[Optional[str]] = []
for r in self.returns:
aliased_args = [
a
@ -1610,7 +1612,7 @@ class FunctionSchema:
strip_default: bool = False,
strip_view_copy_name: bool = False,
keep_return_names: bool = False,
) -> FunctionSchema:
) -> "FunctionSchema":
"""
Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method
@ -1707,10 +1709,10 @@ class FunctionSchema:
returns=returns,
)
def view_signature(self) -> FunctionSchema:
def view_signature(self) -> "FunctionSchema":
return self.signature(strip_view_copy_name=True)
def with_name(self, name: OperatorName) -> FunctionSchema:
def with_name(self, name: "OperatorName") -> "FunctionSchema":
return FunctionSchema(
name=name,
arguments=self.arguments,
@ -1745,12 +1747,12 @@ class FunctionSchema:
class Annotation:
# Typically only has one element. Not actually a set so
# we can conveniently assume it is canonically ordered
alias_set: tuple[str, ...]
alias_set: Tuple[str, ...]
is_write: bool
alias_set_after: tuple[str, ...]
alias_set_after: Tuple[str, ...]
@staticmethod
def parse(ann: str) -> Annotation:
def parse(ann: str) -> "Annotation":
# TODO: implement a proper parser if this gets more ugly
# Regex Explanation:
# Example: "a! -> a|b"
@ -1803,13 +1805,13 @@ class Annotation:
@dataclass(frozen=True)
class Type:
@staticmethod
def parse(t: str) -> Type:
def parse(t: str) -> "Type":
r = Type._parse(t)
assert str(r) == t, f"{r} != {t}"
return r
@staticmethod
def _parse(t: str) -> Type:
def _parse(t: str) -> "Type":
m = re.match(r"^(.+)\?$", t)
if m is not None:
return OptionalType(Type.parse(m.group(1)))
@ -1835,7 +1837,7 @@ class Type:
# so we can conveniently generate legacy Declarations.yaml but
# really we should probably just remove these at some point
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
raise NotImplementedError
def is_tensor_like(self) -> bool:
@ -1850,7 +1852,7 @@ class Type:
def is_nullable(self) -> bool:
raise NotImplementedError
def is_list_like(self) -> ListType | None:
def is_list_like(self) -> Optional["ListType"]:
raise NotImplementedError
@ -1890,7 +1892,7 @@ class BaseType(Type):
def is_nullable(self) -> bool:
return False
def is_list_like(self) -> ListType | None:
def is_list_like(self) -> Optional["ListType"]:
return None
def is_symint_like(self) -> bool:
@ -1914,7 +1916,7 @@ class OptionalType(Type):
def is_nullable(self) -> bool:
return True
def is_list_like(self) -> ListType | None:
def is_list_like(self) -> Optional["ListType"]:
return self.elem.is_list_like()
@ -1941,7 +1943,7 @@ class CustomClassType(Type):
"""
return False
def is_list_like(self) -> ListType | None:
def is_list_like(self) -> Optional["ListType"]:
return None
@ -1955,7 +1957,7 @@ class CustomClassType(Type):
@dataclass(frozen=True)
class ListType(Type):
elem: Type
size: int | None
size: Optional[int]
def __str__(self) -> str:
size = f"{self.size}" if self.size else ""
@ -1970,7 +1972,7 @@ class ListType(Type):
def is_nullable(self) -> bool:
return self.elem.is_nullable()
def is_list_like(self) -> ListType | None:
def is_list_like(self) -> Optional["ListType"]:
return self
@ -1981,7 +1983,7 @@ class Argument:
name: str
type: Type
default: str | None
default: Optional[str]
# The semantics of the annotation field are a little strange.
#
@ -2002,16 +2004,16 @@ class Argument:
# structure of annotated types is very simple. So we just hard
# code it here. But if we ever do get anything more complex, this
# model will have to change!
annotation: Annotation | None
annotation: Optional[Annotation]
@property
def alias_info(self) -> Annotation | None:
def alias_info(self) -> Optional[Annotation]:
return self.annotation
@staticmethod
def parse(arg: str) -> Argument:
def parse(arg: str) -> "Argument":
name: str
default: str | None
default: Optional[str]
assert " " in arg, f"illegal argument '{arg}'"
type_and_annot, name_and_default = arg.rsplit(" ", 1)
if "=" in name_and_default:
@ -2024,7 +2026,7 @@ class Argument:
default = None
# TODO: deduplicate annotation matching with Return
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Annotation | None
annotation: Optional[Annotation]
if match:
# If you update this, make sure the __str__ still works too
assert match.group(2) in [
@ -2067,24 +2069,24 @@ class Argument:
@dataclass(frozen=True)
class Return:
name: str | None
name: Optional[str]
type: Type
annotation: Annotation | None
annotation: Optional[Annotation]
@property
def alias_info(self) -> Annotation | None:
def alias_info(self) -> Optional[Annotation]:
return self.annotation
@staticmethod
def parse(arg: str) -> Return:
name: str | None
def parse(arg: str) -> "Return":
name: Optional[str]
if " " in arg:
type_and_annot, name = arg.rsplit(" ", 1)
else:
type_and_annot = arg
name = None
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Annotation | None
annotation: Optional[Annotation]
if match:
# If you update this, make sure the __str__ still works too
assert match.group(2) in [
@ -2146,34 +2148,34 @@ class Arguments:
# pre_self_positional is usually empty, but is notably non-empty
# for where.self, where the condition argument comes before the
# self argument
pre_self_positional: tuple[Argument, ...]
self_arg: SelfArgument | None
post_self_positional: tuple[Argument, ...]
pre_self_positional: Tuple[Argument, ...]
self_arg: Optional[SelfArgument]
post_self_positional: Tuple[Argument, ...]
pre_tensor_options_kwarg_only: tuple[Argument, ...]
tensor_options: TensorOptionsArguments | None
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
tensor_options: Optional[TensorOptionsArguments]
# post_tensor_options is typically memory format, which should be
# part of tensor options but isn't right now, and is usually
# placed after the tensor options arguments
post_tensor_options_kwarg_only: tuple[Argument, ...]
post_tensor_options_kwarg_only: Tuple[Argument, ...]
# Unlike in the previous codegen, we have factored out 'out' arguments
# in the canonical representation, removing them from kwarg
# arguments. This choice is justified by numerous downstream
# transformations which treat out arguments specially; additionally,
# you can see that canonicity is not violated!
out: tuple[Argument, ...] # these are also kwarg-only
out: Tuple[Argument, ...] # these are also kwarg-only
@property
def flat_non_out(self) -> Sequence[Argument]:
ret: list[Argument] = []
ret: List[Argument] = []
ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only)
return ret
@property
def flat_positional(self) -> Sequence[Argument]:
ret: list[Argument] = []
ret: List[Argument] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg.argument)
@ -2187,7 +2189,7 @@ class Arguments:
# NB: doesn't contain out arguments
@property
def flat_kwarg_only(self) -> Sequence[Argument]:
ret: list[Argument] = []
ret: List[Argument] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.extend(self.tensor_options.all())
@ -2196,7 +2198,7 @@ class Arguments:
@property
def flat_all(self) -> Sequence[Argument]:
ret: list[Argument] = []
ret: List[Argument] = []
ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only)
ret.extend(self.out)
@ -2205,15 +2207,15 @@ class Arguments:
@property
def non_out(
self,
) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
ret.extend(self.positional)
ret.extend(self.kwarg_only)
return ret
@property
def positional(self) -> Sequence[Argument | SelfArgument]:
ret: list[Argument | SelfArgument] = []
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
ret: List[Union[Argument, SelfArgument]] = []
ret.extend(self.pre_self_positional)
if self.self_arg is not None:
ret.append(self.self_arg)
@ -2221,8 +2223,8 @@ class Arguments:
return ret
@property
def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
ret: list[Argument | TensorOptionsArguments] = []
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
ret: List[Union[Argument, TensorOptionsArguments]] = []
ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None:
ret.append(self.tensor_options)
@ -2230,14 +2232,14 @@ class Arguments:
return ret
@property
def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
ret.extend(self.positional)
ret.extend(self.kwarg_only)
ret.extend(self.out)
return ret
def mutable_arg_names(self) -> list[str]:
def mutable_arg_names(self) -> List[str]:
return [
a.name
for a in self.flat_all
@ -2253,7 +2255,7 @@ class Arguments:
def has_generator_arg(self) -> bool:
return any(a.type.is_generator_like() for a in self.flat_non_out)
def signature(self, *, strip_default: bool = False) -> Arguments:
def signature(self, *, strip_default: bool = False) -> "Arguments":
# dataclasses.replace could be used here, but it is less
# type safe so for now I've opted to type everything out
def strip_arg_annotation(a: Argument) -> Argument:
@ -2288,7 +2290,7 @@ class Arguments:
out=(),
)
def remove_self_annotation(self) -> Arguments:
def remove_self_annotation(self) -> "Arguments":
assert self.self_arg is not None
return dataclasses.replace(
self,
@ -2297,7 +2299,7 @@ class Arguments:
),
)
def with_out_args(self, outs: list[Argument]) -> Arguments:
def with_out_args(self, outs: List[Argument]) -> "Arguments":
assert len(self.out) == 0
return dataclasses.replace(
self,
@ -2305,10 +2307,10 @@ class Arguments:
)
@staticmethod
def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
positional: list[Argument] = []
kwarg_only: list[Argument] = []
out: list[Argument] = []
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
positional: List[Argument] = []
kwarg_only: List[Argument] = []
out: List[Argument] = []
arguments_acc = positional
# TODO: Use a real parser here; this will get bamboozled
@ -2341,7 +2343,7 @@ class Arguments:
return positional, kwarg_only, out
@staticmethod
def parse(args: str) -> Arguments:
def parse(args: str) -> "Arguments":
"""
Input: 'int x, int y, int z'
"""
@ -2359,9 +2361,9 @@ class Arguments:
if a.name == "self":
self_ix = i
break
pre_self_positional: list[Argument]
self_arg: SelfArgument | None
post_self_positional: list[Argument]
pre_self_positional: List[Argument]
self_arg: Optional[SelfArgument]
post_self_positional: List[Argument]
if self_ix is not None:
pre_self_positional = positional[:self_ix]
self_arg = SelfArgument(positional[self_ix])
@ -2372,9 +2374,9 @@ class Arguments:
post_self_positional = positional
# Group tensor options arguments
pre_tensor_options_kwarg_only: list[Argument] = []
tensor_options: TensorOptionsArguments | None = None
post_tensor_options_kwarg_only: list[Argument] = []
pre_tensor_options_kwarg_only: List[Argument] = []
tensor_options: Optional[TensorOptionsArguments] = None
post_tensor_options_kwarg_only: List[Argument] = []
kwarg_only_acc = pre_tensor_options_kwarg_only
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
@ -2421,7 +2423,7 @@ class Arguments:
)
def __str__(self) -> str:
all_arguments: list[str] = []
all_arguments: List[str] = []
all_arguments.extend(map(str, self.flat_positional))
if self.flat_kwarg_only or self.out:
all_arguments.append("*")
@ -2500,7 +2502,7 @@ class BaseOperatorName:
functional_overload: bool = False
@staticmethod
def parse(op: str) -> BaseOperatorName:
def parse(op: str) -> "BaseOperatorName":
assert op != ""
assert not op.endswith("_out"), (
"_out suffix is reserved and not permitted for operator names; "
@ -2572,7 +2574,7 @@ class OperatorName:
overload_name: str
@staticmethod
def parse(op_name: str) -> OperatorName:
def parse(op_name: str) -> "OperatorName":
if "." in op_name:
name, overload_name = op_name.split(".", 1)
else:
@ -2599,7 +2601,7 @@ class OperatorName:
else:
return f"{self.name}"
def remove_inplace(self) -> OperatorName:
def remove_inplace(self) -> "OperatorName":
return OperatorName(
name=BaseOperatorName(
base=self.name.base,
@ -2609,7 +2611,7 @@ class OperatorName:
overload_name=self.overload_name,
)
def with_overload(self, overload: str) -> OperatorName:
def with_overload(self, overload: str) -> "OperatorName":
return OperatorName(
name=BaseOperatorName(
base=self.name.base,
@ -2647,9 +2649,9 @@ class NativeFunctionsViewGroup:
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
# (we already get them "for free" through decomposition)
view_copy: NativeFunction | None
view_copy: Optional[NativeFunction]
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
view_inplace: NativeFunction | None
view_inplace: Optional[NativeFunction]
def __post_init__(self) -> None:
assert self.view.is_view_op
@ -2729,7 +2731,7 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
# Given a NativeFunction that corresponds to a view op,
# returns the OperatorName of the corresponding "copy" variant of the op.
def get_view_copy_name(f: NativeFunction) -> OperatorName:
def get_view_copy_name(f: NativeFunction) -> "OperatorName":
# Right now, when asking for a view op's corresponding "view_copy" name
# we assert for sanity that the op is allowed to have a generated view_copy variant.
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
@ -2753,7 +2755,7 @@ def get_view_copy_name(f: NativeFunction) -> OperatorName:
# Helper functions for parsing argument lists (both inputs and returns)
def parse_returns(return_decl: str) -> tuple[Return, ...]:
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
"""
Input: '()'
Output: []
@ -2772,12 +2774,12 @@ def parse_returns(return_decl: str) -> tuple[Return, ...]:
class Precompute:
# A map from kernel argument name -> a list of precomputed
# elements that replaces/supersedes it.
replace: dict[str, list[Argument]]
replace: Dict[str, List[Argument]]
# List of precomputed args added without replacement
add: list[Argument]
add: List[Argument]
@staticmethod
def parse(src: object) -> Precompute:
def parse(src: object) -> "Precompute":
assert isinstance(src, list)
# src is a list of strings of the format:
@ -2822,7 +2824,7 @@ class Precompute:
for a in args:
assert a.name.upper() != a.name
def to_list(self) -> list[str]:
def to_list(self) -> List[str]:
replace_list = []
for kernel_param, replacement_params in self.replace.items():
replacements = ", ".join(str(param) for param in replacement_params)

View File

@ -1,7 +1,5 @@
from __future__ import annotations
from collections import defaultdict
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate
@ -103,9 +101,9 @@ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# But have differing SchemaKinds.
def pre_group_native_functions(
native_functions: Sequence[NativeFunction],
) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: dict[
FunctionSchema, dict[SchemaKind, NativeFunction]
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: Dict[
FunctionSchema, Dict[SchemaKind, NativeFunction]
] = defaultdict(dict)
for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()]
@ -115,7 +113,7 @@ def pre_group_native_functions(
# Returns the out variant overload name given a base function overload name
def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
def get_expected_out_variant_overload_name(overload_name: Optional[str]) -> str:
return "out" if not overload_name else f"{overload_name}_out"
@ -180,7 +178,7 @@ def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
def generate_out_args_from_schema(
func: FunctionSchema,
) -> tuple[list[Return], list[Argument]]:
) -> Tuple[List[Return], List[Argument]]:
# More of a sanity check - our existing restrictions on schemas should enforce that
# mutable schema kinds never return their mutable arguments.
assert not any(
@ -200,11 +198,11 @@ def generate_out_args_from_schema(
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
new_out_args: list[Argument] = []
new_out_args: List[Argument] = []
# The end result of new_returns is that:
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
new_returns: list[Return] = []
new_returns: List[Return] = []
for i, r in enumerate(func.returns):
if r.type.is_tensor_like():
new_out = Argument(
@ -268,7 +266,7 @@ def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Details are in the function, but we only generate composite kernels (in some cases) today.
def generate_function(
f: NativeFunction, k: SchemaKind
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
from torchgen.api import cpp
if k == SchemaKind.functional:
@ -377,8 +375,8 @@ def generate_function(
# Note: this function *mutates* its two inputs,
# adding the new NativeFunctions / BackendMetadata to them
def add_generated_native_functions(
rs: list[NativeFunction],
indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
rs: List[NativeFunction],
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
) -> None:
# The main code for generating new NativeFunctions
# First we group of NativeFunctions by schema kind,
@ -499,7 +497,7 @@ out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THA
rs.append(fn)
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
assert len(rets) == len(names)
if len(rets) == 0:
return ""
@ -511,7 +509,7 @@ def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
# Given a function, and the name of a variable corresponding to the output of that function,
# gather up all of the individual returns that are not aliased
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]:
aliased_rets = func.aliased_return_names()
non_aliased_names = []
is_out_var_a_tuple = len(func.returns) > 1
@ -526,7 +524,7 @@ def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str
# Generates functional kernels in terms of their inplace.mutable counterparts.
# We only do this for "generated" NativeFunctions
@with_native_function
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# We should only be generating these for code-generated NativeFunctions
if "generated" not in g.functional.tags:
return None
@ -543,7 +541,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
sig = DispatcherSignature(g.functional.func)
target_sig = DispatcherSignature(target_f.func)
context: list[Binding | Expr] = []
context: List[Union[Binding, Expr]] = []
clone_mutable_inputs = []
cloned_return_names = []
# We can't just directly pass all of the arguments from the functional op into the mutating op.
@ -589,7 +587,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
# Generates out= kernels in terms of their functional counterparts.
# We only do this for "generated" NativeFunctions
@with_native_function
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# We should only be generating these for code-generated NativeFunctions
if "generated" not in g.out.tags:
return None

View File

@ -1,12 +1,9 @@
#!/usr/bin/env python3
from __future__ import annotations
import os
from enum import Enum
from operator import itemgetter
from pathlib import Path
from typing import Any
from typing import Any, Dict, List
import torch
from torch.jit.generate_bytecode import generate_upgraders_bytecode
@ -188,7 +185,7 @@ PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
)
def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
instruction_list_part = []
for instruction in instruction_list_from_yaml:
instruction_list_part.append(
@ -203,7 +200,7 @@ def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
)
def construct_constants(constants_list_from_yaml: list[Any]) -> str:
def construct_constants(constants_list_from_yaml: List[Any]) -> str:
constants_list_part = []
for constant_from_yaml in constants_list_from_yaml:
convert_constant = None
@ -229,7 +226,7 @@ def construct_constants(constants_list_from_yaml: list[Any]) -> str:
)
def construct_operators(operator_list_from_yaml: list[Any]) -> str:
def construct_operators(operator_list_from_yaml: List[Any]) -> str:
operator_list_part = []
for operator in operator_list_from_yaml:
operator_list_part.append(
@ -244,7 +241,7 @@ def construct_operators(operator_list_from_yaml: list[Any]) -> str:
)
def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
def construct_types(types_tr_list_from_yaml: List[Any]) -> str:
types_tr_list_part = []
for types_tr in types_tr_list_from_yaml:
types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
@ -263,7 +260,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
def construct_version_maps(
upgrader_bytecode_function_to_index_map: dict[str, Any]
upgrader_bytecode_function_to_index_map: Dict[str, Any]
) -> str:
version_map = torch._C._get_operator_version_map()
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
@ -305,8 +302,8 @@ def construct_version_maps(
def get_upgrader_bytecode_function_to_index_map(
upgrader_dict: list[dict[str, Any]]
) -> dict[str, Any]:
upgrader_dict: List[Dict[str, Any]]
) -> Dict[str, Any]:
upgrader_bytecode_function_to_index_map = {}
index = 0
for upgrader_bytecode in upgrader_dict:
@ -318,7 +315,7 @@ def get_upgrader_bytecode_function_to_index_map(
return upgrader_bytecode_function_to_index_map
def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
body_parts = []
upgrader_bytecode_function_to_index_map = (
get_upgrader_bytecode_function_to_index_map(upgrader_dict)
@ -373,7 +370,7 @@ def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
out_file.write(upgrader_file_content.encode("utf-8"))
def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
sorted_upgrader_list = sorted(
upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
)

View File

@ -1,6 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
# This class holds information about a single operator used to determine
@ -47,12 +46,12 @@ class SelectiveBuildOperator:
include_all_overloads: bool
# Debug Information at the operator level
_debug_info: tuple[str, ...] | None
_debug_info: Optional[Tuple[str, ...]]
@staticmethod
def from_yaml_dict(
op_name: str, op_info: dict[str, object]
) -> SelectiveBuildOperator:
op_name: str, op_info: Dict[str, object]
) -> "SelectiveBuildOperator":
allowed_keys = {
"name",
"is_root_operator",
@ -80,7 +79,7 @@ class SelectiveBuildOperator:
include_all_overloads = op_info.get("include_all_overloads", True)
assert isinstance(include_all_overloads, bool)
debug_info: tuple[str, ...] | None = None
debug_info: Optional[Tuple[str, ...]] = None
if "debug_info" in op_info:
di_list = op_info["debug_info"]
assert isinstance(di_list, list)
@ -97,7 +96,7 @@ class SelectiveBuildOperator:
@staticmethod
def from_legacy_operator_name_without_overload(
name: str,
) -> SelectiveBuildOperator:
) -> "SelectiveBuildOperator":
return SelectiveBuildOperator(
name=name,
is_root_operator=True,
@ -106,8 +105,8 @@ class SelectiveBuildOperator:
_debug_info=None,
)
def to_dict(self) -> dict[str, object]:
ret: dict[str, object] = {
def to_dict(self) -> Dict[str, object]:
ret: Dict[str, object] = {
"is_root_operator": self.is_root_operator,
"is_used_for_training": self.is_used_for_training,
"include_all_overloads": self.include_all_overloads,
@ -119,9 +118,9 @@ class SelectiveBuildOperator:
def merge_debug_info(
lhs: tuple[str, ...] | None,
rhs: tuple[str, ...] | None,
) -> tuple[str, ...] | None:
lhs: Optional[Tuple[str, ...]],
rhs: Optional[Tuple[str, ...]],
) -> Optional[Tuple[str, ...]]:
# Ensure that when merging, each entry shows up just once.
if lhs is None and rhs is None:
return None
@ -130,8 +129,8 @@ def merge_debug_info(
def combine_operators(
lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
) -> SelectiveBuildOperator:
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator"
) -> "SelectiveBuildOperator":
if str(lhs.name) != str(rhs.name):
raise Exception( # noqa: TRY002
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
@ -153,10 +152,10 @@ def combine_operators(
def merge_operator_dicts(
lhs: dict[str, SelectiveBuildOperator],
rhs: dict[str, SelectiveBuildOperator],
) -> dict[str, SelectiveBuildOperator]:
operators: dict[str, SelectiveBuildOperator] = {}
lhs: Dict[str, SelectiveBuildOperator],
rhs: Dict[str, SelectiveBuildOperator],
) -> Dict[str, SelectiveBuildOperator]:
operators: Dict[str, SelectiveBuildOperator] = {}
for op_name, op in list(lhs.items()) + list(rhs.items()):
new_op = op
if op_name in operators:

View File

@ -1,12 +1,11 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Dict, List, Optional, Set, Tuple
import yaml
from torchgen.model import NativeFunction
from torchgen.selective_build.operator import (
merge_debug_info,
merge_operator_dicts,
@ -15,10 +14,6 @@ from torchgen.selective_build.operator import (
)
if TYPE_CHECKING:
from torchgen.model import NativeFunction
# A SelectiveBuilder holds information extracted from the selective build
# YAML specification.
#
@ -33,10 +28,10 @@ class SelectiveBuilder:
include_all_operators: bool
# Debug Information at the selective/custom build level.
_debug_info: tuple[str, ...] | None
_debug_info: Optional[Tuple[str, ...]]
# A dictionary of operator -> operator metadata.
operators: dict[str, SelectiveBuildOperator]
operators: Dict[str, SelectiveBuildOperator]
# A dictionary of selected kernel tags and dtypes. Typically a
# PyTorch Operator Kernel (function) may have many code paths
@ -44,22 +39,22 @@ class SelectiveBuilder:
# one per kernel function, but there could be many per kernel
# function. The tag isn't a kernel function name, but some fragment
# of the kernel function implementation itself.
kernel_metadata: dict[str, list[str]]
kernel_metadata: Dict[str, List[str]]
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input
# dtypes for tensor-like input args).
# This is from selective.yaml
et_kernel_metadata: dict[str, list[str]]
et_kernel_metadata: Dict[str, List[str]]
# A set of all the custom torch bind classes used by the selected models
# Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls
custom_classes: set[str]
custom_classes: Set[str]
# A set of all the build features used by the selected models
# Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls
build_features: set[str]
build_features: Set[str]
# If true, then fragments for all dtypes for all kernel functions
# are included as well as all custom classes. This is typically set when any one of the
@ -68,11 +63,11 @@ class SelectiveBuilder:
include_all_non_op_selectives: bool
@staticmethod
def get_nop_selector() -> SelectiveBuilder:
def get_nop_selector() -> "SelectiveBuilder":
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
@staticmethod
def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder":
valid_top_level_keys = {
"include_all_non_op_selectives",
"include_all_operators",
@ -140,20 +135,20 @@ class SelectiveBuilder:
)
@staticmethod
def from_yaml_str(config_contents: str) -> SelectiveBuilder:
def from_yaml_str(config_contents: str) -> "SelectiveBuilder":
contents = yaml.safe_load(config_contents)
return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod
def from_yaml_path(config_path: str) -> SelectiveBuilder:
def from_yaml_path(config_path: str) -> "SelectiveBuilder":
with open(config_path) as f:
contents = yaml.safe_load(f)
return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod
def from_legacy_op_registration_allow_list(
allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
) -> SelectiveBuilder:
allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool
) -> "SelectiveBuilder":
operators = {}
for op in allow_list:
operators[op] = {
@ -236,7 +231,7 @@ class SelectiveBuilder:
and dtype in self.kernel_metadata[kernel_tag]
)
def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]:
"""
Return a list of kernel keys that cover the used ops
"""
@ -266,8 +261,8 @@ class SelectiveBuilder:
return list(result_set)
def to_dict(self) -> dict[str, object]:
ret: dict[str, object] = {
def to_dict(self) -> Dict[str, object]:
ret: Dict[str, object] = {
"include_all_non_op_selectives": self.include_all_non_op_selectives,
"include_all_operators": self.include_all_operators,
}
@ -293,10 +288,10 @@ class SelectiveBuilder:
def merge_kernel_metadata(
lhs: dict[str, list[str]],
rhs: dict[str, list[str]],
) -> dict[str, list[str]]:
kernel_metadata: dict[str, list[str]] = {}
lhs: Dict[str, List[str]],
rhs: Dict[str, List[str]],
) -> Dict[str, List[str]]:
kernel_metadata: Dict[str, List[str]] = {}
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
dtypes_copy = set(dtypes)
if tag_name in kernel_metadata:
@ -308,10 +303,10 @@ def merge_kernel_metadata(
def merge_et_kernel_metadata(
lhs: dict[str, list[str]],
rhs: dict[str, list[str]],
) -> dict[str, list[str]]:
merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
lhs: Dict[str, List[str]],
rhs: Dict[str, List[str]],
) -> Dict[str, List[str]]:
merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set)
for op in list(lhs.keys()) + list(rhs.keys()):
merge_et_kernel_metadata[op].update(lhs.get(op, []))
merge_et_kernel_metadata[op].update(rhs.get(op, []))

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
import importlib.util
import os
import sys
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain
from pathlib import Path
@ -18,9 +18,9 @@ you are in the root directory of the Pytorch git repo"""
if not file_path.exists():
raise Exception(err_msg) # noqa: TRY002
spec = spec_from_file_location(module_name, file_path)
spec = importlib.util.spec_from_file_location(module_name, file_path)
assert spec is not None
module = module_from_spec(spec)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
assert module is not None

View File

@ -1,9 +1,9 @@
from __future__ import annotations
from typing import Dict, Union
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
if isinstance(g, NativeFunctionsGroup):
return str(g.functional.func.name.name.base)
else:
@ -55,12 +55,12 @@ is_hand_written_ops_ = frozenset(
)
def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
name_base = func_name_base_str(g)
return name_base in is_hand_written_ops_
def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
assert index == 0 or index == 1
if op_name == "addr":
if index == 0:

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import argparse
import itertools
import os
@ -30,7 +28,7 @@ def group_functions_by_op_name(
return []
groups = []
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
with native_function_manager(g):
return generator.is_supported(g)

View File

@ -1,9 +1,7 @@
from __future__ import annotations
import json
import logging
import math
from typing import Sequence
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager
@ -27,7 +25,7 @@ logger: logging.Logger = logging.getLogger()
def has_alias(
arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
) -> bool:
for arg in arguments:
annotation = getattr(arg, "annotation", None)
@ -239,7 +237,7 @@ BLOCKED_OPS = frozenset(
)
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
base_op_name = ""
func = None
if isinstance(g, NativeFunctionsViewGroup):
@ -300,8 +298,8 @@ def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
def ivalue_type_conversion_method(
arg_type: BaseType | OptionalType | Type,
) -> tuple[bool, str] | None:
arg_type: Union[BaseType, OptionalType, Type]
) -> Optional[Tuple[bool, str]]:
"""
Return the method call expression of `c10::ivalue' to convert its contained value to
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
@ -396,7 +394,7 @@ def test_tensor_dim(op_name: str) -> int:
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
def test_tensor_shape(op_name: str) -> str:
@ -407,7 +405,7 @@ def test_tensor_shape(op_name: str) -> str:
def test_value_expression(
arg_type: BaseType | OptionalType | Type, index: int, op_name: str
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
) -> str:
tensor_size_ex = test_tensor_shape(op_name)
if tensor_size_ex == "":
@ -477,8 +475,8 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
def generate_test_ir_arguments(
schema: FunctionSchema,
) -> list[tuple[str, str | None]]:
def ir_argument(arg: Argument) -> tuple[str, str | None]:
) -> List[Tuple[str, Optional[str]]]:
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
t = arg.type
add_optional = False
if isinstance(t, OptionalType):

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import contextlib
import functools
import hashlib
@ -7,29 +5,31 @@ import os
import re
import sys
import textwrap
from argparse import Namespace
from dataclasses import fields, is_dataclass
from enum import auto, Enum
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
NoReturn,
Optional,
Sequence,
TYPE_CHECKING,
Set,
Tuple,
TypeVar,
Union,
)
from typing_extensions import Self
from torchgen.code_template import CodeTemplate
if TYPE_CHECKING:
from argparse import Namespace
# Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which
@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)"
# TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> tuple[str, list[str]]:
def split_name_params(schema: str) -> Tuple[str, List[str]]:
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
if m is None:
raise RuntimeError(f"Unsupported function schema: {schema}")
@ -73,7 +73,7 @@ S = TypeVar("S")
# Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
for x in xs:
r = func(x)
if r is not None:
@ -127,7 +127,7 @@ class FileManager:
install_dir: str
template_dir: str
dry_run: bool
filenames: set[str]
filenames: Set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir
@ -136,7 +136,7 @@ class FileManager:
self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: str | None
old_contents: Optional[str]
try:
with open(filename) as f:
old_contents = f.read()
@ -150,7 +150,7 @@ class FileManager:
# Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
) -> str:
template_path = os.path.join(self.template_dir, template_fn)
env = env_callable()
@ -171,7 +171,7 @@ class FileManager:
self,
filename: str,
template_fn: str,
env_callable: Callable[[], str | dict[str, Any]],
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
filename = f"{self.install_dir}/{filename}"
assert filename not in self.filenames, "duplicate file write {filename}"
@ -186,7 +186,7 @@ class FileManager:
def write(
self,
filename: str,
env_callable: Callable[[], str | dict[str, Any]],
env_callable: Callable[[], Union[str, Dict[str, Any]]],
) -> None:
self.write_with_template(filename, filename, env_callable)
@ -196,13 +196,13 @@ class FileManager:
items: Iterable[T],
*,
key_fn: Callable[[T], str],
env_callable: Callable[[T], dict[str, list[str]]],
env_callable: Callable[[T], Dict[str, List[str]]],
num_shards: int,
base_env: dict[str, Any] | None = None,
sharded_keys: set[str],
base_env: Optional[Dict[str, Any]] = None,
sharded_keys: Set[str],
) -> None:
everything: dict[str, Any] = {"shard_id": "Everything"}
shards: list[dict[str, Any]] = [
everything: Dict[str, Any] = {"shard_id": "Everything"}
shards: List[Dict[str, Any]] = [
{"shard_id": f"_{i}"} for i in range(num_shards)
]
all_shards = [everything] + shards
@ -221,7 +221,7 @@ class FileManager:
else:
shard[key] = []
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v
@ -275,7 +275,7 @@ class FileManager:
# Helper function to generate file manager
def make_file_manager(
options: Namespace, install_dir: str | None = None
options: Namespace, install_dir: Optional[str] = None
) -> FileManager:
template_dir = os.path.join(options.source_path, "templates")
install_dir = install_dir if install_dir else options.install_dir
@ -335,7 +335,7 @@ def _pformat(
def _format_dict(
attr: dict[Any, Any],
attr: Dict[Any, Any],
indent: int,
width: int,
curr_indent: int,
@ -355,7 +355,7 @@ def _format_dict(
def _format_list(
attr: list[Any] | set[Any] | tuple[Any, ...],
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
indent: int,
width: int,
curr_indent: int,
@ -370,7 +370,7 @@ def _format_list(
def _format(
fields_str: list[str],
fields_str: List[str],
indent: int,
width: int,
curr_indent: int,
@ -402,9 +402,7 @@ class NamespaceHelper:
} // namespace torch
"""
def __init__(
self, namespace_str: str, entity_name: str = "", max_level: int = 2
) -> None:
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
# cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::")
assert (
@ -421,7 +419,7 @@ class NamespaceHelper:
@staticmethod
def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2
) -> NamespaceHelper:
) -> "NamespaceHelper":
"""
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
"""
@ -454,9 +452,9 @@ class NamespaceHelper:
class OrderedSet(Generic[T]):
storage: dict[T, Literal[None]]
storage: Dict[T, Literal[None]]
def __init__(self, iterable: Iterable[T] | None = None) -> None:
def __init__(self, iterable: Optional[Iterable[T]] = None):
if iterable is None:
self.storage = {}
else:
@ -468,28 +466,28 @@ class OrderedSet(Generic[T]):
def __iter__(self) -> Iterator[T]:
return iter(self.storage.keys())
def update(self, items: OrderedSet[T]) -> None:
def update(self, items: "OrderedSet[T]") -> None:
self.storage.update(items.storage)
def add(self, item: T) -> None:
self.storage[item] = None
def copy(self) -> OrderedSet[T]:
def copy(self) -> "OrderedSet[T]":
ret: OrderedSet[T] = OrderedSet()
ret.storage = self.storage.copy()
return ret
@staticmethod
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
ret = args[0].copy()
for s in args[1:]:
ret.update(s)
return ret
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
return OrderedSet.union(self, other)
def __ior__(self, other: OrderedSet[T]) -> Self:
def __ior__(self, other: "OrderedSet[T]") -> Self:
self.update(other)
return self