[BE][Easy] enable postponed annotations in torchgen (#129376)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376
Approved by: https://github.com/ezyang
ghstack dependencies: #129375
This commit is contained in:
Xuehai Pan
2024-06-29 12:48:07 +08:00
committed by PyTorch MergeBot
parent 8a67daf283
commit 9120992c72
45 changed files with 977 additions and 901 deletions

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple from typing import cast, Sequence
from torchgen import local from torchgen import local
from torchgen.api import cpp from torchgen.api import cpp
@ -48,16 +50,16 @@ class Derivative:
original_formula: str original_formula: str
# Names of the arguments for which this formula calculates derivatives. # Names of the arguments for which this formula calculates derivatives.
var_names: Tuple[str, ...] var_names: tuple[str, ...]
# Saved inputs that are referenced by the formula. # Saved inputs that are referenced by the formula.
saved_inputs: Tuple[SavedAttribute, ...] saved_inputs: tuple[SavedAttribute, ...]
# Saved outputs that are referenced by the formula. # Saved outputs that are referenced by the formula.
saved_outputs: Tuple[SavedAttribute, ...] saved_outputs: tuple[SavedAttribute, ...]
# Gradients that are referenced by name in the formula. # Gradients that are referenced by name in the formula.
named_gradients: Set[str] named_gradients: set[str]
# Represents a forward formula that calculates forward derivatives # Represents a forward formula that calculates forward derivatives
@ -71,17 +73,17 @@ class ForwardDerivative:
# Name of the output arguments for which this formula calculates forward # Name of the output arguments for which this formula calculates forward
# derivatives # derivatives
var_names: Tuple[str, ...] var_names: tuple[str, ...]
# Type of the output arguments for which this formula calculates forward # Type of the output arguments for which this formula calculates forward
# derivatives # derivatives
var_types: Tuple[Type, ...] var_types: tuple[Type, ...]
# Inputs for which the forward derivatives are required for this formula # Inputs for which the forward derivatives are required for this formula
required_inputs_fw_grad: Optional[Tuple[str, ...]] required_inputs_fw_grad: tuple[str, ...] | None
# Inputs for which the primal is required for this formula # Inputs for which the primal is required for this formula
required_inputs_primal: Optional[Tuple[str, ...]] required_inputs_primal: tuple[str, ...] | None
# Flag to specify if this formula requires the original value of self # Flag to specify if this formula requires the original value of self
# This is only used by inplace operations # This is only used by inplace operations
@ -116,7 +118,7 @@ class DifferentiabilityInfo:
# The name of the generated autograd function. # The name of the generated autograd function.
# It's set only if we will calculate a derivative, i.e. # It's set only if we will calculate a derivative, i.e.
# 'args_with_derivatives' is not empty. # 'args_with_derivatives' is not empty.
op: Optional[str] op: str | None
# The derivatives formulae for this function. # The derivatives formulae for this function.
# Note that the length of this sequence is the number of differentiable inputs # Note that the length of this sequence is the number of differentiable inputs
@ -138,7 +140,7 @@ class DifferentiabilityInfo:
# The named gradients that are used in any of the derivatives. # The named gradients that are used in any of the derivatives.
# Invariant: all(name in available_named_gradients for name in used_named_gradients) # Invariant: all(name in available_named_gradients for name in used_named_gradients)
used_named_gradients: Set[str] used_named_gradients: set[str]
# The function's input arguments for which it calculates derivatives. # The function's input arguments for which it calculates derivatives.
# It's the union of 'var_names' of all 'derivatives', sorted by the # It's the union of 'var_names' of all 'derivatives', sorted by the
@ -149,7 +151,7 @@ class DifferentiabilityInfo:
non_differentiable_arg_names: Sequence[str] non_differentiable_arg_names: Sequence[str]
# Raw data read from derivatives.yaml. # Raw data read from derivatives.yaml.
output_differentiability: Optional[List[bool]] output_differentiability: list[bool] | None
# output_differentiability in derivatives.yaml can be a list of # output_differentiability in derivatives.yaml can be a list of
# conditions that express if the output is differentiable. In this case, # conditions that express if the output is differentiable. In this case,
@ -157,7 +159,7 @@ class DifferentiabilityInfo:
# (NB: we only support one condition right now). # (NB: we only support one condition right now).
# output_differentiability gets populated with True for each condition, # output_differentiability gets populated with True for each condition,
# while output_differentiability_conditions gets populated with the conditions # while output_differentiability_conditions gets populated with the conditions
output_differentiability_conditions: Optional[List[str]] output_differentiability_conditions: list[str] | None
@property @property
def has_derivatives(self) -> bool: def has_derivatives(self) -> bool:
@ -170,7 +172,7 @@ class DifferentiabilityInfo:
# See Note [Codegen'd {view}_copy Operators] # See Note [Codegen'd {view}_copy Operators]
def create_view_copy_from_view_derivative( def create_view_copy_from_view_derivative(
self, g: NativeFunctionsViewGroup self, g: NativeFunctionsViewGroup
) -> Optional["DifferentiabilityInfo"]: ) -> DifferentiabilityInfo | None:
if g.view_copy is None: if g.view_copy is None:
return None return None
f = g.view_copy f = g.view_copy
@ -201,7 +203,7 @@ class DifferentiabilityInfo:
) )
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool: def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
if info is None: if info is None:
return False return False
for derivative in info.derivatives: for derivative in info.derivatives:
@ -211,11 +213,11 @@ def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
return False return False
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool: def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "retain_variables") return uses_ident(info, "retain_variables")
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool: def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
return uses_ident(info, "grad") return uses_ident(info, "grad")
@ -253,8 +255,8 @@ class DifferentiableOutput:
@dataclass(frozen=True) @dataclass(frozen=True)
class NativeFunctionWithDifferentiabilityInfo: class NativeFunctionWithDifferentiabilityInfo:
func: NativeFunction func: NativeFunction
info: Optional[Dict[str, DifferentiabilityInfo]] info: dict[str, DifferentiabilityInfo] | None
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]] fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
# TODO: Update comment below since it is out of date. # TODO: Update comment below since it is out of date.
@ -363,19 +365,19 @@ def is_reference_for_foreach(
# TODO(crcrpar): Avoid hard coding "Default" ideally. # TODO(crcrpar): Avoid hard coding "Default" ideally.
def gen_foreach_derivativeinfo( def gen_foreach_derivativeinfo(
foreach_function: NativeFunction, foreach_function: NativeFunction,
functional_info_by_signature: Dict[ functional_info_by_signature: dict[
FunctionSchema, Dict[str, DifferentiabilityInfo] FunctionSchema, dict[str, DifferentiabilityInfo]
], ],
non_functional_info_by_signature: Dict[ non_functional_info_by_signature: dict[
FunctionSchema, Dict[str, DifferentiabilityInfo] FunctionSchema, dict[str, DifferentiabilityInfo]
], ],
dispatch_key: str = "Default", dispatch_key: str = "Default",
) -> Tuple[Optional[DifferentiabilityInfo], bool]: ) -> tuple[DifferentiabilityInfo | None, bool]:
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place. """Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
The second return value indicates whether the info is generated in this function. The second return value indicates whether the info is generated in this function.
""" """
ref_diff_info: Optional[DifferentiabilityInfo] = None ref_diff_info: DifferentiabilityInfo | None = None
for function_schema, diff_info in functional_info_by_signature.items(): for function_schema, diff_info in functional_info_by_signature.items():
if not is_reference_for_foreach(foreach_function, function_schema): if not is_reference_for_foreach(foreach_function, function_schema):
@ -485,13 +487,13 @@ def gen_foreach_derivativeinfo(
if arg.name in all_var_names if arg.name in all_var_names
] ]
forward_derivatives: List[ForwardDerivative] = [] forward_derivatives: list[ForwardDerivative] = []
fw_derivative: ForwardDerivative fw_derivative: ForwardDerivative
for fw_derivative in ref_diff_info.forward_derivatives: for fw_derivative in ref_diff_info.forward_derivatives:
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef] var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
var_types: List[Type] = list(fw_derivative.var_types) var_types: list[Type] = list(fw_derivative.var_types)
required_inputs_fw_grad: List[str] = [] required_inputs_fw_grad: list[str] = []
required_inputs_primal: List[str] = [] required_inputs_primal: list[str] = []
if fw_derivative.required_inputs_fw_grad is not None: if fw_derivative.required_inputs_fw_grad is not None:
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad) required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
if fw_derivative.required_inputs_primal: if fw_derivative.required_inputs_primal:
@ -578,9 +580,9 @@ def gen_foreach_derivativeinfo(
def match_differentiability_info( def match_differentiability_info(
native_functions: List[NativeFunction], native_functions: list[NativeFunction],
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]], differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
) -> List[NativeFunctionWithDifferentiabilityInfo]: ) -> list[NativeFunctionWithDifferentiabilityInfo]:
"""Sets the "derivative" key on declarations to matching autograd function """Sets the "derivative" key on declarations to matching autograd function
In-place functions will use the out-of-place derivative definition if there In-place functions will use the out-of-place derivative definition if there
is no in-place specific derivative. is no in-place specific derivative.
@ -599,7 +601,7 @@ def match_differentiability_info(
def find_info( def find_info(
f: NativeFunction, f: NativeFunction,
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]: ) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
# Don't bother matching info to generated out= variants # Don't bother matching info to generated out= variants
if "generated" in f.tags and f.func.kind() == SchemaKind.out: if "generated" in f.tags and f.func.kind() == SchemaKind.out:
return None, False return None, False
@ -653,7 +655,7 @@ Attempted to convert a derivative formula for a mutable operator
return None, False return None, False
result: List[NativeFunctionWithDifferentiabilityInfo] = [] result: list[NativeFunctionWithDifferentiabilityInfo] = []
for f in native_functions: for f in native_functions:
info_dict, is_exact_match = find_info(f) info_dict, is_exact_match = find_info(f)
@ -677,7 +679,7 @@ Attempted to convert a derivative formula for a mutable operator
) )
continue continue
fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {} fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
for key, info in info_dict.items(): for key, info in info_dict.items():
if not info.forward_derivatives: if not info.forward_derivatives:
fw_derivative_dict[key] = [] fw_derivative_dict[key] = []
@ -713,7 +715,7 @@ Attempted to convert a derivative formula for a mutable operator
formula = fw_info.formula formula = fw_info.formula
def replace_self_with_original_self(formula: str, postfix: str) -> str: def replace_self_with_original_self(formula: str, postfix: str) -> str:
def repl(m: Match[str]) -> str: def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}original_self{postfix}{m.group(2)}" return f"{m.group(1)}original_self{postfix}{m.group(2)}"
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula) return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
@ -734,7 +736,7 @@ Attempted to convert a derivative formula for a mutable operator
formula = replace_self_with_original_self(formula, "_t") formula = replace_self_with_original_self(formula, "_t")
# replace "result" from the formula by "self_p" # replace "result" from the formula by "self_p"
def repl(m: Match[str]) -> str: def repl(m: re.Match[str]) -> str:
return f"{m.group(1)}self_p{m.group(2)}" return f"{m.group(1)}self_p{m.group(2)}"
formula = re.sub(IDENT_REGEX.format("result"), repl, formula) formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
@ -758,8 +760,8 @@ Attempted to convert a derivative formula for a mutable operator
# If there is a need, we can relax (2) to allow any op that has an in-place variant # If there is a need, we can relax (2) to allow any op that has an in-place variant
is_single_method_on_self_t = False is_single_method_on_self_t = False
directly_do_inplace = False directly_do_inplace = False
op_name: Optional[str] = None op_name: str | None = None
between_parens: Optional[str] = None between_parens: str | None = None
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula) match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
if match: if match:
op_name, between_parens = match.group(1), match.group(2) op_name, between_parens = match.group(1), match.group(2)
@ -823,7 +825,7 @@ Attempted to convert a derivative formula for a mutable operator
def is_differentiable( def is_differentiable(
name: str, type: Type, info: Optional[DifferentiabilityInfo] name: str, type: Type, info: DifferentiabilityInfo | None
) -> bool: ) -> bool:
return type.is_tensor_like() and ( return type.is_tensor_like() and (
info is None or name not in info.non_differentiable_arg_names info is None or name not in info.non_differentiable_arg_names
@ -832,10 +834,10 @@ def is_differentiable(
def gen_differentiable_outputs( def gen_differentiable_outputs(
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default" fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
) -> List[DifferentiableOutput]: ) -> list[DifferentiableOutput]:
f = fn.func f = fn.func
info = fn.info[key] if fn.info else None info = fn.info[key] if fn.info else None
outputs: List[DifferentiableOutput] = [ outputs: list[DifferentiableOutput] = [
DifferentiableOutput( DifferentiableOutput(
name=name, name=name,
type=ret.type, type=ret.type,
@ -850,7 +852,7 @@ def gen_differentiable_outputs(
f"The length of output_differentiability ({len(output_differentiability)}), " f"The length of output_differentiability ({len(output_differentiability)}), "
f"does not match the number of outputs ({len(outputs)})." f"does not match the number of outputs ({len(outputs)})."
) )
differentiable_outputs: List[DifferentiableOutput] = [] differentiable_outputs: list[DifferentiableOutput] = []
if False in output_differentiability and f.func.kind() == SchemaKind.inplace: if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
raise RuntimeError( raise RuntimeError(
"output_differentiability=False for inplace operation (version_counter won't get updated)" "output_differentiability=False for inplace operation (version_counter won't get updated)"

View File

@ -1,4 +1,6 @@
from typing import List, Optional, Sequence, Set, Union from __future__ import annotations
from typing import Sequence
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -94,7 +96,7 @@ def valuetype_type(
binds: ArgName, binds: ArgName,
remove_non_owning_ref_types: bool = False, remove_non_owning_ref_types: bool = False,
symint: bool = False, symint: bool = False,
) -> Optional[NamedCType]: ) -> NamedCType | None:
if isinstance(t, BaseType): if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None return None
@ -279,7 +281,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: List[str] = [] returns: list[str] = []
for i, r in enumerate(f.func.returns): for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is # If we have an inplace function, the return argument is
# implicitly named self. # implicitly named self.
@ -368,17 +370,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
def argument( def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument], a: Argument | TensorOptionsArguments | SelfArgument,
*, *,
cpp_no_default_args: Set[str], cpp_no_default_args: set[str],
method: bool, method: bool,
faithful: bool, faithful: bool,
symint: bool = False, symint: bool = False,
has_tensor_options: bool, has_tensor_options: bool,
) -> List[Binding]: ) -> list[Binding]:
def sub_argument( def sub_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument] a: Argument | TensorOptionsArguments | SelfArgument,
) -> List[Binding]: ) -> list[Binding]:
return argument( return argument(
a, a,
cpp_no_default_args=cpp_no_default_args, cpp_no_default_args=cpp_no_default_args,
@ -394,7 +396,7 @@ def argument(
binds = SpecialArgName.possibly_redundant_memory_format binds = SpecialArgName.possibly_redundant_memory_format
else: else:
binds = a.name binds = a.name
default: Optional[str] = None default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None: if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type, symint=symint) default = default_expr(a.default, a.type, symint=symint)
return [ return [
@ -445,9 +447,9 @@ def arguments(
faithful: bool, faithful: bool,
symint: bool = False, symint: bool = False,
method: bool, method: bool,
cpp_no_default_args: Set[str], cpp_no_default_args: set[str],
) -> List[Binding]: ) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful: if faithful:
args.extend(arguments.non_out) args.extend(arguments.non_out)
args.extend(arguments.out) args.extend(arguments.out)

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import itertools import itertools
from typing import List, Sequence, Union from typing import Sequence
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import ArgName, Binding, CType, NamedCType from torchgen.api.types import ArgName, Binding, CType, NamedCType
@ -76,10 +78,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
return cpp.returns_type(rs, symint=symint) return cpp.returns_type(rs, symint=symint)
def jit_arguments(func: FunctionSchema) -> List[Argument]: def jit_arguments(func: FunctionSchema) -> list[Argument]:
def to_argument( def to_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument] a: Argument | TensorOptionsArguments | SelfArgument,
) -> List[Argument]: ) -> list[Argument]:
if isinstance(a, Argument): if isinstance(a, Argument):
return [a] return [a]
elif isinstance(a, SelfArgument): elif isinstance(a, SelfArgument):
@ -114,5 +116,5 @@ def argument(
) )
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]: def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
return [argument(a, symint=symint) for a in jit_arguments(func)] return [argument(a, symint=symint) for a in jit_arguments(func)]

View File

@ -1,4 +1,4 @@
from typing import List, Optional from __future__ import annotations
from torchgen.api import dispatcher from torchgen.api import dispatcher
from torchgen.api.types import ( from torchgen.api.types import (
@ -93,7 +93,7 @@ def name(
*, *,
is_reverse: bool, is_reverse: bool,
include_namespace: bool, include_namespace: bool,
reapply_views: Optional[bool] = None, reapply_views: bool | None = None,
) -> str: ) -> str:
if reapply_views is None: if reapply_views is None:
# reapply_views is only important for the fwd lambda, # reapply_views is only important for the fwd lambda,
@ -124,7 +124,7 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
return f"{api_name}_inverse" return f"{api_name}_inverse"
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
# capture arguments include all arguments except `self`. # capture arguments include all arguments except `self`.
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
@ -152,14 +152,14 @@ def returns_type(func: FunctionSchema) -> CType:
return BaseCType(tensorT) return BaseCType(tensorT)
def outer_arguments(*, is_reverse: bool) -> List[Binding]: def outer_arguments(*, is_reverse: bool) -> list[Binding]:
if is_reverse: if is_reverse:
return [base_binding, mutated_view_binding, mutated_view_idx_binding] return [base_binding, mutated_view_binding, mutated_view_idx_binding]
else: else:
return [base_binding, mutated_view_idx_binding] return [base_binding, mutated_view_idx_binding]
def inner_call_index(func: FunctionSchema) -> Optional[Binding]: def inner_call_index(func: FunctionSchema) -> Binding | None:
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately # When we replay a view op that returns multiple tensors, we need to index into the output appropriately
if len(func.returns) > 1 or ( if len(func.returns) > 1 or (
@ -169,7 +169,7 @@ def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
return None return None
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
args = func.arguments.flat_all args = func.arguments.flat_all
assert args[0].type == BaseType(BaseTy.Tensor) assert args[0].type == BaseType(BaseTy.Tensor)
non_self_args = args[1:] non_self_args = args[1:]

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union from __future__ import annotations
from typing import Any
from torchgen.api.types import ( from torchgen.api.types import (
BaseCppType, BaseCppType,
@ -34,7 +36,7 @@ from torchgen.model import (
) )
_valueT: Optional[BaseCppType] = None _valueT: BaseCppType | None = None
# A ValueT is an IR type which represents the computation of a Tensor. In other # A ValueT is an IR type which represents the computation of a Tensor. In other
@ -66,8 +68,8 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
def process_ir_type( def process_ir_type(
typ: Type, properties: "LazyIrProperties", *, symint: bool typ: Type, properties: LazyIrProperties, *, symint: bool
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]: ) -> BaseCType | VectorCType | OptionalCType | ListCType:
""" """
This function takes a type from NativeFunctions and converts it for use with This function takes a type from NativeFunctions and converts it for use with
lazy tensor codegen. lazy tensor codegen.
@ -147,7 +149,7 @@ def process_ir_type(
# #
# Invariant: passed typ should be an *owning* CType (e.g., we will report # Invariant: passed typ should be an *owning* CType (e.g., we will report
# that ArrayRef<Value> is NOT a value type) # that ArrayRef<Value> is NOT a value type)
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool: def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
""" """
Given a type, determine if it is a Value-like type. This is equivalent to Given a type, determine if it is a Value-like type. This is equivalent to
being Tensor-like, but assumes the type has already been transformed. being Tensor-like, but assumes the type has already been transformed.
@ -202,7 +204,7 @@ def isGeneratorType(typ: Type) -> bool:
class LazyArgument: class LazyArgument:
name: str name: str
orig_type: Type orig_type: Type
lazy_type_: Optional[CType] lazy_type_: CType | None
is_wrapped_scalar: bool is_wrapped_scalar: bool
is_generator: bool is_generator: bool
# TODO: this is lies, it is false for symint list # TODO: this is lies, it is false for symint list
@ -214,7 +216,9 @@ class LazyArgument:
# true if this argument is or contains a lazy IR value # true if this argument is or contains a lazy IR value
is_lazy_value: bool is_lazy_value: bool
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool): def __init__(
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
) -> None:
self.name = arg.name self.name = arg.name
self.orig_type = arg.type self.orig_type = arg.type
self.symint = symint self.symint = symint
@ -248,7 +252,7 @@ class LazyIrProperties:
attributes. The mutual exclusivity is automatically handled. attributes. The mutual exclusivity is automatically handled.
""" """
Properties: Tuple[Tuple[str, ...], ...] = ( Properties: tuple[tuple[str, ...], ...] = (
( (
"ShapePrecompute", # Assume shape has been precomputed "ShapePrecompute", # Assume shape has been precomputed
"ShapeCompute", # Need to compute the shape on construction "ShapeCompute", # Need to compute the shape on construction
@ -271,8 +275,8 @@ class LazyIrProperties:
), ),
) )
def __init__(self, *default_properties: str): def __init__(self, *default_properties: str) -> None:
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys( properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
LazyIrProperties.Properties LazyIrProperties.Properties
) )
self.__dict__["properties"] = properties self.__dict__["properties"] = properties
@ -305,17 +309,17 @@ class LazyIrProperties:
# TODO: This is not idiomatic with how other torchgen APIs transform on schema. # TODO: This is not idiomatic with how other torchgen APIs transform on schema.
class LazyIrSchema: class LazyIrSchema:
# The name of the operator this function schema describes. # The name of the operator this function schema describes.
name: "OperatorName" name: OperatorName
positional_args: Tuple[LazyArgument, ...] positional_args: tuple[LazyArgument, ...]
keyword_args: Tuple[LazyArgument, ...] keyword_args: tuple[LazyArgument, ...]
# TODO: Need to handle collisions with argument names at some point # TODO: Need to handle collisions with argument names at some point
returns: Tuple["Return", ...] returns: tuple[Return, ...]
# if this schema has a Generator arg, list its orig ctype/name but don't # if this schema has a Generator arg, list its orig ctype/name but don't
# build a LazyArgument since lazy IR doesn't support it # build a LazyArgument since lazy IR doesn't support it
generator_arg: Optional[NamedCType] = None generator_arg: NamedCType | None = None
# original function schema # original function schema
func: FunctionSchema func: FunctionSchema
@ -329,21 +333,21 @@ class LazyIrSchema:
"Lower", "Lower",
"CanBeReused", "CanBeReused",
) )
opkind: Optional[str] = None opkind: str | None = None
def __init__( def __init__(
self, self,
func: FunctionSchema, func: FunctionSchema,
properties: Optional[LazyIrProperties] = None, properties: LazyIrProperties | None = None,
*, *,
symint: bool, symint: bool,
): ) -> None:
if properties: if properties:
self.properties = properties self.properties = properties
self.func = func self.func = func
self.symint = symint self.symint = symint
positional_args: List[LazyArgument] = [] positional_args: list[LazyArgument] = []
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]: for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
if arg_field == "self_arg" and func.arguments.self_arg is not None: if arg_field == "self_arg" and func.arguments.self_arg is not None:
arg = func.arguments.self_arg.argument arg = func.arguments.self_arg.argument
@ -357,7 +361,7 @@ class LazyIrSchema:
) )
self.positional_args = tuple(positional_args) self.positional_args = tuple(positional_args)
keyword_args: List[LazyArgument] = [] keyword_args: list[LazyArgument] = []
for arg_field in [ for arg_field in [
"pre_tensor_options_kwarg_only", "pre_tensor_options_kwarg_only",
"tensor_options", "tensor_options",
@ -411,13 +415,13 @@ class LazyIrSchema:
values: bool = True, values: bool = True,
scalars: bool = True, scalars: bool = True,
generator: bool = True, generator: bool = True,
) -> List[LazyArgument]: ) -> list[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views. # This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings), # Some parts of the code care about kwargs vs args (TS lowerings),
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone. # other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
# Generators are special cased, as they are needed for fallback/shape-inference but not supported # Generators are special cased, as they are needed for fallback/shape-inference but not supported
# in TS lowerings and therefore also omitted from lazy IR. # in TS lowerings and therefore also omitted from lazy IR.
args: List[LazyArgument] = [] args: list[LazyArgument] = []
if positional: if positional:
args.extend(self.positional_args) args.extend(self.positional_args)
if keyword: if keyword:
@ -439,25 +443,25 @@ class LazyIrSchema:
return [] return []
@property @property
def positional_values(self) -> List[LazyArgument]: def positional_values(self) -> list[LazyArgument]:
return self.filtered_args( return self.filtered_args(
positional=True, keyword=False, values=True, scalars=False positional=True, keyword=False, values=True, scalars=False
) )
@property @property
def positional_scalars(self) -> List[LazyArgument]: def positional_scalars(self) -> list[LazyArgument]:
return self.filtered_args( return self.filtered_args(
positional=True, keyword=False, values=False, scalars=True positional=True, keyword=False, values=False, scalars=True
) )
@property @property
def keyword_values(self) -> List[LazyArgument]: def keyword_values(self) -> list[LazyArgument]:
return self.filtered_args( return self.filtered_args(
positional=False, keyword=True, values=True, scalars=False positional=False, keyword=True, values=True, scalars=False
) )
@property @property
def keyword_scalars(self) -> List[LazyArgument]: def keyword_scalars(self) -> list[LazyArgument]:
return self.filtered_args( return self.filtered_args(
positional=False, keyword=True, values=False, scalars=True positional=False, keyword=True, values=False, scalars=True
) )

View File

@ -1,4 +1,6 @@
from typing import List, Optional, Sequence, Union from __future__ import annotations
from typing import Sequence
from torchgen import local from torchgen import local
from torchgen.api import cpp from torchgen.api import cpp
@ -81,11 +83,11 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
def argument( def argument(
a: Union[Argument, SelfArgument, TensorOptionsArguments], a: Argument | SelfArgument | TensorOptionsArguments,
*, *,
is_out: bool, is_out: bool,
symint: bool, symint: bool,
) -> List[Binding]: ) -> list[Binding]:
# Ideally, we NEVER default native functions. However, there are a number # Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting # of functions that call native:: directly and rely on the defaulting
# existing. So for BC, we generate defaults for non-out variants (but not # existing. So for BC, we generate defaults for non-out variants (but not
@ -93,7 +95,7 @@ def argument(
# default) # default)
should_default = not is_out should_default = not is_out
if isinstance(a, Argument): if isinstance(a, Argument):
default: Optional[str] = None default: str | None = None
if should_default and a.default is not None: if should_default and a.default is not None:
default = cpp.default_expr(a.default, a.type, symint=symint) default = cpp.default_expr(a.default, a.type, symint=symint)
return [ return [
@ -144,8 +146,8 @@ def argument(
assert_never(a) assert_never(a)
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]: def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(func.arguments.non_out) args.extend(func.arguments.non_out)
args.extend(func.arguments.out) args.extend(func.arguments.out)
return [ return [

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union from typing import Sequence
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
@ -197,14 +199,14 @@ from torchgen.model import (
@dataclass(frozen=True) @dataclass(frozen=True)
class PythonReturns: class PythonReturns:
returns: Tuple[Return, ...] returns: tuple[Return, ...]
@dataclass(frozen=True) @dataclass(frozen=True)
class PythonArgument: class PythonArgument:
name: str name: str
type: Type type: Type
default: Optional[str] default: str | None
# Used to generate the default init expr for some PythonArgParser outputs, e.g.: # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
# #
@ -212,7 +214,7 @@ class PythonArgument:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ^ # ^
# +--- default_init str # +--- default_init str
default_init: Optional[str] default_init: str | None
# Compute argument formal for python argument parsing. # Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
@ -300,12 +302,10 @@ class PythonOutArgument(PythonArgument):
# 'auto out = _r.tensorlist_n<2>(2);', # 'auto out = _r.tensorlist_n<2>(2);',
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc. # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
# TODO: maybe don't need keep scattered out fields for python signature? # TODO: maybe don't need keep scattered out fields for python signature?
outputs: Tuple[PythonArgument, ...] outputs: tuple[PythonArgument, ...]
@staticmethod @staticmethod
def from_outputs( def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
outputs: Tuple[PythonArgument, ...]
) -> Optional["PythonOutArgument"]:
if not outputs: if not outputs:
return None return None
@ -339,13 +339,13 @@ class PythonSignature:
# Positional arguments. # Positional arguments.
# TODO: create a dedicated SelfArgument type for 'self'? # TODO: create a dedicated SelfArgument type for 'self'?
input_args: Tuple[PythonArgument, ...] input_args: tuple[PythonArgument, ...]
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc). # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
input_kwargs: Tuple[PythonArgument, ...] input_kwargs: tuple[PythonArgument, ...]
output_args: Optional[PythonOutArgument] output_args: PythonOutArgument | None
# Return types, which are only used by pyi # Return types, which are only used by pyi
returns: PythonReturns returns: PythonReturns
@ -356,7 +356,7 @@ class PythonSignature:
# for out variant), in which case they will be used as scattered fields without # for out variant), in which case they will be used as scattered fields without
# being packed into 'options'. # being packed into 'options'.
# TODO: maybe create a PythonTensorOptionsArgument? # TODO: maybe create a PythonTensorOptionsArgument?
tensor_options_args: Tuple[PythonArgument, ...] tensor_options_args: tuple[PythonArgument, ...]
# method or function signature? # method or function signature?
method: bool method: bool
@ -367,8 +367,8 @@ class PythonSignature:
def arguments( def arguments(
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]: ) -> tuple[PythonArgument | PythonOutArgument, ...]:
result: List[Union[PythonArgument, PythonOutArgument]] = [] result: list[PythonArgument | PythonOutArgument] = []
result.extend(self.input_args) result.extend(self.input_args)
result.extend(self.input_kwargs) result.extend(self.input_kwargs)
if self.output_args is not None and not skip_outputs: if self.output_args is not None and not skip_outputs:
@ -394,7 +394,7 @@ class PythonSignature:
# signature_str_pyi(). # signature_str_pyi().
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str: def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [ schema_formals: list[str] = [
a.argument_str(method=self.method, symint=symint) for a in args a.argument_str(method=self.method, symint=symint) for a in args
] ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
@ -405,7 +405,7 @@ class PythonSignature:
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [ schema_formals: list[str] = [
a.argument_str_pyi(method=self.method) for a in args a.argument_str_pyi(method=self.method) for a in args
] ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
@ -419,10 +419,10 @@ class PythonSignature:
schema_formals.insert(0, "self") schema_formals.insert(0, "self")
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# only pyi uses vararg signatures # only pyi uses vararg signatures
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [ schema_formals: list[str] = [
a.argument_str_pyi(method=self.method) for a in args a.argument_str_pyi(method=self.method) for a in args
] ]
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
# [func call]: self.addmm(mat1, mat2, beta, 1) # [func call]: self.addmm(mat1, mat2, beta, 1)
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case. # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
deprecated_args_exprs: Tuple[str, ...] deprecated_args_exprs: tuple[str, ...]
@property @property
def deprecated(self) -> bool: def deprecated(self) -> bool:
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str: def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
args = self.arguments(skip_outputs=skip_outputs) args = self.arguments(skip_outputs=skip_outputs)
schema_formals: List[str] = [ schema_formals: list[str] = [
a.argument_str_pyi(method=self.method, deprecated=True) for a in args a.argument_str_pyi(method=self.method, deprecated=True) for a in args
] ]
positional_argc = len(self.input_args) positional_argc = len(self.input_args)
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
returns_str = returns_str_pyi(self) returns_str = returns_str_pyi(self)
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...' return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]: def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
# the codegen doesn't include vararg variants for deprecated signatures # the codegen doesn't include vararg variants for deprecated signatures
return None return None
@ -530,14 +530,14 @@ class PythonSignatureGroup:
base: NativeFunction base: NativeFunction
# The out variant (e.g. conv2d_out) # The out variant (e.g. conv2d_out)
outplace: Optional[NativeFunction] outplace: NativeFunction | None
@classmethod @classmethod
def from_pairs( def from_pairs(
cls, cls,
functional: PythonSignatureNativeFunctionPair, functional: PythonSignatureNativeFunctionPair,
out: Optional[PythonSignatureNativeFunctionPair], out: PythonSignatureNativeFunctionPair | None,
) -> "PythonSignatureGroup": ) -> PythonSignatureGroup:
if out is None: if out is None:
return PythonSignatureGroup( return PythonSignatureGroup(
signature=functional.signature, signature=functional.signature,
@ -716,7 +716,7 @@ def argument_type_str(
raise RuntimeError(f"unrecognized type {repr(t)}") raise RuntimeError(f"unrecognized type {repr(t)}")
def argument_type_size(t: Type) -> Optional[int]: def argument_type_size(t: Type) -> int | None:
l = t.is_list_like() l = t.is_list_like()
if l is not None and str(l.elem) != "bool": if l is not None and str(l.elem) != "bool":
return l.size return l.size
@ -750,11 +750,11 @@ def signature(
def signature_from_schema( def signature_from_schema(
func: FunctionSchema, func: FunctionSchema,
*, *,
category_override: Optional[str], category_override: str | None,
method: bool = False, method: bool = False,
pyi: bool = False, pyi: bool = False,
) -> PythonSignature: ) -> PythonSignature:
args: List[Argument] = [] args: list[Argument] = []
args.extend(func.arguments.pre_self_positional) args.extend(func.arguments.pre_self_positional)
# Skip SelfArgument if this is method. # Skip SelfArgument if this is method.
if not method and func.arguments.self_arg is not None: if not method and func.arguments.self_arg is not None:
@ -807,10 +807,10 @@ def signature_from_schema(
) )
is_dummy_function = category_override == "dummy" is_dummy_function = category_override == "dummy"
tensor_options_args: List[PythonArgument] = [] tensor_options_args: list[PythonArgument] = []
if (is_factory_function or is_like_or_new_function) and not is_dummy_function: if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
def topt_default_init(name: str) -> Optional[str]: def topt_default_init(name: str) -> str | None:
topt_args = func.arguments.tensor_options topt_args = func.arguments.tensor_options
if topt_args is None: if topt_args is None:
return None return None
@ -891,7 +891,7 @@ def signature_from_schema(
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]: def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
if len(returns) <= 1 or all(r.name is None for r in returns): if len(returns) <= 1 or all(r.name is None for r in returns):
return [] return []
else: else:
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
return argument_type_str_pyi(t) return argument_type_str_pyi(t)
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]: def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns] python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
structseq_name = signature.name structseq_name = signature.name
field_names = structseq_fieldnames(signature.returns.returns) field_names = structseq_fieldnames(signature.returns.returns)
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
def dispatch_lambda_args( def dispatch_lambda_args(
ps: PythonSignature, f: NativeFunction, symint: bool = True ps: PythonSignature, f: NativeFunction, symint: bool = True
) -> Tuple[DispatchLambdaArgument, ...]: ) -> tuple[DispatchLambdaArgument, ...]:
if isinstance(ps, PythonSignatureDeprecated): if isinstance(ps, PythonSignatureDeprecated):
schema = ps.deprecated_schema schema = ps.deprecated_schema
else: else:
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
method=False, method=False,
cpp_no_default_args=f.cpp_no_default_args, cpp_no_default_args=f.cpp_no_default_args,
) )
out_args: Set[str] = {a.name for a in schema.arguments.out} out_args: set[str] = {a.name for a in schema.arguments.out}
# Convert from cpp argument to lambda argument # Convert from cpp argument to lambda argument
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument: def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
def cpp_dispatch_exprs( def cpp_dispatch_exprs(
f: NativeFunction, f: NativeFunction,
*, *,
python_signature: Optional[PythonSignature] = None, python_signature: PythonSignature | None = None,
) -> Tuple[str, ...]: ) -> tuple[str, ...]:
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments() cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
exprs: Tuple[str, ...] = tuple() exprs: tuple[str, ...] = tuple()
if not isinstance(python_signature, PythonSignatureDeprecated): if not isinstance(python_signature, PythonSignatureDeprecated):
# By default the exprs are consistent with the C++ signature. # By default the exprs are consistent with the C++ signature.
exprs = tuple(a.name for a in cpp_args) exprs = tuple(a.name for a in cpp_args)
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
# For certain cases it is intentionally more restrictive than necessary, # For certain cases it is intentionally more restrictive than necessary,
# e.g.: it doesn't accepts doublelist with definite size. # e.g.: it doesn't accepts doublelist with definite size.
def arg_parser_unpack_method( def arg_parser_unpack_method(
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True t: Type, default: str | None, default_init: str | None, *, symint: bool = True
) -> str: ) -> str:
has_default_init = default_init is not None has_default_init = default_init is not None
if has_default_init and str(t) not in ( if has_default_init and str(t) not in (
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr. # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
def arg_parser_output_exprs( def arg_parser_output_exprs(
ps: PythonSignature, f: NativeFunction, *, symint: bool = True ps: PythonSignature, f: NativeFunction, *, symint: bool = True
) -> Dict[str, PythonArgParserOutputExpr]: ) -> dict[str, PythonArgParserOutputExpr]:
return { return {
e.name: e e.name: e
for i, a in enumerate(ps.arguments()) for i, a in enumerate(ps.arguments())
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
# outputs. # outputs.
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint) arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
lambda_args = dispatch_lambda_args(ps, f, symint=symint) lambda_args = dispatch_lambda_args(ps, f, symint=symint)
inits: List[str] = [] inits: list[str] = []
lambda_args_exprs: Dict[str, str] = {} lambda_args_exprs: dict[str, str] = {}
has_toptions = has_tensor_options(f) has_toptions = has_tensor_options(f)

View File

@ -1,4 +1,4 @@
from typing import List, Union from __future__ import annotations
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import ( from torchgen.api.types import (
@ -97,7 +97,7 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
# Structured kernels are never defaulted # Structured kernels are never defaulted
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]: def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
if isinstance(a, Argument): if isinstance(a, Argument):
return [ return [
Binding( Binding(
@ -115,15 +115,15 @@ def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[B
assert_never(a) assert_never(a)
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]: def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if g.out.precomputed: if g.out.precomputed:
# A list of parameters for the impl function with # A list of parameters for the impl function with
# certain parameters replaced with precomputed counterparts # certain parameters replaced with precomputed counterparts
# as specified in native_functions.yaml. # as specified in native_functions.yaml.
non_out_args_replaced: List[ non_out_args_replaced: list[
Union[Argument, TensorOptionsArguments, SelfArgument] Argument | TensorOptionsArguments | SelfArgument
] = [] ] = []
for a in g.out.func.arguments.non_out: for a in g.out.func.arguments.non_out:
if isinstance(a, Argument) and a.name in g.out.precomputed.replace: if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
@ -145,13 +145,13 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
return [r for arg in args for r in argument(arg)] return [r for arg in args for r in argument(arg)]
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]: def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.functional.func.arguments.non_out) args.extend(g.functional.func.arguments.non_out)
return [r for arg in args for r in argument(arg)] return [r for arg in args for r in argument(arg)]
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]: def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
args.extend(g.out.func.arguments.out) args.extend(g.out.func.arguments.out)
return [r for arg in args for r in argument(arg)] return [r for arg in args for r in argument(arg)]

View File

@ -1,4 +1,6 @@
from typing import Dict, List, NoReturn, Sequence, Union from __future__ import annotations
from typing import NoReturn, Sequence
from torchgen.api.types import ( from torchgen.api.types import (
ArrayRefCType, ArrayRefCType,
@ -95,13 +97,13 @@ class UnsatError(RuntimeError):
# something more complicated, e.g., tracking the set of bindings in a context, # something more complicated, e.g., tracking the set of bindings in a context,
# you may find using these smaller types more convenient. # you may find using these smaller types more convenient.
def translate( def translate(
bindings: Sequence[Union[Expr, Binding]], bindings: Sequence[Expr | Binding],
goals: Sequence[Union[NamedCType, Binding]], goals: Sequence[NamedCType | Binding],
*, *,
method: bool = False, method: bool = False,
allow_expensive_conversions: bool = False, allow_expensive_conversions: bool = False,
) -> List[Expr]: ) -> list[Expr]:
binding_exprs: List[Expr] = [] binding_exprs: list[Expr] = []
for b in bindings: for b in bindings:
if isinstance(b, Binding): if isinstance(b, Binding):
binding_exprs.append( binding_exprs.append(
@ -113,7 +115,7 @@ def translate(
else: else:
binding_exprs.append(b) binding_exprs.append(b)
goal_ctypes: List[NamedCType] = [] goal_ctypes: list[NamedCType] = []
for g in goals: for g in goals:
if isinstance(g, Binding): if isinstance(g, Binding):
goal_ctypes.append(g.nctype) goal_ctypes.append(g.nctype)
@ -121,7 +123,7 @@ def translate(
goal_ctypes.append(g) goal_ctypes.append(g)
# Add all the bindings to the context # Add all the bindings to the context
ctx: Dict[NamedCType, str] = {} ctx: dict[NamedCType, str] = {}
for b in binding_exprs: for b in binding_exprs:
ctx[b.type] = b.expr ctx[b.type] = b.expr

View File

@ -1,7 +1,12 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union from typing import Iterator, Sequence, TYPE_CHECKING
from torchgen.api.types.types_base import Binding, CType, Expr from torchgen.api.types.types_base import Binding, CType, Expr
if TYPE_CHECKING:
from torchgen.model import ( from torchgen.model import (
BackendIndex, BackendIndex,
FunctionSchema, FunctionSchema,
@ -38,7 +43,7 @@ class CppSignature:
symint: bool symint: bool
# The set of C++ arguments which should not have defaults applied to them # The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: Set[str] cpp_no_default_args: set[str]
# Is this a fallback C++ binding? Fallback bindings are enabled by # Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that # manual_cpp_binding: True and are alternate, non-public API that
@ -72,7 +77,7 @@ class CppSignature:
def decl( def decl(
self, self,
*, *,
name: Optional[str] = None, name: str | None = None,
prefix: str = "", prefix: str = "",
is_redispatching_fn: bool = False, is_redispatching_fn: bool = False,
suppress_symint_suffix: bool = False, suppress_symint_suffix: bool = False,
@ -93,7 +98,7 @@ class CppSignature:
def defn( def defn(
self, self,
*, *,
name: Optional[str] = None, name: str | None = None,
prefix: str = "", prefix: str = "",
is_redispatching_fn: bool = False, is_redispatching_fn: bool = False,
) -> str: ) -> str:
@ -126,9 +131,9 @@ class CppSignature:
class CppSignatureGroup: class CppSignatureGroup:
func: FunctionSchema func: FunctionSchema
signature: CppSignature signature: CppSignature
faithful_signature: Optional[CppSignature] faithful_signature: CppSignature | None
symint_signature: Optional[CppSignature] symint_signature: CppSignature | None
symint_faithful_signature: Optional[CppSignature] symint_faithful_signature: CppSignature | None
def most_faithful_signature(self) -> CppSignature: def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature: if self.faithful_signature:
@ -149,7 +154,7 @@ class CppSignatureGroup:
@staticmethod @staticmethod
def from_native_function( def from_native_function(
f: NativeFunction, *, method: bool, fallback_binding: bool = False f: NativeFunction, *, method: bool, fallback_binding: bool = False
) -> "CppSignatureGroup": ) -> CppSignatureGroup:
func = f.func func = f.func
def make_sig(*, faithful: bool, symint: bool) -> CppSignature: def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
@ -162,16 +167,16 @@ class CppSignatureGroup:
cpp_no_default_args=f.cpp_no_default_args, cpp_no_default_args=f.cpp_no_default_args,
) )
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]: def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
faithful_signature: Optional[CppSignature] = None faithful_signature: CppSignature | None = None
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0: if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = make_sig(faithful=True, symint=symint) faithful_signature = make_sig(faithful=True, symint=symint)
signature = make_sig(faithful=False, symint=symint) signature = make_sig(faithful=False, symint=symint)
return signature, faithful_signature return signature, faithful_signature
signature, faithful_signature = make_sigs(symint=False) signature, faithful_signature = make_sigs(symint=False)
symint_signature: Optional[CppSignature] = None symint_signature: CppSignature | None = None
symint_faithful_signature: Optional[CppSignature] = None symint_faithful_signature: CppSignature | None = None
if func.has_symint(): if func.has_symint():
symint_signature, symint_faithful_signature = make_sigs(symint=True) symint_signature, symint_faithful_signature = make_sigs(symint=True)
@ -196,20 +201,20 @@ class DispatcherSignature:
symint: bool = True symint: bool = True
def arguments(self) -> List[Binding]: def arguments(self) -> list[Binding]:
return dispatcher.arguments(self.func, symint=self.symint) return dispatcher.arguments(self.func, symint=self.symint)
def name(self) -> str: def name(self) -> str:
return self.prefix + dispatcher.name(self.func) return self.prefix + dispatcher.name(self.func)
def decl(self, name: Optional[str] = None) -> str: def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments()) args_str = ", ".join(a.decl() for a in self.arguments())
if name is None: if name is None:
name = self.name() name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})" return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn( def defn(
self, name: Optional[str] = None, *, is_redispatching_fn: bool = False self, name: str | None = None, *, is_redispatching_fn: bool = False
) -> str: ) -> str:
args = [a.defn() for a in self.arguments()] args = [a.defn() for a in self.arguments()]
if is_redispatching_fn: if is_redispatching_fn:
@ -219,7 +224,7 @@ class DispatcherSignature:
name = self.name() name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})" return f"{self.returns_type().cpp_type()} {name}({args_str})"
def exprs(self) -> List[Expr]: def exprs(self) -> list[Expr]:
return [Expr(a.name, a.nctype) for a in self.arguments()] return [Expr(a.name, a.nctype) for a in self.arguments()]
def returns_type(self) -> CType: def returns_type(self) -> CType:
@ -237,7 +242,7 @@ class DispatcherSignature:
@staticmethod @staticmethod
def from_schema( def from_schema(
func: FunctionSchema, *, prefix: str = "", symint: bool = True func: FunctionSchema, *, prefix: str = "", symint: bool = True
) -> "DispatcherSignature": ) -> DispatcherSignature:
return DispatcherSignature(func, prefix, symint) return DispatcherSignature(func, prefix, symint)
@ -253,13 +258,13 @@ class NativeSignature:
def name(self) -> str: def name(self) -> str:
return self.prefix + native.name(self.func) return self.prefix + native.name(self.func)
def decl(self, name: Optional[str] = None) -> str: def decl(self, name: str | None = None) -> str:
args_str = ", ".join(a.decl() for a in self.arguments()) args_str = ", ".join(a.decl() for a in self.arguments())
if name is None: if name is None:
name = self.name() name = self.name()
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})" return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
def defn(self, name: Optional[str] = None) -> str: def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments()) args_str = ", ".join(a.defn() for a in self.arguments())
if name is None: if name is None:
name = self.name() name = self.name()
@ -270,13 +275,13 @@ class NativeSignature:
args_str = ", ".join(a.defn() for a in self.arguments()) args_str = ", ".join(a.defn() for a in self.arguments())
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})" return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
def arguments(self) -> List[Binding]: def arguments(self) -> list[Binding]:
return native.arguments(self.func, symint=self.symint) return native.arguments(self.func, symint=self.symint)
def returns_type(self) -> CType: def returns_type(self) -> CType:
return native.returns_type(self.func.returns, symint=self.symint) return native.returns_type(self.func.returns, symint=self.symint)
def dispatcher_exprs(self) -> List[Expr]: def dispatcher_exprs(self) -> list[Expr]:
return translate.translate( return translate.translate(
self.arguments(), dispatcher.arguments(self.func), method=False self.arguments(), dispatcher.arguments(self.func), method=False
) )
@ -307,7 +312,7 @@ class FunctionalizationLambda:
# are we generating the forward lambda or the reverse lambda? # are we generating the forward lambda or the reverse lambda?
is_reverse: bool is_reverse: bool
def captures(self) -> List[Expr]: def captures(self) -> list[Expr]:
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments # The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed, # We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
# and plumb it into the lambda. # and plumb it into the lambda.
@ -336,7 +341,7 @@ class FunctionalizationLambda:
] ]
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}" return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
def inner_call(self, *, reapply_views: Optional[bool] = None) -> str: def inner_call(self, *, reapply_views: bool | None = None) -> str:
inner_call_name = functionalization.name( inner_call_name = functionalization.name(
self.g, self.g,
is_reverse=self.is_reverse, is_reverse=self.is_reverse,
@ -366,7 +371,7 @@ class FunctionalizationLambda:
@staticmethod @staticmethod
def from_func( def from_func(
g: NativeFunctionsViewGroup, *, is_reverse: bool g: NativeFunctionsViewGroup, *, is_reverse: bool
) -> "FunctionalizationLambda": ) -> FunctionalizationLambda:
return FunctionalizationLambda(g, is_reverse) return FunctionalizationLambda(g, is_reverse)
@ -375,11 +380,11 @@ class StructuredImplSignature:
g: NativeFunctionsGroup g: NativeFunctionsGroup
name: str name: str
def defn(self, name: Optional[str] = None) -> str: def defn(self, name: str | None = None) -> str:
args_str = ", ".join(a.defn() for a in self.arguments()) args_str = ", ".join(a.defn() for a in self.arguments())
return f"TORCH_IMPL_FUNC({self.name})({args_str})" return f"TORCH_IMPL_FUNC({self.name})({args_str})"
def arguments(self) -> List[Binding]: def arguments(self) -> list[Binding]:
return structured.impl_arguments(self.g) return structured.impl_arguments(self.g)
@ -388,7 +393,7 @@ class StructuredImplSignature:
def kernel_signature( def kernel_signature(
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = "" f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
) -> Union["NativeSignature", "DispatcherSignature"]: ) -> NativeSignature | DispatcherSignature:
# Note [External Backends Follow Dispatcher API] # Note [External Backends Follow Dispatcher API]
# Kernel signatures for in-tree backends follow the "native" API, # Kernel signatures for in-tree backends follow the "native" API,
# while kernels for out-of-tree backends follow the dispatcher API. # while kernels for out-of-tree backends follow the dispatcher API.

View File

@ -12,8 +12,10 @@ if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related. Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10. Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
""" """
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict
from torchgen.api.types.types_base import ( from torchgen.api.types.types_base import (
BaseCppType, BaseCppType,
@ -83,7 +85,7 @@ symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
scalar_t = BaseCppType("", "scalar_t") scalar_t = BaseCppType("", "scalar_t")
opmath_t = BaseCppType("", "opmath_t") opmath_t = BaseCppType("", "opmath_t")
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = { ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
ScalarType.Byte: byteT, ScalarType.Byte: byteT,
ScalarType.Char: charT, ScalarType.Char: charT,
ScalarType.Short: shortT, ScalarType.Short: shortT,
@ -102,7 +104,7 @@ ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT, ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
} }
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT, BaseTy.int: longT,
BaseTy.float: doubleT, BaseTy.float: doubleT,
BaseTy.bool: boolT, BaseTy.bool: boolT,
@ -128,7 +130,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
@dataclass(frozen=True) @dataclass(frozen=True)
class OptionalCType(CType): class OptionalCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -137,13 +139,13 @@ class OptionalCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>" return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref()) return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True) @dataclass(frozen=True)
class ListCType(CType): class ListCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -152,13 +154,13 @@ class ListCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>" return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return ListCType(self.elem.remove_const_ref()) return ListCType(self.elem.remove_const_ref())
@dataclass(frozen=True) @dataclass(frozen=True)
class ArrayRefCType(CType): class ArrayRefCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -167,7 +169,7 @@ class ArrayRefCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>" return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref()) return ArrayRefCType(self.elem.remove_const_ref())
@ -185,5 +187,5 @@ class VectorizedCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
raise NotImplementedError raise NotImplementedError
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return self return self

View File

@ -12,11 +12,16 @@ if we want to generate code for another C++ library.
Add new types to `types.py` if these types are ATen/c10 related. Add new types to `types.py` if these types are ATen/c10 related.
Add new types to `types_base.py` if they are basic and not attached to ATen/c10. Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
""" """
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import List, Optional, Union from typing import TYPE_CHECKING, Union
if TYPE_CHECKING:
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
@ -36,7 +41,7 @@ ArgName = Union[str, SpecialArgName]
# This class shouldn't be created directly; instead, use/create one of the singletons below. # This class shouldn't be created directly; instead, use/create one of the singletons below.
@dataclass(frozen=True) @dataclass(frozen=True)
class BaseCppType: class BaseCppType:
ns: Optional[str] ns: str | None
name: str name: str
def __str__(self) -> str: def __str__(self) -> str:
@ -71,7 +76,7 @@ class CType(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return self return self
@ -87,13 +92,13 @@ class BaseCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return str(self.type).replace("at::", "") return str(self.type).replace("at::", "")
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return self return self
@dataclass(frozen=True) @dataclass(frozen=True)
class ConstRefCType(CType): class ConstRefCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref: if strip_ref:
@ -103,13 +108,13 @@ class ConstRefCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"const {self.elem.cpp_type_registration_declarations()} &" return f"const {self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref() return self.elem.remove_const_ref()
@dataclass(frozen=True) @dataclass(frozen=True)
class VectorCType(CType): class VectorCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -118,13 +123,13 @@ class VectorCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>" return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return VectorCType(self.elem.remove_const_ref()) return VectorCType(self.elem.remove_const_ref())
@dataclass(frozen=True) @dataclass(frozen=True)
class ArrayCType(CType): class ArrayCType(CType):
elem: "CType" elem: CType
size: int size: int
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
@ -134,13 +139,13 @@ class ArrayCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>" return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return ArrayCType(self.elem.remove_const_ref(), self.size) return ArrayCType(self.elem.remove_const_ref(), self.size)
@dataclass(frozen=True) @dataclass(frozen=True)
class TupleCType(CType): class TupleCType(CType):
elems: List["CType"] elems: list[CType]
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -149,13 +154,13 @@ class TupleCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>' return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return TupleCType([e.remove_const_ref() for e in self.elems]) return TupleCType([e.remove_const_ref() for e in self.elems])
@dataclass(frozen=True) @dataclass(frozen=True)
class MutRefCType(CType): class MutRefCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
if strip_ref: if strip_ref:
@ -165,7 +170,7 @@ class MutRefCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"{self.elem.cpp_type_registration_declarations()} &" return f"{self.elem.cpp_type_registration_declarations()} &"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return self.elem.remove_const_ref() return self.elem.remove_const_ref()
@ -190,10 +195,10 @@ class NamedCType:
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return self.type.cpp_type_registration_declarations() return self.type.cpp_type_registration_declarations()
def remove_const_ref(self) -> "NamedCType": def remove_const_ref(self) -> NamedCType:
return NamedCType(self.name, self.type.remove_const_ref()) return NamedCType(self.name, self.type.remove_const_ref())
def with_name(self, name: str) -> "NamedCType": def with_name(self, name: str) -> NamedCType:
return NamedCType(name, self.type) return NamedCType(name, self.type)
@ -208,11 +213,11 @@ class NamedCType:
class Binding: class Binding:
name: str name: str
nctype: NamedCType nctype: NamedCType
argument: Union[Argument, TensorOptionsArguments, SelfArgument] argument: Argument | TensorOptionsArguments | SelfArgument
# TODO: maybe don't represent default here # TODO: maybe don't represent default here
default: Optional[str] = None default: str | None = None
def rename(self, name: str) -> "Binding": def rename(self, name: str) -> Binding:
return Binding( return Binding(
name=name, name=name,
nctype=self.nctype, nctype=self.nctype,
@ -224,7 +229,7 @@ class Binding:
def type(self) -> str: def type(self) -> str:
return self.nctype.cpp_type() return self.nctype.cpp_type()
def no_default(self) -> "Binding": def no_default(self) -> Binding:
return Binding( return Binding(
name=self.name, name=self.name,
nctype=self.nctype, nctype=self.nctype,
@ -255,7 +260,7 @@ class Binding:
def defn(self) -> str: def defn(self) -> str:
return f"{self.type} {self.name}" return f"{self.type} {self.name}"
def with_name(self, name: str) -> "Binding": def with_name(self, name: str) -> Binding:
return Binding( return Binding(
name=name, nctype=self.nctype, argument=self.argument, default=self.default name=name, nctype=self.nctype, argument=self.argument, default=self.default
) )

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional
import torchgen.api.types as api_types import torchgen.api.types as api_types
from torchgen.api import cpp, structured from torchgen.api import cpp, structured
@ -38,7 +39,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
# argument registers) # argument registers)
# #
# NB: used for CPU only # NB: used for CPU only
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]: def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
# Dispatch stubs are always plain ints # Dispatch stubs are always plain ints
r = cpp.valuetype_type(t, binds=binds, symint=False) r = cpp.valuetype_type(t, binds=binds, symint=False)
if r is not None: if r is not None:
@ -134,8 +135,8 @@ def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
@dataclass(frozen=True) @dataclass(frozen=True)
class UfunctorBindings: class UfunctorBindings:
ctor: List[Binding] ctor: list[Binding]
apply: List[Binding] apply: list[Binding]
# ufunctors are a CUDA-only concept representing functors that take some of # ufunctors are a CUDA-only concept representing functors that take some of
@ -156,7 +157,7 @@ class UfunctorBindings:
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers # The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
# to the operator() definition # to the operator() definition
def ufunctor_arguments( def ufunctor_arguments(
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
) -> UfunctorBindings: ) -> UfunctorBindings:
ctor = [] ctor = []
apply = [] apply = []
@ -185,7 +186,7 @@ def ufunctor_arguments(
# } # }
# #
# In this file, we refer to T as compute_t which is bound by caller # In this file, we refer to T as compute_t which is bound by caller
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]: def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
return [ return [
ufunc_argument(a, compute_t=compute_t) ufunc_argument(a, compute_t=compute_t)
for a in g.functional.func.arguments.flat_non_out for a in g.functional.func.arguments.flat_non_out
@ -197,7 +198,7 @@ def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Bindin
# #
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha); # using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub); # DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]: def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
# stubs drop all tensor arguments (they are implicit in the TensorIterator # stubs drop all tensor arguments (they are implicit in the TensorIterator
# argument and keep everything else) # argument and keep everything else)
return [ return [

View File

@ -1,4 +1,4 @@
from typing import List, Tuple from __future__ import annotations
from torchgen.api import cpp from torchgen.api import cpp
from torchgen.api.types import Binding, CppSignatureGroup, CType from torchgen.api.types import Binding, CppSignatureGroup, CType
@ -103,7 +103,7 @@ def name(f: NativeFunction) -> str:
# Convert all the arguments in a NativeFunction to C++ code # Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]: def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
# we need the 'self' argument so method needs to be False # we need the 'self' argument so method needs to be False
args = ( args = (
CppSignatureGroup.from_native_function(f, method=False) CppSignatureGroup.from_native_function(f, method=False)
@ -138,7 +138,7 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType # (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
def argumenttype_ivalue_convert( def argumenttype_ivalue_convert(
t: Type, arg_name: str, *, mutable: bool = False t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]: ) -> tuple[str, CType, list[str], list[str]]:
# Unboxing is for mobile, which doesn't care about SymInts # Unboxing is for mobile, which doesn't care about SymInts
ctype = cpp.argumenttype_type( ctype = cpp.argumenttype_type(
t=t, mutable=mutable, binds=arg_name, symint=False t=t, mutable=mutable, binds=arg_name, symint=False
@ -172,7 +172,7 @@ def argumenttype_ivalue_convert(
def _gen_code_base_type( def _gen_code_base_type(
arg_name: str, out_name: str, ctype: CType arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
return [ return [
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], [] ], []
@ -180,7 +180,7 @@ def _gen_code_base_type(
def _gen_code_optional_type( def _gen_code_optional_type(
arg_name: str, out_name: str, t: OptionalType, ctype: CType arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in" in_name = f"{arg_name}_opt_in"
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
return ( return (
@ -203,7 +203,7 @@ if ({arg_name}_opt.has_value()) {{
def _gen_code_list_type( def _gen_code_list_type(
arg_name: str, out_name: str, t: ListType, ctype: CType arg_name: str, out_name: str, t: ListType, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in" in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem" elem_name = f"{arg_name}_elem"
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"] code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import re import re
from typing import Mapping, Match, Optional, Sequence from typing import Mapping, Sequence
# match $identifier or ${identifier} and replace with value in env # match $identifier or ${identifier} and replace with value in env
@ -20,7 +22,7 @@ class CodeTemplate:
filename: str filename: str
@staticmethod @staticmethod
def from_file(filename: str) -> "CodeTemplate": def from_file(filename: str) -> CodeTemplate:
with open(filename) as f: with open(filename) as f:
return CodeTemplate(f.read(), filename) return CodeTemplate(f.read(), filename)
@ -29,7 +31,7 @@ class CodeTemplate:
self.filename = filename self.filename = filename
def substitute( def substitute(
self, env: Optional[Mapping[str, object]] = None, **kwargs: object self, env: Mapping[str, object] | None = None, **kwargs: object
) -> str: ) -> str:
if env is None: if env is None:
env = {} env = {}
@ -43,7 +45,7 @@ class CodeTemplate:
[indent + l + "\n" for e in v for l in str(e).splitlines()] [indent + l + "\n" for e in v for l in str(e).splitlines()]
).rstrip() ).rstrip()
def replace(match: Match[str]) -> str: def replace(match: re.Match[str]) -> str:
indent = match.group(1) indent = match.group(1)
key = match.group(2) key = match.group(2)
comma_before = "" comma_before = ""

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import contextlib import contextlib
import functools import functools
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
import torchgen.local as local import torchgen.local as local
from torchgen.model import ( from torchgen.model import (
@ -38,7 +40,7 @@ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
@contextlib.contextmanager @contextlib.contextmanager
def native_function_manager( def native_function_manager(
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction] g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
) -> Iterator[None]: ) -> Iterator[None]:
if isinstance(g, NativeFunctionsGroup): if isinstance(g, NativeFunctionsGroup):
# By default, we associate all errors with structured native functions # By default, we associate all errors with structured native functions
@ -118,10 +120,10 @@ def with_native_function_and_index(
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices # Convenience decorator for functions that explicitly take in a Dict of BackendIndices
def with_native_function_and_indices( def with_native_function_and_indices(
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T] func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]: ) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
@functools.wraps(func) @functools.wraps(func)
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T: def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
with native_function_manager(f): with native_function_manager(f):
return func(f, backend_indices) return func(f, backend_indices)

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import itertools import itertools
from abc import ABC from abc import ABC
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any
import torchgen.api.dispatcher as dispatcher import torchgen.api.dispatcher as dispatcher
from torchgen.api.lazy import ( from torchgen.api.lazy import (
@ -109,7 +111,7 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str:
def gen_fallback_code( def gen_fallback_code(
schema: LazyIrSchema, schema: LazyIrSchema,
sig: Union[DispatcherSignature, NativeSignature], sig: DispatcherSignature | NativeSignature,
overload_name: str, overload_name: str,
) -> str: ) -> str:
""" """
@ -147,9 +149,9 @@ def aten_symbol(schema: LazyIrSchema) -> str:
# converts all tensor-like arguments to meta tensors. Returns: # converts all tensor-like arguments to meta tensors. Returns:
# (1) a string containing all of the logic that does the conversions. # (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings. # (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: List[Binding] = [] context: list[Binding] = []
unwrapped_tensor_args: List[str] = [] unwrapped_tensor_args: list[str] = []
for arg in sig.arguments(): for arg in sig.arguments():
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like(): if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
unwrapped_name = f"{arg.name}_meta" unwrapped_name = f"{arg.name}_meta"
@ -171,7 +173,7 @@ class GenLazyIR(ABC):
use_lazy_shape: bool use_lazy_shape: bool
@method_with_native_function @method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
metadata = self.backend_index.get_kernel( metadata = self.backend_index.get_kernel(
f.functional if isinstance(f, NativeFunctionsGroup) else f f.functional if isinstance(f, NativeFunctionsGroup) else f
@ -236,7 +238,7 @@ class GenLazyIR(ABC):
/* num_outputs */ {len(schema.returns)}, /* num_outputs */ {len(schema.returns)},
torch::lazy::MHash({scalar_hashes}))""" torch::lazy::MHash({scalar_hashes}))"""
def gen(self, schema: LazyIrSchema) -> List[str]: def gen(self, schema: LazyIrSchema) -> list[str]:
opkind = schema.opkind or aten_symbol(schema) opkind = schema.opkind or aten_symbol(schema)
# for now, we just want one IR class decl and soon after also the method defs # for now, we just want one IR class decl and soon after also the method defs
@ -413,7 +415,7 @@ class GenLazyNativeFuncDefinition:
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str: def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
value_args = schema.filtered_args(values=True, scalars=False) value_args = schema.filtered_args(values=True, scalars=False)
# Generates lazy_{name} variables for LazyTensors wrapping input tensors # Generates lazy_{name} variables for LazyTensors wrapping input tensors
lazy_tensor_decls: List[str] = [] lazy_tensor_decls: list[str] = []
for arg in value_args: for arg in value_args:
if arg.is_wrapped_scalar: if arg.is_wrapped_scalar:
if isinstance(arg.lazy_type, OptionalCType): if isinstance(arg.lazy_type, OptionalCType):
@ -460,7 +462,7 @@ class GenLazyNativeFuncDefinition:
func: NativeFunction, func: NativeFunction,
schema: LazyIrSchema, schema: LazyIrSchema,
metadata: BackendMetadata, metadata: BackendMetadata,
sig: Union[DispatcherSignature, NativeSignature], sig: DispatcherSignature | NativeSignature,
) -> str: ) -> str:
if self.gen_forced_fallback_code: if self.gen_forced_fallback_code:
return gen_fallback_code( return gen_fallback_code(
@ -574,7 +576,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
}} }}
""" """
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str: def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
# xla uses an instance method for tensor creation, for the time being # xla uses an instance method for tensor creation, for the time being
if self.create_from_first_tensor: if self.create_from_first_tensor:
# TODO(whc) remove this if XLA switches to using static method for creation # TODO(whc) remove this if XLA switches to using static method for creation
@ -615,7 +617,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
return bridge_str return bridge_str
@method_with_native_function @method_with_native_function
def __call__(self, func: NativeFunction) -> List[str]: def __call__(self, func: NativeFunction) -> list[str]:
sig = kernel_signature(func, self.backend_index) sig = kernel_signature(func, self.backend_index)
metadata = self.backend_index.get_kernel(func) metadata = self.backend_index.get_kernel(func)
assert metadata is not None assert metadata is not None
@ -639,7 +641,7 @@ class ComputeShapeSignature:
Here we use the base name as the suffix of the signature to avoid generating for in-place variants. Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
""" """
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool): def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
self.__schema = LazyIrSchema(f.func, symint=symint) self.__schema = LazyIrSchema(f.func, symint=symint)
self.__dispatch_args = ", ".join( self.__dispatch_args = ", ".join(
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)] [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
@ -670,7 +672,7 @@ class GenLazyShapeInferenceDefinition:
tensor_class: str tensor_class: str
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> List[str]: def __call__(self, f: NativeFunction) -> list[str]:
metadata = self.backend_index.get_kernel(f) metadata = self.backend_index.get_kernel(f)
assert metadata is not None assert metadata is not None
@ -687,8 +689,8 @@ class GenLazyShapeInferenceDefinition:
def generate_non_native_lazy_ir_nodes( def generate_non_native_lazy_ir_nodes(
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
) -> List[str]: ) -> list[str]:
"""Generate the non-native lazy IR node classes""" """Generate the non-native lazy IR node classes"""
nodes = [] nodes = []
for op in non_native: for op in non_native:

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Union from __future__ import annotations
import torchgen.api.meta as meta import torchgen.api.meta as meta
import torchgen.api.structured as structured import torchgen.api.structured as structured
@ -9,7 +9,7 @@ from torchgen.utils import mapMaybe
@with_native_function_and_index @with_native_function_and_index
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]: def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
sig = kernel_signature(f, backend_index) sig = kernel_signature(f, backend_index)
metadata = backend_index.get_kernel(f) metadata = backend_index.get_kernel(f)
if metadata is None: if metadata is None:
@ -22,7 +22,7 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional
@with_native_function_and_index @with_native_function_and_index
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]: def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
meta_name = meta.name(g) meta_name = meta.name(g)
out_args = structured.impl_arguments(g) out_args = structured.impl_arguments(g)
metadata = backend_index.get_kernel(g) metadata = backend_index.get_kernel(g)
@ -42,8 +42,8 @@ void impl({', '.join(a.decl() for a in out_args)});
# actual kernel definitions we keep in aten/src/ATen/native/ # actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function_and_index @with_native_function_and_index
def compute_native_function_declaration( def compute_native_function_declaration(
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
) -> List[str]: ) -> list[str]:
metadata = backend_index.get_kernel(g) metadata = backend_index.get_kernel(g)
if isinstance(g, NativeFunctionsGroup): if isinstance(g, NativeFunctionsGroup):
if metadata is not None and metadata.structured: if metadata is not None and metadata.structured:

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import itertools import itertools
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Literal, Optional, Tuple, Union from typing import Literal, TYPE_CHECKING
import torchgen.api.cpp as cpp import torchgen.api.cpp as cpp
import torchgen.api.meta as meta import torchgen.api.meta as meta
@ -34,15 +36,18 @@ from torchgen.model import (
SchemaKind, SchemaKind,
TensorOptionsArguments, TensorOptionsArguments,
) )
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import assert_never, mapMaybe, Target from torchgen.utils import assert_never, mapMaybe, Target
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
def gen_registration_headers( def gen_registration_headers(
backend_index: BackendIndex, backend_index: BackendIndex,
per_operator_headers: bool, per_operator_headers: bool,
rocm: bool, rocm: bool,
) -> List[str]: ) -> list[str]:
if per_operator_headers: if per_operator_headers:
headers = ["#include <ATen/ops/as_strided_native.h>"] headers = ["#include <ATen/ops/as_strided_native.h>"]
else: else:
@ -73,7 +78,7 @@ def gen_registration_headers(
def gen_empty_impl_names( def gen_empty_impl_names(
backend_index: BackendIndex, backend_index: BackendIndex,
) -> Tuple[Optional[str], Optional[str]]: ) -> tuple[str | None, str | None]:
empty_impl = None empty_impl = None
empty_strided_impl = None empty_strided_impl = None
@ -97,7 +102,7 @@ def gen_empty_impl_names(
return empty_impl, empty_strided_impl return empty_impl, empty_strided_impl
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]: def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
if backend_index.dispatch_key == DispatchKey.Meta: if backend_index.dispatch_key == DispatchKey.Meta:
empty_options = "options.device(at::kMeta)" empty_options = "options.device(at::kMeta)"
else: else:
@ -120,7 +125,7 @@ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &o
] ]
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]: def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
_, empty_strided_impl = gen_empty_impl_names(backend_index) _, empty_strided_impl = gen_empty_impl_names(backend_index)
return ( return (
[] []
@ -138,7 +143,7 @@ std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I
) )
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]: def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
# The function isn't used by this key (since only functional ops have a kernel for this key), # The function isn't used by this key (since only functional ops have a kernel for this key),
# so we need to not include it to avoid a defined-but-not-used error. # so we need to not include it to avoid a defined-but-not-used error.
@ -168,7 +173,7 @@ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const
] ]
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]: def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
return [ return [
""" """
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) { void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
@ -191,7 +196,7 @@ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &o
] ]
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]: def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
return [ return [
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")', 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
*gen_create_out_helper(backend_index), *gen_create_out_helper(backend_index),
@ -249,7 +254,7 @@ class RegisterDispatchKey:
# Finally, this field is currently Optional because it is only used by external backends. # Finally, this field is currently Optional because it is only used by external backends.
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
# all of the existing kernel signatures scattered across aten/src/ATen/native. # all of the existing kernel signatures scattered across aten/src/ATen/native.
class_method_name: Optional[str] class_method_name: str | None
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher. # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
@ -257,7 +262,7 @@ class RegisterDispatchKey:
@staticmethod @staticmethod
def gen_device_check( def gen_device_check(
type: DeviceCheckType, args: List[Argument], method_name: str type: DeviceCheckType, args: list[Argument], method_name: str
) -> str: ) -> str:
if type == DeviceCheckType.NoCheck: if type == DeviceCheckType.NoCheck:
return " // No device check\n" return " // No device check\n"
@ -272,7 +277,7 @@ class RegisterDispatchKey:
return device_check return device_check
@method_with_native_function @method_with_native_function
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
if isinstance(f, NativeFunctionsGroup): if isinstance(f, NativeFunctionsGroup):
g: NativeFunctionsGroup = f g: NativeFunctionsGroup = f
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend. # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
@ -291,7 +296,7 @@ class RegisterDispatchKey:
def wrapper_kernel_sig( def wrapper_kernel_sig(
self, f: NativeFunction self, f: NativeFunction
) -> Union[NativeSignature, DispatcherSignature]: ) -> NativeSignature | DispatcherSignature:
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names. # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
return DispatcherSignature.from_schema( return DispatcherSignature.from_schema(
f.func, f.func,
@ -300,8 +305,8 @@ class RegisterDispatchKey:
) )
def gen_out_inplace_wrapper( def gen_out_inplace_wrapper(
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] self, f: NativeFunction, g: NativeFunctionsGroup | None
) -> Optional[str]: ) -> str | None:
if g is None: if g is None:
return None return None
k = f.func.kind() k = f.func.kind()
@ -350,7 +355,7 @@ class RegisterDispatchKey:
}} }}
""" """
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]: def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
metadata = self.backend_index.get_kernel(g) metadata = self.backend_index.get_kernel(g)
if self.backend_index.dispatch_key == DispatchKey.Meta: if self.backend_index.dispatch_key == DispatchKey.Meta:
assert not self.backend_index.has_kernel(g.out), ( assert not self.backend_index.has_kernel(g.out), (
@ -380,8 +385,8 @@ class RegisterDispatchKey:
return list(mapMaybe(structured_gen.gen_one, g.functions())) return list(mapMaybe(structured_gen.gen_one, g.functions()))
def gen_unstructured( def gen_unstructured(
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None self, f: NativeFunction, g: NativeFunctionsGroup | None = None
) -> Optional[str]: ) -> str | None:
with native_function_manager(f): with native_function_manager(f):
inplace_meta = False inplace_meta = False
gets_out_inplace_wrapper = False gets_out_inplace_wrapper = False
@ -732,7 +737,7 @@ resize_out(out, sizes, strides, options);
return "\n".join(line for line in lines if line) return "\n".join(line for line in lines if line)
@method_with_native_function @method_with_native_function
def gen_one(self, f: NativeFunction) -> Optional[str]: def gen_one(self, f: NativeFunction) -> str | None:
assert not f.manual_kernel_registration assert not f.manual_kernel_registration
if ( if (
@ -806,7 +811,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
sig_body = [] sig_body = []
# We'll use context to keep track of any variables we've brought # We'll use context to keep track of any variables we've brought
# into scope while generating code # into scope while generating code
context: List[Union[Binding, Expr]] = list(sig.arguments()) context: list[Binding | Expr] = list(sig.arguments())
# Initialize the class corresponding to this structured # Initialize the class corresponding to this structured
# operator; feeding it the output argument(s) if it is known # operator; feeding it the output argument(s) if it is known

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Sequence, TYPE_CHECKING
import torchgen.api.ufunc as ufunc import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate from torchgen.api.translate import translate
@ -14,7 +16,6 @@ from torchgen.api.types import (
StructuredImplSignature, StructuredImplSignature,
VectorizedCType, VectorizedCType,
) )
from torchgen.api.ufunc import UfunctorBindings
from torchgen.context import with_native_function from torchgen.context import with_native_function
from torchgen.model import ( from torchgen.model import (
Argument, Argument,
@ -28,6 +29,10 @@ from torchgen.model import (
from torchgen.utils import OrderedSet from torchgen.utils import OrderedSet
if TYPE_CHECKING:
from torchgen.api.ufunc import UfunctorBindings
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# #
# CUDA STUFF # CUDA STUFF
@ -60,7 +65,7 @@ from torchgen.utils import OrderedSet
@dataclass(frozen=True) @dataclass(frozen=True)
class UfunctorSignature: class UfunctorSignature:
g: NativeFunctionsGroup g: NativeFunctionsGroup
scalar_tensor_idx: Optional[int] scalar_tensor_idx: int | None
name: str name: str
def arguments(self) -> UfunctorBindings: def arguments(self) -> UfunctorBindings:
@ -68,7 +73,7 @@ class UfunctorSignature:
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
) )
def fields(self) -> List[Binding]: def fields(self) -> list[Binding]:
# fields are renamed to have a trailing underscore, as is conventional # fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor] return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
@ -98,10 +103,10 @@ class UfuncSignature:
name: str name: str
compute_t: CType compute_t: CType
def arguments(self) -> List[Binding]: def arguments(self) -> list[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str: def call(self, ctx: Sequence[Binding | Expr]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@ -132,10 +137,10 @@ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
def compute_ufunc_cuda_functors( def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup, g: NativeFunctionsGroup,
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]: ) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors. # First, build the functors.
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {} ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: List[str] = [] ufunctors: list[str] = []
loops = g.out.ufunc_inner_loop loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = { scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1, UfuncKey.CUDAFunctorOnSelf: 1,
@ -237,7 +242,7 @@ BinaryScalarSpecializationConfigs = [
def compute_ufunc_cuda_dtype_body( def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup, g: NativeFunctionsGroup,
dtype: ScalarType, dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfunctorSignature], inner_loops: dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding], parent_ctx: Sequence[Binding],
) -> str: ) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;" body = "using opmath_t = at::opmath_type<scalar_t>;"
@ -249,7 +254,7 @@ def compute_ufunc_cuda_dtype_body(
scalar_idx = config.scalar_idx + 1 scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible # Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway) # without copy; we don't want to mutate the input argument anyway)
ctx: List[Union[Expr, Binding]] = list(parent_ctx) ctx: list[Expr | Binding] = list(parent_ctx)
ctx.append( ctx.append(
Expr( Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})", expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
@ -346,7 +351,7 @@ class StubSignature:
def type_name(self) -> str: def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn" return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> List[Binding]: def arguments(self) -> list[Binding]:
return ufunc.stub_arguments(self.g) return ufunc.stub_arguments(self.g)
def type(self) -> str: def type(self) -> str:
@ -393,7 +398,7 @@ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
def compute_ufunc_cpu_dtype_body( def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup, g: NativeFunctionsGroup,
dtype: ScalarType, dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfuncSignature], inner_loops: dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding], parent_ctx: Sequence[Binding],
) -> str: ) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
@ -459,8 +464,8 @@ def compute_ufunc_cpu_dtype_body(
) )
) )
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]: def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
r: List[Union[Expr, Binding]] = [] r: list[Expr | Binding] = []
r.extend(ctx) r.extend(ctx)
r.extend(b) r.extend(b)
return r return r
@ -489,7 +494,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
# Reindex the ufunc by dtypes; processing generic/scalaronly as well # Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop loops = g.out.ufunc_inner_loop
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {} ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = [] lks = []
# ORDER MATTERS: this specifies overriding precedence # ORDER MATTERS: this specifies overriding precedence

View File

@ -1,24 +1,29 @@
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple from typing import Sequence, TYPE_CHECKING
from torchgen import dest from torchgen import dest
# disable import sorting to avoid circular dependency. # disable import sorting to avoid circular dependency.
from torchgen.api.types import DispatcherSignature # usort:skip from torchgen.api.types import DispatcherSignature # usort:skip
from torchgen.context import method_with_native_function from torchgen.context import method_with_native_function
from torchgen.executorch.model import ETKernelIndex
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import concatMap, Target from torchgen.utils import concatMap, Target
if TYPE_CHECKING:
from torchgen.executorch.model import ETKernelIndex
from torchgen.selective_build.selector import SelectiveBuilder
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
# model authoring side. # model authoring side.
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeNativeFunctionStub: class ComputeNativeFunctionStub:
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
if Variant.function not in f.variants: if Variant.function not in f.variants:
return None return None
@ -80,7 +85,7 @@ def gen_custom_ops_registration(
selector: SelectiveBuilder, selector: SelectiveBuilder,
kernel_index: ETKernelIndex, kernel_index: ETKernelIndex,
rocm: bool, rocm: bool,
) -> Tuple[str, str]: ) -> tuple[str, str]:
""" """
Generate custom ops registration code for dest.RegisterDispatchKey. Generate custom ops registration code for dest.RegisterDispatchKey.
@ -97,7 +102,7 @@ def gen_custom_ops_registration(
dispatch_key = DispatchKey.CPU dispatch_key = DispatchKey.CPU
backend_index = kernel_index._to_backend_index() backend_index = kernel_index._to_backend_index()
static_init_dispatch_registrations = "" static_init_dispatch_registrations = ""
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
for native_function in native_functions: for native_function in native_functions:
ns_grouped_native_functions[native_function.namespace].append(native_function) ns_grouped_native_functions[native_function.namespace].append(native_function)

View File

@ -1,4 +1,6 @@
from typing import List, Optional, Sequence, Set, Union from __future__ import annotations
from typing import Sequence
from torchgen import local from torchgen import local
from torchgen.api.types import ( from torchgen.api.types import (
@ -63,7 +65,7 @@ def valuetype_type(
*, *,
binds: ArgName, binds: ArgName,
remove_non_owning_ref_types: bool = False, remove_non_owning_ref_types: bool = False,
) -> Optional[NamedCType]: ) -> NamedCType | None:
if isinstance(t, BaseType): if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None return None
@ -209,7 +211,7 @@ def returns_type(rs: Sequence[Return]) -> CType:
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: List[str] = [] returns: list[str] = []
for i, r in enumerate(f.func.returns): for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is # If we have an inplace function, the return argument is
# implicitly named self. # implicitly named self.
@ -295,16 +297,16 @@ def default_expr(d: str, t: Type) -> str:
def argument( def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument], a: Argument | TensorOptionsArguments | SelfArgument,
*, *,
cpp_no_default_args: Set[str], cpp_no_default_args: set[str],
method: bool, method: bool,
faithful: bool, faithful: bool,
has_tensor_options: bool, has_tensor_options: bool,
) -> List[Binding]: ) -> list[Binding]:
def sub_argument( def sub_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument] a: Argument | TensorOptionsArguments | SelfArgument,
) -> List[Binding]: ) -> list[Binding]:
return argument( return argument(
a, a,
cpp_no_default_args=cpp_no_default_args, cpp_no_default_args=cpp_no_default_args,
@ -319,7 +321,7 @@ def argument(
binds = SpecialArgName.possibly_redundant_memory_format binds = SpecialArgName.possibly_redundant_memory_format
else: else:
binds = a.name binds = a.name
default: Optional[str] = None default: str | None = None
if a.name not in cpp_no_default_args and a.default is not None: if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type) default = default_expr(a.default, a.type)
return [ return [
@ -347,9 +349,9 @@ def arguments(
*, *,
faithful: bool, faithful: bool,
method: bool, method: bool,
cpp_no_default_args: Set[str], cpp_no_default_args: set[str],
) -> List[Binding]: ) -> list[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] args: list[Argument | TensorOptionsArguments | SelfArgument] = []
if faithful: if faithful:
args.extend(arguments.non_out) args.extend(arguments.non_out)
args.extend(arguments.out) args.extend(arguments.out)

View File

@ -1,9 +1,14 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Set from typing import TYPE_CHECKING
import torchgen.api.cpp as aten_cpp import torchgen.api.cpp as aten_cpp
from torchgen.api.types import Binding, CType
from torchgen.executorch.api.types.types import contextArg from torchgen.executorch.api.types.types import contextArg
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType
from torchgen.model import FunctionSchema, NativeFunction from torchgen.model import FunctionSchema, NativeFunction
@ -20,14 +25,14 @@ class ExecutorchCppSignature:
func: FunctionSchema func: FunctionSchema
# The set of C++ arguments which should not have defaults applied to them # The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: Set[str] cpp_no_default_args: set[str]
# Allows you to prepend an arbitrary prefix to the signature name. # Allows you to prepend an arbitrary prefix to the signature name.
# This is useful for parts of the codegen that generate wrappers around kernels, # This is useful for parts of the codegen that generate wrappers around kernels,
# and need to avoid naming collisions. # and need to avoid naming collisions.
prefix: str = "" prefix: str = ""
def arguments(self, *, include_context: bool = True) -> List[Binding]: def arguments(self, *, include_context: bool = True) -> list[Binding]:
return ([contextArg] if include_context else []) + et_cpp.arguments( return ([contextArg] if include_context else []) + et_cpp.arguments(
self.func.arguments, self.func.arguments,
faithful=True, # always faithful, out argument at the end faithful=True, # always faithful, out argument at the end
@ -41,7 +46,7 @@ class ExecutorchCppSignature:
faithful_name_for_out_overloads=True, faithful_name_for_out_overloads=True,
) )
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str: def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
args_str = ", ".join( args_str = ", ".join(
a.decl() for a in self.arguments(include_context=include_context) a.decl() for a in self.arguments(include_context=include_context)
) )
@ -49,7 +54,7 @@ class ExecutorchCppSignature:
name = self.name() name = self.name()
return f"{self.returns_type().cpp_type()} {name}({args_str})" return f"{self.returns_type().cpp_type()} {name}({args_str})"
def defn(self, name: Optional[str] = None) -> str: def defn(self, name: str | None = None) -> str:
args = [a.defn() for a in self.arguments()] args = [a.defn() for a in self.arguments()]
args_str = ", ".join(args) args_str = ", ".join(args)
if name is None: if name is None:
@ -62,7 +67,7 @@ class ExecutorchCppSignature:
@staticmethod @staticmethod
def from_native_function( def from_native_function(
f: NativeFunction, *, prefix: str = "" f: NativeFunction, *, prefix: str = ""
) -> "ExecutorchCppSignature": ) -> ExecutorchCppSignature:
return ExecutorchCppSignature( return ExecutorchCppSignature(
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
) )

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict
from torchgen.api.types import ( from torchgen.api.types import (
BaseCppType, BaseCppType,
@ -40,7 +41,7 @@ contextArg = Binding(
default=None, default=None,
) )
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = { BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
BaseTy.int: longT, BaseTy.int: longT,
BaseTy.float: doubleT, BaseTy.float: doubleT,
BaseTy.bool: boolT, BaseTy.bool: boolT,
@ -54,7 +55,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
@dataclass(frozen=True) @dataclass(frozen=True)
class OptionalCType(CType): class OptionalCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -63,13 +64,13 @@ class OptionalCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>" return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return OptionalCType(self.elem.remove_const_ref()) return OptionalCType(self.elem.remove_const_ref())
@dataclass(frozen=True) @dataclass(frozen=True)
class ArrayRefCType(CType): class ArrayRefCType(CType):
elem: "CType" elem: CType
def cpp_type(self, *, strip_ref: bool = False) -> str: def cpp_type(self, *, strip_ref: bool = False) -> str:
# Do not pass `strip_ref` recursively. # Do not pass `strip_ref` recursively.
@ -78,5 +79,5 @@ class ArrayRefCType(CType):
def cpp_type_registration_declarations(self) -> str: def cpp_type_registration_declarations(self) -> str:
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>" return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
def remove_const_ref(self) -> "CType": def remove_const_ref(self) -> CType:
return ArrayRefCType(self.elem.remove_const_ref()) return ArrayRefCType(self.elem.remove_const_ref())

View File

@ -1,7 +1,8 @@
from dataclasses import dataclass from __future__ import annotations
from typing import Callable, List, Sequence, Tuple
from dataclasses import dataclass
from typing import Callable, Sequence, TYPE_CHECKING
from torchgen.api.types import Binding, CType, NamedCType
from torchgen.model import ( from torchgen.model import (
Argument, Argument,
BaseTy, BaseTy,
@ -13,6 +14,10 @@ from torchgen.model import (
) )
if TYPE_CHECKING:
from torchgen.api.types import Binding, CType, NamedCType
connector = "\n\t" connector = "\n\t"
@ -52,7 +57,7 @@ class Unboxing:
# Convert all the arguments in a NativeFunction to C++ code # Convert all the arguments in a NativeFunction to C++ code
def convert_arguments( def convert_arguments(
self, args: Sequence[Binding] self, args: Sequence[Binding]
) -> Tuple[List[Binding], List[str]]: ) -> tuple[list[Binding], list[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = [] binding_list = []
for arg in args: for arg in args:
@ -72,7 +77,7 @@ class Unboxing:
def argumenttype_evalue_convert( def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False self, t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]: ) -> tuple[str, CType, list[str], list[str]]:
""" """
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument (1) the C++ code necessary to unbox the argument
@ -107,14 +112,14 @@ class Unboxing:
def _gen_code_base_type( def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType self, arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
return [ return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], [] ], []
def _gen_code_optional_type( def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_opt_in" in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name t.elem, in_name
@ -130,7 +135,7 @@ class Unboxing:
def _gen_code_list_type( def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
in_name = f"{arg_name}_list_in" in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem" elem_name = f"{arg_name}_elem"
code = [] code = []

View File

@ -1,11 +1,12 @@
# Represents all kernels used by an Executorch model. # Represents all kernels used by an Executorch model.
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure. # It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
from __future__ import annotations
import itertools import itertools
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum from enum import IntEnum
from typing import Dict, List, Tuple, Union
from torchgen.model import ( from torchgen.model import (
BackendIndex, BackendIndex,
@ -41,7 +42,7 @@ class ETKernelKeyOpArgMeta:
arg_name: str arg_name: str
dtype: str dtype: str
# The order of the dimensions if entry is a Tensor # The order of the dimensions if entry is a Tensor
dim_order: Tuple[int, ...] dim_order: tuple[int, ...]
def to_native_string(self) -> str: def to_native_string(self) -> str:
dtype_str = ScalarType[self.dtype].value dtype_str = ScalarType[self.dtype].value
@ -52,7 +53,7 @@ class ETKernelKeyOpArgMeta:
@dataclass(frozen=True) @dataclass(frozen=True)
class ETKernelKey: class ETKernelKey:
# Field undefined is default = True # Field undefined is default = True
arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = () arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
# Indicator for this kernel being used as a catch all # Indicator for this kernel being used as a catch all
default: bool = False default: bool = False
@ -61,10 +62,10 @@ class ETKernelKey:
@staticmethod @staticmethod
def gen_from_yaml( def gen_from_yaml(
args: Dict[str, Tuple[str, str]], args: dict[str, tuple[str, str]],
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
dim_order_alias_map: Dict[str, List[int]], dim_order_alias_map: dict[str, list[int]],
) -> List["ETKernelKey"]: ) -> list[ETKernelKey]:
"""Generate ETKernelKeys from arg kernel specs """Generate ETKernelKeys from arg kernel specs
Multiple ETKernelKeys are returned due to dtype permutations from utilizing Multiple ETKernelKeys are returned due to dtype permutations from utilizing
type_alias_map (actualizing each potential type permutation as a KernelKey) type_alias_map (actualizing each potential type permutation as a KernelKey)
@ -137,15 +138,15 @@ class ETKernelKey:
@dataclass(frozen=True) @dataclass(frozen=True)
class ETKernelIndex: class ETKernelIndex:
index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
m = self.get_kernels(g) m = self.get_kernels(g)
return m is not None return m is not None
def get_kernels( def get_kernels(
self, g: Union[NativeFunction, NativeFunctionsGroup] self, g: NativeFunction | NativeFunctionsGroup
) -> Dict[ETKernelKey, BackendMetadata]: ) -> dict[ETKernelKey, BackendMetadata]:
if isinstance(g, NativeFunction): if isinstance(g, NativeFunction):
f = g f = g
elif isinstance(g, NativeFunctionsGroup): elif isinstance(g, NativeFunctionsGroup):
@ -158,8 +159,8 @@ class ETKernelIndex:
@staticmethod @staticmethod
def grow_from_backend_indices( def grow_from_backend_indices(
kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]], kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None: ) -> None:
for dk in backend_indices: for dk in backend_indices:
index = backend_indices[dk] index = backend_indices[dk]
@ -171,17 +172,17 @@ class ETKernelIndex:
@staticmethod @staticmethod
def from_backend_indices( def from_backend_indices(
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex": ) -> ETKernelIndex:
kernel_index: Dict[ kernel_index: dict[
OperatorName, Dict[ETKernelKey, BackendMetadata] OperatorName, dict[ETKernelKey, BackendMetadata]
] = defaultdict(dict) ] = defaultdict(dict)
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
return ETKernelIndex(kernel_index) return ETKernelIndex(kernel_index)
def grow( def grow(
self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
) -> "ETKernelIndex": ) -> ETKernelIndex:
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices) ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
return self return self
@ -189,7 +190,7 @@ class ETKernelIndex:
""" """
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex. WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
""" """
index: Dict[OperatorName, BackendMetadata] = {} index: dict[OperatorName, BackendMetadata] = {}
for op in self.index: for op in self.index:
kernel_dict = self.index[op] kernel_dict = self.index[op]
assert ( assert (
@ -209,9 +210,7 @@ class ETKernelIndex:
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a # Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
@staticmethod @staticmethod
def merge_indices( def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
index_a: "ETKernelIndex", index_b: "ETKernelIndex"
) -> "ETKernelIndex":
combined = defaultdict(dict, index_a.index.copy()) combined = defaultdict(dict, index_a.index.copy())
for op, entry in index_b.index.items(): for op, entry in index_b.index.items():

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from collections import defaultdict, namedtuple from collections import defaultdict, namedtuple
from typing import Any, Dict, List, Optional, Set, Tuple from typing import Any
import yaml import yaml
@ -22,7 +24,7 @@ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indice
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"] ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]: def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
"""Given a loaded yaml representing kernel assignment information, extract the """Given a loaded yaml representing kernel assignment information, extract the
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance) mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
@ -34,11 +36,11 @@ def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]
if (kernels := e.pop("kernels", None)) is None: if (kernels := e.pop("kernels", None)) is None:
return {} return {}
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment] type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment] dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
dim_order_alias.pop("__line__", None) dim_order_alias.pop("__line__", None)
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {} kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
for entry in kernels: # type: ignore[attr-defined] for entry in kernels: # type: ignore[attr-defined]
arg_meta = entry.get("arg_meta") arg_meta = entry.get("arg_meta")
@ -76,7 +78,7 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
that should be used by the kernel key). that should be used by the kernel key).
""" """
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {} indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
for ei in es: # type: ignore[attr-defined] for ei in es: # type: ignore[attr-defined]
e = ei.copy() e = ei.copy()
@ -95,11 +97,11 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
return ETKernelIndex(indices) return ETKernelIndex(indices)
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]: def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
"""Given a loaded yaml representing a list of operators, extract the """Given a loaded yaml representing a list of operators, extract the
kernel key related fields indexed by the operator name. kernel key related fields indexed by the operator name.
""" """
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict) fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
for ei in es: # type: ignore[attr-defined] for ei in es: # type: ignore[attr-defined]
funcs = ei.get("func") funcs = ei.get("func")
assert isinstance(funcs, str), f"not a str: {funcs}" assert isinstance(funcs, str), f"not a str: {funcs}"
@ -118,9 +120,9 @@ def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
def parse_et_yaml( def parse_et_yaml(
path: str, path: str,
tags_yaml_path: str, tags_yaml_path: str,
ignore_keys: Optional[Set[DispatchKey]] = None, ignore_keys: set[DispatchKey] | None = None,
skip_native_fns_gen: bool = False, skip_native_fns_gen: bool = False,
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]: ) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict """Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
of fields to persist from native_functions.yaml to functions.yaml of fields to persist from native_functions.yaml to functions.yaml
""" """

View File

@ -1,23 +1,13 @@
from __future__ import annotations
import argparse import argparse
import functools import functools
import json import json
import os import os
import pathlib
from collections import defaultdict, namedtuple, OrderedDict from collections import defaultdict, namedtuple, OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ( from pathlib import Path
Any, from typing import Any, Callable, Literal, Sequence, TypeVar
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)
import yaml import yaml
@ -148,20 +138,20 @@ class LineLoader(YamlLoader):
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"]) ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {} _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {} _GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
def parse_native_yaml_struct( def parse_native_yaml_struct(
es: object, es: object,
valid_tags: Set[str], valid_tags: set[str],
ignore_keys: Optional[Set[DispatchKey]] = None, ignore_keys: set[DispatchKey] | None = None,
path: str = "<stdin>", path: str = "<stdin>",
skip_native_fns_gen: bool = False, skip_native_fns_gen: bool = False,
) -> ParsedYaml: ) -> ParsedYaml:
assert isinstance(es, list) assert isinstance(es, list)
rs: List[NativeFunction] = [] rs: list[NativeFunction] = []
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict) bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
for e in es: for e in es:
assert isinstance(e, dict), f"expected to be dict: {e}" assert isinstance(e, dict), f"expected to be dict: {e}"
assert isinstance(e.get("__line__"), int), e assert isinstance(e.get("__line__"), int), e
@ -174,7 +164,7 @@ def parse_native_yaml_struct(
BackendIndex.grow_index(bs, m) BackendIndex.grow_index(bs, m)
error_check_native_functions(rs) error_check_native_functions(rs)
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet. # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
indices: Dict[DispatchKey, BackendIndex] = defaultdict( indices: dict[DispatchKey, BackendIndex] = defaultdict(
lambda: BackendIndex( lambda: BackendIndex(
dispatch_key=DispatchKey.Undefined, dispatch_key=DispatchKey.Undefined,
use_out_as_primary=True, use_out_as_primary=True,
@ -200,9 +190,9 @@ def parse_native_yaml_struct(
return ParsedYaml(rs, indices) return ParsedYaml(rs, indices)
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]: def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
assert isinstance(es, list) assert isinstance(es, list)
rs: Set[str] = set() rs: set[str] = set()
for e in es: for e in es:
assert isinstance(e.get("__line__"), int), e assert isinstance(e.get("__line__"), int), e
loc = Location(path, e["__line__"]) loc = Location(path, e["__line__"])
@ -218,7 +208,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def parse_tags_yaml(path: str) -> Set[str]: def parse_tags_yaml(path: str) -> set[str]:
global _GLOBAL_PARSE_TAGS_YAML_CACHE global _GLOBAL_PARSE_TAGS_YAML_CACHE
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE: if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
with open(path) as f: with open(path) as f:
@ -231,10 +221,10 @@ def parse_tags_yaml(path: str) -> Set[str]:
def parse_native_yaml( def parse_native_yaml(
path: str, path: str,
tags_yaml_path: str, tags_yaml_path: str,
ignore_keys: Optional[Set[DispatchKey]] = None, ignore_keys: set[DispatchKey] | None = None,
*, *,
skip_native_fns_gen: bool = False, skip_native_fns_gen: bool = False,
loaded_yaml: Optional[object] = None, loaded_yaml: object | None = None,
) -> ParsedYaml: ) -> ParsedYaml:
global _GLOBAL_PARSE_NATIVE_YAML_CACHE global _GLOBAL_PARSE_NATIVE_YAML_CACHE
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE: if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
@ -261,8 +251,8 @@ def parse_native_yaml(
# Some assertions are already performed during parsing, but those are only within a single NativeFunction. # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
# Assertions here are meant to be performed across NativeFunctions. # Assertions here are meant to be performed across NativeFunctions.
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None: def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
func_map: Dict[OperatorName, NativeFunction] = {} func_map: dict[OperatorName, NativeFunction] = {}
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list) base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
for f in funcs: for f in funcs:
func_map[f.func.name] = f func_map[f.func.name] = f
base_func_map[f.func.name.name].append(f) base_func_map[f.func.name.name].append(f)
@ -329,7 +319,7 @@ def cpp_string(s: str) -> str:
# and similar functional combinators. # and similar functional combinators.
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]: def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
if len(backends) == 0: if len(backends) == 0:
return [] return []
else: else:
@ -343,7 +333,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
def get_static_dispatch_backend( def get_static_dispatch_backend(
f: NativeFunction, backend_index: BackendIndex f: NativeFunction, backend_index: BackendIndex
) -> Optional[DispatchKey]: ) -> DispatchKey | None:
if f.structured_delegate is not None or backend_index.has_kernel(f): if f.structured_delegate is not None or backend_index.has_kernel(f):
# TODO: for ops with structured_delegate it should check the dispatch table of # TODO: for ops with structured_delegate it should check the dispatch table of
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
@ -362,8 +352,8 @@ def get_static_dispatch_backend(
def static_dispatch_ops_header( def static_dispatch_ops_header(
f: NativeFunction, backend_index: List[BackendIndex] f: NativeFunction, backend_index: list[BackendIndex]
) -> Optional[str]: ) -> str | None:
if backend_index is None or f.manual_kernel_registration: if backend_index is None or f.manual_kernel_registration:
return None return None
@ -377,7 +367,7 @@ def static_dispatch_ops_header(
return "\n".join(output) return "\n".join(output)
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]: def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
return [ return [
f"#include <ATen/{dispatch_key}Functions.h>" f"#include <ATen/{dispatch_key}Functions.h>"
for dispatch_key in static_dispatch_keys(backends) for dispatch_key in static_dispatch_keys(backends)
@ -388,12 +378,12 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
# Note that we have a special case for `memory_format` argument and this case is not covered by # Note that we have a special case for `memory_format` argument and this case is not covered by
# tools.codegen.api.translate() yet as its application is limited to static dispatch. # tools.codegen.api.translate() yet as its application is limited to static dispatch.
def translate_args( def translate_args(
sig: Union[CppSignature, DispatcherSignature], sig: CppSignature | DispatcherSignature,
cpp_sig: CppSignature, cpp_sig: CppSignature,
) -> str: ) -> str:
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]: def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
output_bindings: List[Binding] = [] output_bindings: list[Binding] = []
for binding in input_bindings: for binding in input_bindings:
if binding.name == "memory_format": if binding.name == "memory_format":
spl_mem_format_binding = Binding( spl_mem_format_binding = Binding(
@ -423,7 +413,7 @@ def translate_args(
def generate_static_dispatch_backend_call( def generate_static_dispatch_backend_call(
sig: Union[CppSignature, DispatcherSignature], sig: CppSignature | DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_index: BackendIndex, backend_index: BackendIndex,
) -> str: ) -> str:
@ -441,9 +431,9 @@ def generate_static_dispatch_backend_call(
def generate_static_dispatch_fallback_call( def generate_static_dispatch_fallback_call(
sig: Union[CppSignature, DispatcherSignature], sig: CppSignature | DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: list[BackendIndex],
) -> str: ) -> str:
cpp_sigs = CppSignatureGroup.from_native_function( cpp_sigs = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=False f, method=False, fallback_binding=False
@ -470,9 +460,9 @@ def generate_static_dispatch_fallback_call(
def static_dispatch( def static_dispatch(
sig: Union[CppSignature, DispatcherSignature], sig: CppSignature | DispatcherSignature,
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: list[BackendIndex],
) -> str: ) -> str:
""" """
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
@ -512,7 +502,7 @@ def static_dispatch(
tensor_opts = f.func.arguments.tensor_options tensor_opts = f.func.arguments.tensor_options
stmts = [] stmts = []
subexprs: List[str] = [] subexprs: list[str] = []
if tensor_opts is not None: if tensor_opts is not None:
subexprs.append( subexprs.append(
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))" "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
@ -548,10 +538,10 @@ def static_dispatch(
@dataclass(frozen=True) @dataclass(frozen=True)
class RegisterSchema: class RegisterSchema:
selector: SelectiveBuilder selector: SelectiveBuilder
known_tags: Dict[str, int] = field(default_factory=dict) known_tags: dict[str, int] = field(default_factory=dict)
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
if not self.selector.is_native_function_selected(f): if not self.selector.is_native_function_selected(f):
return None return None
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}" tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
@ -573,7 +563,7 @@ class RegisterSchema:
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeOperators: class ComputeOperators:
target: Literal[Target.DECLARATION, Target.DEFINITION] target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: List[BackendIndex] static_dispatch_backend_indices: list[BackendIndex]
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> str: def __call__(self, f: NativeFunction) -> str:
@ -670,7 +660,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeFunction: class ComputeFunction:
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
sig_group = CppSignatureGroup.from_native_function( sig_group = CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding f, method=False, fallback_binding=f.manual_cpp_binding
) )
@ -718,10 +708,10 @@ namespace symint {{
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeTensorMethod: class ComputeTensorMethod:
target: Literal[Target.DECLARATION, Target.DEFINITION] target: Literal[Target.DECLARATION, Target.DEFINITION]
static_dispatch_backend_indices: List[BackendIndex] static_dispatch_backend_indices: list[BackendIndex]
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
if Variant.method not in f.variants: if Variant.method not in f.variants:
return None return None
@ -764,7 +754,7 @@ inline {sig.defn(prefix="Tensor::")} const {{
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeRedispatchFunction: class ComputeRedispatchFunction:
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
# We unconditionally generate function variants of the redispatch API. # We unconditionally generate function variants of the redispatch API.
# This is mainly because we can namespace functions separately, but not methods, # This is mainly because we can namespace functions separately, but not methods,
sig_group = CppSignatureGroup.from_native_function( sig_group = CppSignatureGroup.from_native_function(
@ -798,7 +788,7 @@ def compute_aten_op(f: NativeFunction) -> str:
# Generates MetaFunctions.h # Generates MetaFunctions.h
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]: def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
if not g.structured: if not g.structured:
return None return None
with native_function_manager(g.out): with native_function_manager(g.out):
@ -943,7 +933,7 @@ class ComputeBackendSelect:
selector: SelectiveBuilder selector: SelectiveBuilder
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
if not needs_backend_select(f, self.selector): if not needs_backend_select(f, self.selector):
return None return None
@ -959,7 +949,7 @@ class ComputeBackendSelect:
dispatcher_sig = DispatcherSignature.from_schema(f.func) dispatcher_sig = DispatcherSignature.from_schema(f.func)
sig: Union[NativeSignature, DispatcherSignature] sig: NativeSignature | DispatcherSignature
sig = dispatcher_sig sig = dispatcher_sig
dispatcher_exprs = dispatcher_sig.exprs() dispatcher_exprs = dispatcher_sig.exprs()
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)" dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
@ -1059,7 +1049,7 @@ def dynamic_type(t: Type) -> str:
).cpp_type() ).cpp_type()
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]: def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
# This is written out explicitly to ensure that Tensor and # This is written out explicitly to ensure that Tensor and
# namespace are put into the list in the right order # namespace are put into the list in the right order
method_of = ["Type"] method_of = ["Type"]
@ -1072,7 +1062,7 @@ def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
def compute_returns_yaml( def compute_returns_yaml(
f: NativeFunction, f: NativeFunction,
) -> Tuple[List[Dict[str, str]], Dict[str, str]]: ) -> tuple[list[dict[str, str]], dict[str, str]]:
# Note [name and field_name] # Note [name and field_name]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# To understand name_to_field_name, we must first talk about this # To understand name_to_field_name, we must first talk about this
@ -1112,7 +1102,7 @@ def compute_returns_yaml(
# schema itself. # schema itself.
# #
# See also https://github.com/pytorch/pytorch/issues/43114 # See also https://github.com/pytorch/pytorch/issues/43114
name_to_field_name: Dict[str, str] = {} name_to_field_name: dict[str, str] = {}
# Compute the returns field of the YAML entry # Compute the returns field of the YAML entry
names = cpp.return_names(f) names = cpp.return_names(f)
@ -1141,12 +1131,12 @@ def compute_cpp_argument_yaml(
cpp_a: Binding, cpp_a: Binding,
*, *,
schema_order: bool, schema_order: bool,
kwarg_only_set: Set[str], kwarg_only_set: set[str],
out_arg_set: Set[str], out_arg_set: set[str],
name_to_field_name: Dict[str, str], name_to_field_name: dict[str, str],
) -> object: ) -> object:
if isinstance(cpp_a.argument, TensorOptionsArguments): if isinstance(cpp_a.argument, TensorOptionsArguments):
arg: Dict[str, object] = { arg: dict[str, object] = {
"annotation": None, "annotation": None,
"dynamic_type": "at::TensorOptions", "dynamic_type": "at::TensorOptions",
"is_nullable": False, "is_nullable": False,
@ -1173,11 +1163,11 @@ def compute_argument_yaml(
a: Argument, a: Argument,
*, *,
schema_order: bool, schema_order: bool,
kwarg_only_set: Set[str], kwarg_only_set: set[str],
out_arg_set: Set[str], out_arg_set: set[str],
name_to_field_name: Dict[str, str], name_to_field_name: dict[str, str],
) -> object: ) -> object:
arg: Dict[str, object] = { arg: dict[str, object] = {
"annotation": str(a.annotation) if a.annotation else None, "annotation": str(a.annotation) if a.annotation else None,
"dynamic_type": dynamic_type(a.type), "dynamic_type": dynamic_type(a.type),
"is_nullable": a.type.is_nullable(), "is_nullable": a.type.is_nullable(),
@ -1303,7 +1293,7 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
@with_native_function_and_indices @with_native_function_and_indices
def compute_registration_declarations( def compute_registration_declarations(
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex] f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
) -> str: ) -> str:
name = dispatcher.name(f.func) name = dispatcher.name(f.func)
returns_type = dispatcher.returns_type( returns_type = dispatcher.returns_type(
@ -1311,7 +1301,7 @@ def compute_registration_declarations(
).cpp_type_registration_declarations() ).cpp_type_registration_declarations()
args = dispatcher.arguments(f.func) args = dispatcher.arguments(f.func)
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args) args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
comment_data: Dict[str, str] = { comment_data: dict[str, str] = {
"schema": f"aten::{f.func}", "schema": f"aten::{f.func}",
# TODO: What exactly is the semantics of the 'dispatch' field? # TODO: What exactly is the semantics of the 'dispatch' field?
"dispatch": str( "dispatch": str(
@ -1337,8 +1327,8 @@ def compute_registration_declarations(
def get_custom_build_selector( def get_custom_build_selector(
provided_op_registration_allowlist: Optional[List[str]], provided_op_registration_allowlist: list[str] | None,
op_selection_yaml_path: Optional[str], op_selection_yaml_path: str | None,
) -> SelectiveBuilder: ) -> SelectiveBuilder:
assert not ( assert not (
provided_op_registration_allowlist is not None provided_op_registration_allowlist is not None
@ -1349,7 +1339,7 @@ def get_custom_build_selector(
+ "same time." + "same time."
) )
op_registration_allowlist: Optional[Set[str]] = None op_registration_allowlist: set[str] | None = None
if provided_op_registration_allowlist is not None: if provided_op_registration_allowlist is not None:
op_registration_allowlist = set(provided_op_registration_allowlist) op_registration_allowlist = set(provided_op_registration_allowlist)
@ -1369,11 +1359,11 @@ def get_custom_build_selector(
def get_grouped_by_view_native_functions( def get_grouped_by_view_native_functions(
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]: ) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
def maybe_create_view_group( def maybe_create_view_group(
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction] d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]: ) -> list[NativeFunction | NativeFunctionsViewGroup]:
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = [] funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
if ViewSchemaKind.aliasing in d: if ViewSchemaKind.aliasing in d:
view = d.pop(ViewSchemaKind.aliasing) view = d.pop(ViewSchemaKind.aliasing)
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None) view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
@ -1391,8 +1381,8 @@ def get_grouped_by_view_native_functions(
funcs.extend(d.values()) funcs.extend(d.values())
return funcs return funcs
grouped_by_views: Dict[ grouped_by_views: dict[
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction] FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
] = defaultdict(dict) ] = defaultdict(dict)
for f in native_functions: for f in native_functions:
schema = f.func.view_signature() schema = f.func.view_signature()
@ -1416,10 +1406,10 @@ def get_grouped_by_view_native_functions(
def get_grouped_native_functions( def get_grouped_native_functions(
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
def flatten_pre_group( def flatten_pre_group(
d: Dict[SchemaKind, NativeFunction] d: dict[SchemaKind, NativeFunction]
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]: ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
r = NativeFunctionsGroup.from_dict(d) r = NativeFunctionsGroup.from_dict(d)
if r is None: if r is None:
# Invariant: any NativeFunctions that are code-generated # Invariant: any NativeFunctions that are code-generated
@ -1438,13 +1428,13 @@ def get_grouped_native_functions(
def get_ns_grouped_kernels( def get_ns_grouped_kernels(
*, *,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[ native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
] = dest.compute_native_function_declaration, ] = dest.compute_native_function_declaration,
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
for f in grouped_native_functions: for f in grouped_native_functions:
native_function_namespaces = set() native_function_namespaces = set()
dispatch_keys = set() dispatch_keys = set()
@ -1467,9 +1457,9 @@ def get_ns_grouped_kernels(
def get_native_function_declarations_from_ns_grouped_kernels( def get_native_function_declarations_from_ns_grouped_kernels(
*, *,
ns_grouped_kernels: Dict[str, List[str]], ns_grouped_kernels: dict[str, list[str]],
) -> List[str]: ) -> list[str]:
declarations: List[str] = [] declarations: list[str] = []
newline = "\n" newline = "\n"
for namespace, kernels in ns_grouped_kernels.items(): for namespace, kernels in ns_grouped_kernels.items():
ns_helper = NamespaceHelper( ns_helper = NamespaceHelper(
@ -1495,12 +1485,12 @@ def get_native_function_declarations_from_ns_grouped_kernels(
# Return native function declarations grouped by their namespaces. # Return native function declarations grouped by their namespaces.
def get_native_function_declarations( def get_native_function_declarations(
*, *,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
native_function_decl_gen: Callable[ native_function_decl_gen: Callable[
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str] [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
] = dest.compute_native_function_declaration, ] = dest.compute_native_function_declaration,
) -> List[str]: ) -> list[str]:
""" """
Generate kernel declarations, in `NativeFunction(s).h`. Generate kernel declarations, in `NativeFunction(s).h`.
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`. :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
@ -1520,7 +1510,7 @@ def get_native_function_declarations(
def get_kernel_namespace( def get_kernel_namespace(
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
) -> str: ) -> str:
backend_metadata = backend_idx.get_kernel(f) backend_metadata = backend_idx.get_kernel(f)
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, ( assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
@ -1538,7 +1528,7 @@ def get_kernel_namespace(
def get_native_function_definitions( def get_native_function_definitions(
*, *,
fm: FileManager, fm: FileManager,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_idx: BackendIndex, backend_idx: BackendIndex,
selector: SelectiveBuilder, selector: SelectiveBuilder,
@ -1546,11 +1536,11 @@ def get_native_function_definitions(
symint: bool, symint: bool,
skip_dispatcher_op_registration: bool, skip_dispatcher_op_registration: bool,
gen_dispatch_helpers: bool, gen_dispatch_helpers: bool,
) -> List[str]: ) -> list[str]:
definitions: List[str] = [] definitions: list[str] = []
ns_definitions: Dict[str, List[str]] = defaultdict(list) ns_definitions: dict[str, list[str]] = defaultdict(list)
anonymous_definitions: Dict[str, List[str]] = defaultdict(list) anonymous_definitions: dict[str, list[str]] = defaultdict(list)
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict) registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
newline = "\n" newline = "\n"
ns_gen = dest.RegisterDispatchKey( ns_gen = dest.RegisterDispatchKey(
backend_idx, backend_idx,
@ -1640,15 +1630,15 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
# Used in CPUFunctions_inl.h and etc. # Used in CPUFunctions_inl.h and etc.
def get_namespaced_declaration( def get_namespaced_declaration(
*, *,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_idx: BackendIndex, backend_idx: BackendIndex,
selector: SelectiveBuilder, selector: SelectiveBuilder,
rocm: bool, rocm: bool,
symint: bool, symint: bool,
) -> List[str]: ) -> list[str]:
declarations: List[str] = [] declarations: list[str] = []
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
newline = "\n" newline = "\n"
func = dest.RegisterDispatchKey( func = dest.RegisterDispatchKey(
backend_idx, backend_idx,
@ -1692,8 +1682,8 @@ def get_native_function_schema_registrations(
*, *,
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
schema_selector: SelectiveBuilder, schema_selector: SelectiveBuilder,
) -> Tuple[List[str], str]: ) -> tuple[list[str], str]:
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list) ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
for native_function in native_functions: for native_function in native_functions:
ns_native_functions[native_function.namespace].append(native_function) ns_native_functions[native_function.namespace].append(native_function)
schema_registrations = "" schema_registrations = ""
@ -1727,14 +1717,14 @@ def get_native_function_schema_registrations(
def gen_aggregated_headers( def gen_aggregated_headers(
*, *,
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
structured_native_functions: Sequence[NativeFunctionsGroup], structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex], static_dispatch_idx: list[BackendIndex],
selector: SelectiveBuilder, selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, cuda_fm: FileManager,
functions_keys: Set[DispatchKey], functions_keys: set[DispatchKey],
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
rocm: bool, rocm: bool,
) -> None: ) -> None:
@ -1848,25 +1838,25 @@ def gen_aggregated_headers(
def gen_per_operator_headers( def gen_per_operator_headers(
*, *,
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex], static_dispatch_idx: list[BackendIndex],
selector: SelectiveBuilder, selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, cuda_fm: FileManager,
ops_fm: FileManager, ops_fm: FileManager,
functions_keys: Set[DispatchKey], functions_keys: set[DispatchKey],
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
rocm: bool, rocm: bool,
) -> None: ) -> None:
# For CMake builds, split operator declarations into separate headers in # For CMake builds, split operator declarations into separate headers in
# the ATen/ops folder to split up header dependencies # the ATen/ops folder to split up header dependencies
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list) functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
for fn in native_functions: for fn in native_functions:
functions_by_root_name[fn.root_name].append(fn) functions_by_root_name[fn.root_name].append(fn)
grouped_functions_by_root_name: Dict[ grouped_functions_by_root_name: dict[
str, List[Union[NativeFunction, NativeFunctionsGroup]] str, list[NativeFunction | NativeFunctionsGroup]
] = defaultdict(list) ] = defaultdict(list)
for group in grouped_native_functions: for group in grouped_native_functions:
name = group.root_name name = group.root_name
@ -2042,18 +2032,18 @@ def gen_per_operator_headers(
def gen_headers( def gen_headers(
*, *,
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
valid_tags: Set[str], valid_tags: set[str],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
structured_native_functions: Sequence[NativeFunctionsGroup], structured_native_functions: Sequence[NativeFunctionsGroup],
static_dispatch_idx: List[BackendIndex], static_dispatch_idx: list[BackendIndex],
selector: SelectiveBuilder, selector: SelectiveBuilder,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
core_fm: FileManager, core_fm: FileManager,
cpu_fm: FileManager, cpu_fm: FileManager,
cuda_fm: FileManager, cuda_fm: FileManager,
ops_fm: FileManager, ops_fm: FileManager,
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey], functions_keys: set[DispatchKey],
rocm: bool, rocm: bool,
per_operator_headers: bool, per_operator_headers: bool,
) -> None: ) -> None:
@ -2133,8 +2123,8 @@ def gen_headers(
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions) "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
) )
def gen_aten_interned_strings() -> Dict[str, str]: def gen_aten_interned_strings() -> dict[str, str]:
attrs: Set[str] = set() # All function argument names attrs: set[str] = set() # All function argument names
names = set() # All ATen function names names = set() # All ATen function names
for func in native_functions: for func in native_functions:
names.add(str(func.func.name.name)) names.add(str(func.func.name.name))
@ -2171,7 +2161,7 @@ def gen_headers(
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings) core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
def gen_tags_enum() -> Dict[str, str]: def gen_tags_enum() -> dict[str, str]:
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))} return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
core_fm.write("enum_tag.h", gen_tags_enum) core_fm.write("enum_tag.h", gen_tags_enum)
@ -2180,19 +2170,19 @@ def gen_headers(
def gen_source_files( def gen_source_files(
*, *,
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
structured_native_functions: Sequence[NativeFunctionsGroup], structured_native_functions: Sequence[NativeFunctionsGroup],
view_groups: Sequence[NativeFunctionsViewGroup], view_groups: Sequence[NativeFunctionsViewGroup],
selector: SelectiveBuilder, selector: SelectiveBuilder,
static_dispatch_idx: List[BackendIndex], static_dispatch_idx: list[BackendIndex],
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
aoti_fm: FileManager, aoti_fm: FileManager,
core_fm: FileManager, core_fm: FileManager,
cpu_fm: FileManager, cpu_fm: FileManager,
cpu_vec_fm: FileManager, cpu_vec_fm: FileManager,
cuda_fm: FileManager, cuda_fm: FileManager,
dispatch_keys: Sequence[DispatchKey], dispatch_keys: Sequence[DispatchKey],
functions_keys: Set[DispatchKey], functions_keys: set[DispatchKey],
rocm: bool, rocm: bool,
force_schema_registration: bool, force_schema_registration: bool,
per_operator_headers: bool, per_operator_headers: bool,
@ -2216,7 +2206,7 @@ def gen_source_files(
if per_operator_headers: if per_operator_headers:
def operator_headers() -> List[str]: def operator_headers() -> list[str]:
headers = [] headers = []
for g in grouped_native_functions: for g in grouped_native_functions:
is_registered = False is_registered = False
@ -2258,7 +2248,7 @@ def gen_source_files(
else: else:
def operator_headers() -> List[str]: def operator_headers() -> list[str]:
headers = ["#include <ATen/NativeFunctions.h>"] headers = ["#include <ATen/NativeFunctions.h>"]
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional: if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
headers.append("#include <ATen/Functions.h>") headers.append("#include <ATen/Functions.h>")
@ -2449,7 +2439,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
del fm del fm
# BackendSelect is generated specially # BackendSelect is generated specially
def gen_backend_select() -> Dict[str, List[str]]: def gen_backend_select() -> dict[str, list[str]]:
relevant_fns = [ relevant_fns = [
fn for fn in native_functions if needs_backend_select(fn, selector) fn for fn in native_functions if needs_backend_select(fn, selector)
] ]
@ -2494,7 +2484,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
) )
def key_func( def key_func(
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> str: ) -> str:
return fn.root_name return fn.root_name
@ -2536,11 +2526,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
) )
def functionalization_env_callable( def functionalization_env_callable(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
def gen_op_headers( def gen_op_headers(
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> List[str]: ) -> list[str]:
if isinstance(g, NativeFunctionsViewGroup): if isinstance(g, NativeFunctionsViewGroup):
# view ops always get a functionalization kernel # view ops always get a functionalization kernel
headers = [ headers = [
@ -2590,8 +2580,8 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
), ),
} }
all_groups: List[ all_groups: list[
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup] NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
] = list(structured_native_functions) + list( ] = list(structured_native_functions) + list(
view_groups # type: ignore[assignment, arg-type, operator] view_groups # type: ignore[assignment, arg-type, operator]
) )
@ -2600,11 +2590,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic) # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped. # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
# Although this could go away long-term if we add a dedicated dispatch key for decompositions. # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
structured_map: Dict[OperatorName, NativeFunction] = { structured_map: dict[OperatorName, NativeFunction] = {
f.func.name: f f.func.name: f
for f in concatMap(lambda g: list(g.functions()), structured_native_functions) for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
} }
view_map: Dict[OperatorName, NativeFunction] = { view_map: dict[OperatorName, NativeFunction] = {
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups) f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
} }
for f in native_functions: for f in native_functions:
@ -2715,12 +2705,12 @@ def gen_declarations_yaml(
) )
def get_torchgen_root() -> pathlib.Path: def get_torchgen_root() -> Path:
""" """
If you're depending on torchgen out-of-tree, you can use the root to figure If you're depending on torchgen out-of-tree, you can use the root to figure
out the path to native_functions.yaml out the path to native_functions.yaml
""" """
return pathlib.Path(__file__).parent.resolve() return Path(__file__).parent.resolve()
def main() -> None: def main() -> None:
@ -2882,11 +2872,11 @@ def main() -> None:
# #
# Invalid character escape '\c'. # Invalid character escape '\c'.
core_install_dir = f"{options.install_dir}/core" core_install_dir = f"{options.install_dir}/core"
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True) Path(core_install_dir).mkdir(parents=True, exist_ok=True)
ops_install_dir = f"{options.install_dir}/ops" ops_install_dir = f"{options.install_dir}/ops"
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True) Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
aoti_install_dir = f"{options.aoti_install_dir}" aoti_install_dir = f"{options.aoti_install_dir}"
pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True) Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
core_fm = make_file_manager(options=options, install_dir=core_install_dir) core_fm = make_file_manager(options=options, install_dir=core_install_dir)
cpu_fm = make_file_manager(options=options) cpu_fm = make_file_manager(options=options)
@ -2916,7 +2906,7 @@ def main() -> None:
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
] ]
static_dispatch_idx: List[BackendIndex] = [] static_dispatch_idx: list[BackendIndex] = []
if options.static_dispatch_backend: if options.static_dispatch_backend:
static_dispatch_idx = [ static_dispatch_idx = [
backend_indices[DispatchKey.parse(key)] backend_indices[DispatchKey.parse(key)]
@ -2973,7 +2963,7 @@ def main() -> None:
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm) gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
if options.output_dependencies: if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve() depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name depfile_name = depfile_path.name
depfile_stem = depfile_path.stem depfile_stem = depfile_path.stem

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Sequence
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
@ -69,7 +71,7 @@ base_type_to_callsite_expr = {
# convert args to C types, names in declarations, and expressions in function bodies # convert args to C types, names in declarations, and expressions in function bodies
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return] def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
if isinstance(typ, BaseType): if isinstance(typ, BaseType):
if typ.name in base_type_to_c_type: if typ.name in base_type_to_c_type:
return ( return (
@ -167,12 +169,12 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str
) )
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]: def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
return [typ + " " + name for typ, name in zip(types, names)] return [typ + " " + name for typ, name in zip(types, names)]
# Generate argument declarations and callsite expressions # Generate argument declarations and callsite expressions
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]: def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
types = [] types = []
new_names = [] new_names = []
callsite_exprs = [] callsite_exprs = []
@ -189,7 +191,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[s
# Return values are passed out as pointer arguments because all the C shim functions # Return values are passed out as pointer arguments because all the C shim functions
# are expected to return AOTITorchError. # are expected to return AOTITorchError.
# Generate returns as declarations and callsite expressions # Generate returns as declarations and callsite expressions
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]: def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
types = [] types = []
names = [] names = []
for idx, ret in enumerate(schema.returns): for idx, ret in enumerate(schema.returns):
@ -222,7 +224,7 @@ def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
ret_pointer_can_be_null = True ret_pointer_can_be_null = True
break break
callsite_exprs: List[str] = [] callsite_exprs: list[str] = []
for idx, ret in enumerate(schema.returns): for idx, ret in enumerate(schema.returns):
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)" tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
assert isinstance(ret.type, BaseType) assert isinstance(ret.type, BaseType)
@ -236,12 +238,12 @@ def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
# gen.py generates header first and then src, so caching the result here to avoid duplicate work # gen.py generates header first and then src, so caching the result here to avoid duplicate work
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {} declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
def gen_declaration_and_definition( def gen_declaration_and_definition(
schema: FunctionSchema, device: str, backend_call: str schema: FunctionSchema, device: str, backend_call: str
) -> Tuple[str, str]: ) -> tuple[str, str]:
func_name = schema.name.unambiguous_name() func_name = schema.name.unambiguous_name()
global declaration_definition_cache global declaration_definition_cache
@ -254,7 +256,7 @@ def gen_declaration_and_definition(
args, callsite_exprs = gen_arguments( args, callsite_exprs = gen_arguments(
[*schema.arguments.out, *schema.arguments.flat_non_out] [*schema.arguments.out, *schema.arguments.flat_non_out]
) )
ret_assignments: List[str] = [] ret_assignments: list[str] = []
else: else:
args, callsite_exprs = gen_arguments(schema.arguments.flat_all) args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
# ignore return values for inplace ops # ignore return values for inplace ops
@ -284,7 +286,7 @@ def gen_declaration_and_definition(
def gen_static_dispatch_backend_call_signature( def gen_static_dispatch_backend_call_signature(
sig: Union[CppSignature, DispatcherSignature], sig: CppSignature | DispatcherSignature,
f: NativeFunction, f: NativeFunction,
) -> CppSignature: ) -> CppSignature:
sig = DispatcherSignature.from_schema(f.func) sig = DispatcherSignature.from_schema(f.func)
@ -310,10 +312,10 @@ def gen_static_dispatch_backend_call(
def get_backend_index_for_aoti( def get_backend_index_for_aoti(
func: NativeFunction, func: NativeFunction,
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
) -> Optional[BackendIndex]: ) -> BackendIndex | None:
backend_index = None backend_index = None
if backend_indices[dispatch_key].has_kernel(func) or ( if backend_indices[dispatch_key].has_kernel(func) or (
func.structured_delegate is not None func.structured_delegate is not None
@ -341,10 +343,10 @@ def get_backend_index_for_aoti(
def get_header_for_aoti( def get_header_for_aoti(
func: NativeFunction, func: NativeFunction,
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
) -> Optional[str]: ) -> str | None:
backend_index = get_backend_index_for_aoti( backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices func, func_group_mapping, dispatch_key, backend_indices
) )
@ -365,11 +367,11 @@ def get_fallback_op_name(func: NativeFunction) -> str:
def gen_c_shim( def gen_c_shim(
func: NativeFunction, func: NativeFunction,
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
header: bool, header: bool,
) -> Optional[str]: ) -> str | None:
backend_index = get_backend_index_for_aoti( backend_index = get_backend_index_for_aoti(
func, func_group_mapping, dispatch_key, backend_indices func, func_group_mapping, dispatch_key, backend_indices
) )
@ -399,16 +401,16 @@ def gen_c_shim(
@dataclass(frozen=True) @dataclass(frozen=True)
class ShimGenerator: class ShimGenerator:
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup] func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
dispatch_key: DispatchKey dispatch_key: DispatchKey
backend_indices: Dict[DispatchKey, BackendIndex] backend_indices: dict[DispatchKey, BackendIndex]
header: bool # True to generate .h and False to generate .cpp header: bool # True to generate .h and False to generate .cpp
@method_with_native_function @method_with_native_function
def __call__( def __call__(
self, self,
func: NativeFunction, func: NativeFunction,
) -> Optional[str]: ) -> str | None:
result = gen_c_shim( result = gen_c_shim(
func, func,
self.func_group_mapping, self.func_group_mapping,
@ -421,9 +423,9 @@ class ShimGenerator:
def gen_aoti_c_shim( def gen_aoti_c_shim(
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup], func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
header: bool, header: bool,
includes: str = "", includes: str = "",
) -> str: ) -> str:

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse import argparse
import os import os
import pathlib import pathlib
import re import re
from collections import Counter, defaultdict, namedtuple from collections import Counter, defaultdict, namedtuple
from typing import Dict, List, Optional, Sequence, Set, Union from typing import Sequence
import yaml import yaml
@ -36,10 +38,10 @@ ParsedExternalYaml = namedtuple(
def parse_backend_yaml( def parse_backend_yaml(
backend_yaml_path: str, backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
) -> ParsedExternalYaml: ) -> ParsedExternalYaml:
native_functions_map: Dict[OperatorName, NativeFunction] = { native_functions_map: dict[OperatorName, NativeFunction] = {
f.func.name: f f.func.name: f
for f in concatMap( for f in concatMap(
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()), lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
@ -119,14 +121,14 @@ def parse_backend_yaml(
Only the following keys are supported: {", ".join(valid_keys)}' Only the following keys are supported: {", ".join(valid_keys)}'
def create_backend_index( def create_backend_index(
backend_ops: List[str], backend_ops: list[str],
symint_ops: Set[str], symint_ops: set[str],
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
*, *,
use_out_as_primary: bool, use_out_as_primary: bool,
use_device_guard: bool, use_device_guard: bool,
) -> BackendIndex: ) -> BackendIndex:
metadata: Dict[OperatorName, BackendMetadata] = {} metadata: dict[OperatorName, BackendMetadata] = {}
for op in backend_ops: for op in backend_ops:
op_name = OperatorName.parse(op) op_name = OperatorName.parse(op)
assert ( assert (
@ -149,7 +151,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
index=metadata, index=metadata,
) )
backend_key: Optional[DispatchKey] = None backend_key: DispatchKey | None = None
if len(supported) > 0: if len(supported) > 0:
with context( with context(
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.' lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
@ -166,7 +168,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
assert backend_key not in backend_indices assert backend_key not in backend_indices
backend_indices[backend_key] = backend_idx backend_indices[backend_key] = backend_idx
autograd_key: Optional[DispatchKey] = None autograd_key: DispatchKey | None = None
if len(supported_autograd) > 0: if len(supported_autograd) > 0:
with context( with context(
lambda: f'The "autograd" key was specified, which indicates that you would like to override \ lambda: f'The "autograd" key was specified, which indicates that you would like to override \
@ -245,12 +247,12 @@ autograd key. They cannot be mix and matched. If this is something you need, fee
def error_on_missing_kernels( def error_on_missing_kernels(
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
backend_key: DispatchKey, backend_key: DispatchKey,
autograd_key: Optional[DispatchKey], autograd_key: DispatchKey | None,
class_name: str, class_name: str,
kernel_defn_file_path: str, kernel_defn_file_path: str,
full_codegen: Optional[List[OperatorName]] = None, full_codegen: list[OperatorName] | None = None,
) -> None: ) -> None:
try: try:
with open(kernel_defn_file_path) as f: with open(kernel_defn_file_path) as f:
@ -268,7 +270,7 @@ def error_on_missing_kernels(
) )
# Quick mapping from each OperatorName used by the external backend # Quick mapping from each OperatorName used by the external backend
# to its backend kernel name # to its backend kernel name
expected_backend_op_names: Dict[OperatorName, str] = dict( expected_backend_op_names: dict[OperatorName, str] = dict(
list( list(
concatMap( concatMap(
lambda index: [ lambda index: [
@ -278,13 +280,13 @@ def error_on_missing_kernels(
) )
) )
) )
expected_backend_native_funcs: List[NativeFunction] = [ expected_backend_native_funcs: list[NativeFunction] = [
f f
for f in native_functions for f in native_functions
if f.func.name in expected_backend_op_names.keys() if f.func.name in expected_backend_op_names.keys()
and f.func.name not in full_codegen and f.func.name not in full_codegen
] ]
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict( expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
list list
) )
for native_f in expected_backend_native_funcs: for native_f in expected_backend_native_funcs:
@ -356,10 +358,10 @@ def gen_dispatchkey_nativefunc_headers(
fm: FileManager, fm: FileManager,
class_name: str, class_name: str,
cpp_namespace: str, cpp_namespace: str,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_dispatch_key: DispatchKey, backend_dispatch_key: DispatchKey,
autograd_dispatch_key: Optional[DispatchKey], autograd_dispatch_key: DispatchKey | None,
backend_name: str = "", backend_name: str = "",
) -> None: ) -> None:
assert class_name is not None assert class_name is not None
@ -413,11 +415,11 @@ def gen_dispatcher_registrations(
fm: FileManager, fm: FileManager,
output_dir: str, output_dir: str,
class_name: str, class_name: str,
backend_indices: Dict[DispatchKey, BackendIndex], backend_indices: dict[DispatchKey, BackendIndex],
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
backend_dispatch_key: DispatchKey, backend_dispatch_key: DispatchKey,
dispatch_key: DispatchKey, dispatch_key: DispatchKey,
selector: "SelectiveBuilder", selector: SelectiveBuilder,
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
build_in_tree: bool = False, build_in_tree: bool = False,
per_operator_headers: bool = False, per_operator_headers: bool = False,
@ -524,7 +526,7 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
def run( def run(
source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
) -> None: ) -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
pytorch_root = pathlib.Path(__file__).parent.parent.absolute() pytorch_root = pathlib.Path(__file__).parent.parent.absolute()

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import argparse import argparse
import os import os
import pathlib
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union from pathlib import Path
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
import yaml import yaml
@ -45,7 +47,6 @@ from torchgen.model import (
OperatorName, OperatorName,
Variant, Variant,
) )
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import ( from torchgen.utils import (
context, context,
FileManager, FileManager,
@ -55,7 +56,11 @@ from torchgen.utils import (
) )
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str: if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
""" """
A wrapper function to basically get `sig.decl(include_context=True)`. A wrapper function to basically get `sig.decl(include_context=True)`.
For ATen kernel, the codegen has no idea about ET contextArg, so we For ATen kernel, the codegen has no idea about ET contextArg, so we
@ -72,9 +77,9 @@ def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
def static_dispatch( def static_dispatch(
sig: Union[CppSignature, ExecutorchCppSignature], sig: CppSignature | ExecutorchCppSignature,
f: NativeFunction, f: NativeFunction,
backend_indices: List[BackendIndex], backend_indices: list[BackendIndex],
) -> str: ) -> str:
""" """
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
@ -113,7 +118,7 @@ TORCH_API inline {_sig_decl_wrapper(sig)} {{
# and the scaffolding to call into the dispatcher from these functions. # and the scaffolding to call into the dispatcher from these functions.
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeFunction: class ComputeFunction:
static_dispatch_backend_indices: List[BackendIndex] static_dispatch_backend_indices: list[BackendIndex]
selector: SelectiveBuilder selector: SelectiveBuilder
@ -122,7 +127,7 @@ class ComputeFunction:
is_custom_op: Callable[[NativeFunction], bool] is_custom_op: Callable[[NativeFunction], bool]
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
is_method_variant = False is_method_variant = False
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"): if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
return None return None
@ -136,7 +141,7 @@ class ComputeFunction:
f"Can't handle native function {f.func} with the following variant specification {f.variants}." f"Can't handle native function {f.func} with the following variant specification {f.variants}."
) )
sig: Union[CppSignature, ExecutorchCppSignature] = ( sig: CppSignature | ExecutorchCppSignature = (
CppSignatureGroup.from_native_function( CppSignatureGroup.from_native_function(
f, method=False, fallback_binding=f.manual_cpp_binding f, method=False, fallback_binding=f.manual_cpp_binding
).most_faithful_signature() ).most_faithful_signature()
@ -179,10 +184,10 @@ class ComputeCodegenUnboxedKernels:
@method_with_nested_native_function @method_with_nested_native_function
def __call__( def __call__(
self, self,
unbox_kernel_entry: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]], unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
) -> str: ) -> str:
f: NativeFunction = unbox_kernel_entry[0] f: NativeFunction = unbox_kernel_entry[0]
kernel_key: Union[ETKernelKey, List[ETKernelKey]] = unbox_kernel_entry[1][0] kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1] kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
op_name = f"{f.namespace}::{f.func.name}" op_name = f"{f.namespace}::{f.func.name}"
@ -196,7 +201,7 @@ class ComputeCodegenUnboxedKernels:
) )
if not used_kernel_keys: if not used_kernel_keys:
return "" return ""
sig: Union[CppSignature, ExecutorchCppSignature] sig: CppSignature | ExecutorchCppSignature
argument_type_gen: Callable[..., NamedCType] argument_type_gen: Callable[..., NamedCType]
return_type_gen: Callable[..., CType] return_type_gen: Callable[..., CType]
if self.use_aten_lib: if self.use_aten_lib:
@ -290,11 +295,11 @@ def gen_unboxing(
) -> None: ) -> None:
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
def key_func( def key_func(
item: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]] item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
) -> str: ) -> str:
return item[0].root_name + ":" + item[1][0].to_native_string() return item[0].root_name + ":" + item[1][0].to_native_string()
items: List[Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]] = [ items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
(native_function, (kernel_key, metadata)) (native_function, (kernel_key, metadata))
for native_function in native_functions for native_function in native_functions
for kernel_key, metadata in kernel_index.get_kernels(native_function).items() for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
@ -325,8 +330,8 @@ def gen_unboxing(
@with_native_function_and_index # type: ignore[arg-type] @with_native_function_and_index # type: ignore[arg-type]
def compute_native_function_declaration( def compute_native_function_declaration(
g: Union[NativeFunctionsGroup, NativeFunction], kernel_index: ETKernelIndex g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
) -> List[str]: ) -> list[str]:
assert isinstance(g, NativeFunction) assert isinstance(g, NativeFunction)
sig = ExecutorchCppSignature.from_native_function(f=g) sig = ExecutorchCppSignature.from_native_function(f=g)
metadata_list = kernel_index.get_kernels(g).values() metadata_list = kernel_index.get_kernels(g).values()
@ -352,7 +357,7 @@ def gen_functions_declarations(
kernel_index: ETKernelIndex, kernel_index: ETKernelIndex,
selector: SelectiveBuilder, selector: SelectiveBuilder,
use_aten_lib: bool, use_aten_lib: bool,
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None, custom_ops_native_functions: Sequence[NativeFunction] | None = None,
) -> str: ) -> str:
""" """
Generates namespace separated C++ function API inline declaration/definitions. Generates namespace separated C++ function API inline declaration/definitions.
@ -406,13 +411,13 @@ def get_ns_grouped_kernels(
kernel_index: ETKernelIndex, kernel_index: ETKernelIndex,
native_function_decl_gen: Callable[ native_function_decl_gen: Callable[
[ [
Union[NativeFunctionsGroup, NativeFunction], NativeFunctionsGroup | NativeFunction,
ETKernelIndex, ETKernelIndex,
], ],
List[str], list[str],
], ],
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list) ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
for f in native_functions: for f in native_functions:
native_function_namespaces = set() native_function_namespaces = set()
op_kernels = kernel_index.get_kernels(f) op_kernels = kernel_index.get_kernels(f)
@ -595,7 +600,7 @@ def gen_custom_ops(
def translate_native_yaml( def translate_native_yaml(
tags_yaml_path: str, tags_yaml_path: str,
aten_yaml_path: str, aten_yaml_path: str,
native_yaml_path: Optional[str], native_yaml_path: str | None,
use_aten_lib: bool, use_aten_lib: bool,
out_file: TextIO, out_file: TextIO,
) -> None: ) -> None:
@ -646,15 +651,15 @@ def translate_native_yaml(
skip_native_fns_gen=False, skip_native_fns_gen=False,
) )
func_to_scoped_name: Dict[FunctionSchema, str] = { func_to_scoped_name: dict[FunctionSchema, str] = {
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
} }
op_to_scoped_name: Dict[OperatorName, str] = { op_to_scoped_name: dict[OperatorName, str] = {
func.name: name for func, name in func_to_scoped_name.items() func.name: name for func, name in func_to_scoped_name.items()
} }
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()} schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
kernel_persist_dict: Dict[str, Dict[str, Any]] = { kernel_persist_dict: dict[str, dict[str, Any]] = {
op_to_scoped_name[op]: v for op, v in persisted_fields.items() op_to_scoped_name[op]: v for op, v in persisted_fields.items()
} }
@ -692,13 +697,13 @@ def translate_native_yaml(
def parse_yaml( def parse_yaml(
path: Optional[str], path: str | None,
tags_yaml_path: str, tags_yaml_path: str,
function_filter: Callable[[NativeFunction], bool], function_filter: Callable[[NativeFunction], bool],
skip_native_fns_gen: bool = False, skip_native_fns_gen: bool = False,
) -> Tuple[ ) -> tuple[
List[NativeFunction], list[NativeFunction],
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex], dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
]: ]:
if path and os.path.exists(path) and os.stat(path).st_size > 0: if path and os.path.exists(path) and os.stat(path).st_size > 0:
with open(path) as f: with open(path) as f:
@ -735,8 +740,8 @@ def parse_yaml(
# (2) Return BackendIndices if kernel index is absent # (2) Return BackendIndices if kernel index is absent
def map_index( def map_index(
m: Dict[OperatorName, BackendMetadata] m: dict[OperatorName, BackendMetadata]
) -> Dict[OperatorName, BackendMetadata]: ) -> dict[OperatorName, BackendMetadata]:
return {op: m[op] for op in m if op in op_names} return {op: m[op] for op in m if op in op_names}
backend_indices = { backend_indices = {
@ -751,11 +756,11 @@ def parse_yaml(
def parse_yaml_files( def parse_yaml_files(
tags_yaml_path: str, tags_yaml_path: str,
aten_yaml_path: str, aten_yaml_path: str,
native_yaml_path: Optional[str], native_yaml_path: str | None,
custom_ops_yaml_path: Optional[str], custom_ops_yaml_path: str | None,
selector: SelectiveBuilder, selector: SelectiveBuilder,
use_aten_lib: bool, use_aten_lib: bool,
) -> Tuple[ETParsedYaml, Optional[ETParsedYaml]]: ) -> tuple[ETParsedYaml, ETParsedYaml | None]:
"""Parses functions.yaml and custom_ops.yaml files. """Parses functions.yaml and custom_ops.yaml files.
Args: Args:
@ -978,7 +983,7 @@ def main() -> None:
) )
if options.output_dependencies: if options.output_dependencies:
depfile_path = pathlib.Path(options.output_dependencies).resolve() depfile_path = Path(options.output_dependencies).resolve()
depfile_name = depfile_path.name depfile_name = depfile_path.name
depfile_stem = depfile_path.stem depfile_stem = depfile_path.stem

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, TYPE_CHECKING
from torchgen.api import cpp, dispatcher from torchgen.api import cpp, dispatcher
from torchgen.api.translate import translate from torchgen.api.translate import translate
@ -46,10 +48,13 @@ from torchgen.native_function_generation import (
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT, MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY, OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
) )
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import dataclass_repr from torchgen.utils import dataclass_repr
if TYPE_CHECKING:
from torchgen.selective_build.selector import SelectiveBuilder
# Note: [Mutable Ops Not Using Functionalization] # Note: [Mutable Ops Not Using Functionalization]
# Ops in this list currently do not work with functionalization and should be fixed. # Ops in this list currently do not work with functionalization and should be fixed.
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = ( MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
@ -88,7 +93,7 @@ class GenCompositeViewCopyKernel:
backend_index: BackendIndex backend_index: BackendIndex
@method_with_native_function @method_with_native_function
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]: def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
if g.view_copy is None: if g.view_copy is None:
return None return None
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy": elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
@ -160,7 +165,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
""" """
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
assert len(rets) == len(names) assert len(rets) == len(names)
if len(rets) == 0: if len(rets) == 0:
return "" return ""
@ -184,7 +189,7 @@ def wrapper_name(func: FunctionSchema) -> str:
return cpp.name(func) return cpp.name(func)
def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> bool: def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
return isinstance(a, SelfArgument) or ( return isinstance(a, SelfArgument) or (
isinstance(a, Argument) and a.type.is_tensor_like() isinstance(a, Argument) and a.type.is_tensor_like()
) )
@ -194,7 +199,7 @@ def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) ->
# Some op schemas include non-owning types though (like TensorList), # Some op schemas include non-owning types though (like TensorList),
# and when we unwrap them we expect to get out an owning type!. # and when we unwrap them we expect to get out an owning type!.
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type. # We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]: def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
if t == BaseCType(tensorListT): if t == BaseCType(tensorListT):
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()" return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
if t == BaseCType(iTensorListRefT): if t == BaseCType(iTensorListRefT):
@ -209,9 +214,9 @@ def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
# (2) a context, to be used by translate(), with all of the relevant bindings. # (2) a context, to be used by translate(), with all of the relevant bindings.
def unwrap_tensor_args( def unwrap_tensor_args(
sig: DispatcherSignature, *, is_view_op: bool sig: DispatcherSignature, *, is_view_op: bool
) -> Tuple[str, List[Binding]]: ) -> tuple[str, list[Binding]]:
context: List[Binding] = [] context: list[Binding] = []
unwrapped_tensor_args: List[str] = [] unwrapped_tensor_args: list[str] = []
for arg in sig.arguments(): for arg in sig.arguments():
if is_tensor_like(arg.argument): if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls. # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
@ -247,9 +252,9 @@ def unwrap_tensor_args(
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns: # converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
# (1) a string containing all of the logic that does the conversions. # (1) a string containing all of the logic that does the conversions.
# (2) a context, to be used by translate(), with all of the relevant bindings. # (2) a context, to be used by translate(), with all of the relevant bindings.
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]: def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
context: List[Binding] = [] context: list[Binding] = []
unwrapped_tensor_args: List[str] = [] unwrapped_tensor_args: list[str] = []
for arg in sig.arguments(): for arg in sig.arguments():
if is_tensor_like(arg.argument): if is_tensor_like(arg.argument):
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls. # for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
@ -317,7 +322,7 @@ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
# Detects whether any of the SymInt arguments are, in fact, symbolic values. # Detects whether any of the SymInt arguments are, in fact, symbolic values.
# This is used in the constructor of ViewMeta. # This is used in the constructor of ViewMeta.
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]: def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
name = "has_symbolic_inputs" name = "has_symbolic_inputs"
statements = [ statements = [
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});" f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
@ -522,7 +527,7 @@ def maybe_create_output(f: NativeFunction, var_name: str) -> str:
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function # - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
def get_mutable_redispatch_return_names( def get_mutable_redispatch_return_names(
f: NativeFunction, inner_return_var: str f: NativeFunction, inner_return_var: str
) -> Tuple[List[str], List[str]]: ) -> tuple[list[str], list[str]]:
aliased_returns = [] aliased_returns = []
non_aliased_returns = [] non_aliased_returns = []
for i, name in enumerate(f.func.aliased_return_names()): for i, name in enumerate(f.func.aliased_return_names()):
@ -751,11 +756,11 @@ def emit_inplace_functionalization_body(
# See Note [Functionalization Pass: View Inverses]. # See Note [Functionalization Pass: View Inverses].
def gen_functionalization_view_inverse_declaration( def gen_functionalization_view_inverse_declaration(
selector: SelectiveBuilder, g: NativeFunctionsViewGroup selector: SelectiveBuilder, g: NativeFunctionsViewGroup
) -> Optional[str]: ) -> str | None:
# For every (non-composite) view op, we need a corresponding "inverse view" function. # For every (non-composite) view op, we need a corresponding "inverse view" function.
# This generates the declarations so we get a good compiler error when someone adds a new view. # This generates the declarations so we get a good compiler error when someone adds a new view.
@with_native_function @with_native_function
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]: def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
if g.view.has_composite_implicit_autograd_kernel: if g.view.has_composite_implicit_autograd_kernel:
return None return None
view_inverse_sig = ViewInverseSignature(g) view_inverse_sig = ViewInverseSignature(g)
@ -766,9 +771,9 @@ def gen_functionalization_view_inverse_declaration(
def gen_functionalization_registration( def gen_functionalization_registration(
selector: SelectiveBuilder, selector: SelectiveBuilder,
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
composite_implicit_autograd_index: BackendIndex, composite_implicit_autograd_index: BackendIndex,
) -> List[str]: ) -> list[str]:
@with_native_function @with_native_function
def emit_registration_helper(f: NativeFunction) -> str: def emit_registration_helper(f: NativeFunction) -> str:
assert not f.has_composite_implicit_autograd_kernel assert not f.has_composite_implicit_autograd_kernel
@ -832,8 +837,8 @@ def gen_functionalization_definition(
# (and instead only need to operate on grouped NativeFunctions). # (and instead only need to operate on grouped NativeFunctions).
# The only reason currently is because we need to emit direct dispatch registrations # The only reason currently is because we need to emit direct dispatch registrations
# For CompositeImplicitAutograd operators, which are potentially ungrouped. # For CompositeImplicitAutograd operators, which are potentially ungrouped.
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup], g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
) -> List[str]: ) -> list[str]:
# Don't generate kernels in mobile build # Don't generate kernels in mobile build
if not selector.include_all_operators: if not selector.include_all_operators:
return [] return []

View File

@ -1,19 +1,10 @@
from __future__ import annotations
import argparse import argparse
import os import os
import pathlib import pathlib
from collections import namedtuple from collections import namedtuple
from typing import ( from typing import Any, Callable, Iterable, Iterator, Sequence
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
import yaml import yaml
@ -102,8 +93,8 @@ ParsedExternalYaml = namedtuple(
def parse_native_functions_keys( def parse_native_functions_keys(
backend_yaml_path: str, backend_yaml_path: str,
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]], grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]: ) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
with open(backend_yaml_path) as f: with open(backend_yaml_path) as f:
yaml_values = yaml.load(f, Loader=YamlLoader) yaml_values = yaml.load(f, Loader=YamlLoader)
assert isinstance(yaml_values, dict) assert isinstance(yaml_values, dict)
@ -120,7 +111,7 @@ def parse_native_functions_keys(
def validate_shape_inference_header( def validate_shape_inference_header(
shape_inference_hdr: str, expected_shape_infr_decls: List[str] shape_inference_hdr: str, expected_shape_infr_decls: list[str]
) -> None: ) -> None:
try: try:
with open(shape_inference_hdr) as f: with open(shape_inference_hdr) as f:
@ -180,12 +171,12 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
class default_args: class default_args:
node_base: str = "Node" node_base: str = "Node"
node_base_hdr: Optional[str] = None node_base_hdr: str | None = None
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h" shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
tensor_class: str = "torch::lazy::LazyTensor" tensor_class: str = "torch::lazy::LazyTensor"
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h" tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
lazy_ir_generator: Type[GenLazyIR] = GenLazyIR lazy_ir_generator: type[GenLazyIR] = GenLazyIR
native_func_definition_generator: Type[ native_func_definition_generator: type[
GenLazyNativeFuncDefinition GenLazyNativeFuncDefinition
] = GenLazyNativeFuncDefinition ] = GenLazyNativeFuncDefinition
backend_name: str = "TorchScript" backend_name: str = "TorchScript"
@ -263,10 +254,10 @@ def main() -> None:
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
torch_root = pathlib.Path(__file__).parent.parent.parent.absolute() torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
aten_path = str(torch_root / "aten" / "src" / "ATen") aten_path = str(torch_root / "aten" / "src" / "ATen")
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
if options.gen_ts_lowerings: if options.gen_ts_lowerings:
lazy_ir_generator = GenTSLazyIR lazy_ir_generator = GenTSLazyIR
native_func_definition_generator: Type[ native_func_definition_generator: type[
GenLazyNativeFuncDefinition GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator ] = default_args.native_func_definition_generator
@ -292,14 +283,14 @@ def run_gen_lazy_tensor(
source_yaml: str, source_yaml: str,
output_dir: str, output_dir: str,
dry_run: bool, dry_run: bool,
impl_path: Optional[str], impl_path: str | None,
node_base: str = default_args.node_base, node_base: str = default_args.node_base,
node_base_hdr: Optional[str] = default_args.node_base_hdr, node_base_hdr: str | None = default_args.node_base_hdr,
tensor_class: str = default_args.tensor_class, tensor_class: str = default_args.tensor_class,
tensor_class_hdr: str = default_args.tensor_class_hdr, tensor_class_hdr: str = default_args.tensor_class_hdr,
shape_inference_hdr: str = default_args.shape_inference_hdr, shape_inference_hdr: str = default_args.shape_inference_hdr,
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator, lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
native_func_definition_generator: Type[ native_func_definition_generator: type[
GenLazyNativeFuncDefinition GenLazyNativeFuncDefinition
] = default_args.native_func_definition_generator, ] = default_args.native_func_definition_generator,
# build_in_tree is true for TS backend and affects include paths # build_in_tree is true for TS backend and affects include paths
@ -347,7 +338,7 @@ def run_gen_lazy_tensor(
) )
grouped_native_functions = get_grouped_native_functions(native_functions) grouped_native_functions = get_grouped_native_functions(native_functions)
def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str: def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
""" """
We sort the native function because of the note in concat_map_codegen. We sort the native function because of the note in concat_map_codegen.
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly. TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
@ -377,8 +368,8 @@ def run_gen_lazy_tensor(
def concat_map_codegen( def concat_map_codegen(
func: Callable[[NativeFunction], Sequence[str]], func: Callable[[NativeFunction], Sequence[str]],
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]], xs: Iterable[NativeFunctionsGroup | NativeFunction],
ops_list: List[OperatorName] = full_codegen, ops_list: list[OperatorName] = full_codegen,
) -> Iterator[str]: ) -> Iterator[str]:
""" """
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple from typing import Sequence
from torchgen.api.translate import translate from torchgen.api.translate import translate
from torchgen.api.types import DispatcherSignature from torchgen.api.types import DispatcherSignature
@ -32,7 +34,7 @@ def is_tensor_list(typ: Type) -> bool:
return isinstance(typ, ListType) and is_tensor(typ.elem) return isinstance(typ, ListType) and is_tensor(typ.elem)
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]: def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\ result = f"""\
Tensor {name}_value; Tensor {name}_value;
optional<int64_t> {name}_bdim; optional<int64_t> {name}_bdim;
@ -40,7 +42,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
return textwrap.dedent(result).split("\n") return textwrap.dedent(result).split("\n")
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]: def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
result = f"""\ result = f"""\
optional<Tensor> {name}_value; optional<Tensor> {name}_value;
optional<int64_t> {name}_bdim; optional<int64_t> {name}_bdim;
@ -52,7 +54,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
def gen_unwraps( def gen_unwraps(
flat_arguments: Sequence[Argument], cur_level_var: str flat_arguments: Sequence[Argument], cur_level_var: str
) -> Tuple[str, List[str]]: ) -> tuple[str, list[str]]:
arg_names = [a.name for a in flat_arguments] arg_names = [a.name for a in flat_arguments]
arg_types = [a.type for a in flat_arguments] arg_types = [a.type for a in flat_arguments]
@ -99,7 +101,7 @@ if ({' && '.join(conditions)}) {{
def gen_returns( def gen_returns(
returns: Tuple[Return, ...], cur_level_var: str, results_var: str returns: tuple[Return, ...], cur_level_var: str, results_var: str
) -> str: ) -> str:
idx = 0 idx = 0
wrapped_returns = [] wrapped_returns = []
@ -132,7 +134,7 @@ def is_mutated_arg(argument: Argument) -> bool:
return argument.annotation is not None and argument.annotation.is_write return argument.annotation is not None and argument.annotation.is_write
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]: def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
# Assumptions: # Assumptions:
# - only one argument is being modified in-place # - only one argument is being modified in-place
# - the argument that is being modified in-place is the first argument # - the argument that is being modified in-place is the first argument
@ -197,7 +199,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
}}""" }}"""
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]: def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
schema = native_function.func schema = native_function.func
sig = DispatcherSignature.from_schema(schema) sig = DispatcherSignature.from_schema(schema)
returns = schema.returns returns = schema.returns
@ -244,7 +246,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
@dataclass(frozen=True) @dataclass(frozen=True)
class ComputeBatchRulePlumbing: class ComputeBatchRulePlumbing:
@method_with_native_function @method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]: def __call__(self, f: NativeFunction) -> str | None:
result = gen_vmap_plumbing(f) result = gen_vmap_plumbing(f)
return result return result

View File

@ -1,6 +1,8 @@
from __future__ import annotations
import threading import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import Iterator, Optional from typing import Iterator
# Simple dynamic scoping implementation. The name "parametrize" comes # Simple dynamic scoping implementation. The name "parametrize" comes
@ -17,8 +19,8 @@ from typing import Iterator, Optional
class Locals(threading.local): class Locals(threading.local):
use_const_ref_for_mutable_tensors: Optional[bool] = None use_const_ref_for_mutable_tensors: bool | None = None
use_ilistref_for_tensor_lists: Optional[bool] = None use_ilistref_for_tensor_lists: bool | None = None
_locals = Locals() _locals = Locals()

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import dataclasses import dataclasses
import itertools import itertools
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union from typing import Callable, Iterator, Sequence
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
@ -229,7 +231,7 @@ class DispatchKey(Enum):
return str(self).lower() return str(self).lower()
@staticmethod @staticmethod
def parse(value: str) -> "DispatchKey": def parse(value: str) -> DispatchKey:
for k, v in DispatchKey.__members__.items(): for k, v in DispatchKey.__members__.items():
if k == value: if k == value:
return v return v
@ -350,20 +352,20 @@ class ScalarType(Enum):
return self.name return self.name
@staticmethod @staticmethod
def maybe_parse(value: str) -> Optional["ScalarType"]: def maybe_parse(value: str) -> ScalarType | None:
for k, v in ScalarType.__members__.items(): for k, v in ScalarType.__members__.items():
if k == value: if k == value:
return v return v
return None return None
@staticmethod @staticmethod
def parse(value: str) -> "ScalarType": def parse(value: str) -> ScalarType:
mb_r = ScalarType.maybe_parse(value) mb_r = ScalarType.maybe_parse(value)
assert mb_r is not None, f"unknown dtype {value}" assert mb_r is not None, f"unknown dtype {value}"
return mb_r return mb_r
@staticmethod @staticmethod
def parse_set(values: str) -> OrderedSet["ScalarType"]: def parse_set(values: str) -> OrderedSet[ScalarType]:
dtypes: OrderedSet[ScalarType] = OrderedSet() dtypes: OrderedSet[ScalarType] = OrderedSet()
for value in values.split(", "): for value in values.split(", "):
if value in DTYPE_CLASSES: if value in DTYPE_CLASSES:
@ -373,7 +375,7 @@ class ScalarType(Enum):
return dtypes return dtypes
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {} DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
# NB: Integral doesn't include boolean # NB: Integral doesn't include boolean
DTYPE_CLASSES["Integral"] = OrderedSet( DTYPE_CLASSES["Integral"] = OrderedSet(
[ [
@ -419,7 +421,7 @@ class UfuncKey(Enum):
return self.name return self.name
@staticmethod @staticmethod
def parse(value: str) -> "UfuncKey": def parse(value: str) -> UfuncKey:
for k, v in UfuncKey.__members__.items(): for k, v in UfuncKey.__members__.items():
if k == value: if k == value:
return v return v
@ -462,7 +464,7 @@ class NativeFunction:
# (This type is quoted as we are forward referencing a type # (This type is quoted as we are forward referencing a type
# defined later in the file. I opted for this ordering of the # defined later in the file. I opted for this ordering of the
# classes for expository clarity.) # classes for expository clarity.)
func: "FunctionSchema" func: FunctionSchema
# Whether or not to generate mutable tensor arguments like regular # Whether or not to generate mutable tensor arguments like regular
# ones # ones
@ -475,14 +477,14 @@ class NativeFunction:
device_check: DeviceCheckType device_check: DeviceCheckType
# What python module to put the function in # What python module to put the function in
python_module: Optional[str] python_module: str | None
# TODO: figure out what this does # TODO: figure out what this does
category_override: Optional[str] category_override: str | None
# If no variants are specified in native_functions.yaml, this is # If no variants are specified in native_functions.yaml, this is
# assumed to be {'function'}. # assumed to be {'function'}.
variants: Set[Variant] variants: set[Variant]
# Whether or not we should skip generating registrations for # Whether or not we should skip generating registrations for
# this kernel. This is a bit of a double-edged sword, as manual # this kernel. This is a bit of a double-edged sword, as manual
@ -497,7 +499,7 @@ class NativeFunction:
# The location in the YAML file were this native function entry was # The location in the YAML file were this native function entry was
# defined. This is for conveniently reporting error messages! # defined. This is for conveniently reporting error messages!
loc: "Location" loc: Location
# A list of operators that are expected to be auto-generated for this NativeFunction. # A list of operators that are expected to be auto-generated for this NativeFunction.
# Note: This list isn't actually directly used by the codegen to generate anything. # Note: This list isn't actually directly used by the codegen to generate anything.
@ -505,11 +507,11 @@ class NativeFunction:
# function schema, and uses the autogen declarations to error check. # function schema, and uses the autogen declarations to error check.
# We expect every NativeFunction that gets auto-generated be explicitly called out # We expect every NativeFunction that gets auto-generated be explicitly called out
# in native_functions.yaml # in native_functions.yaml
autogen: List["OperatorName"] autogen: list[OperatorName]
# If non-empty, this kernel is subject to ufunc codegen. # If non-empty, this kernel is subject to ufunc codegen.
# Sorted by ufunc_key # Sorted by ufunc_key
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"] ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
# Whether or not this out functions is a "structured kernel". Structured # Whether or not this out functions is a "structured kernel". Structured
# kernels are defined a little differently from normal kernels; in # kernels are defined a little differently from normal kernels; in
@ -522,13 +524,13 @@ class NativeFunction:
# Whether or not this non-out function is a structured kernel, defined # Whether or not this non-out function is a structured kernel, defined
# in terms of the out kernel referenced by the string here. # in terms of the out kernel referenced by the string here.
structured_delegate: Optional["OperatorName"] structured_delegate: OperatorName | None
# Only valid for structured kernels. Specifies alternative of what # Only valid for structured kernels. Specifies alternative of what
# to inherit from when defining the meta class for the structured # to inherit from when defining the meta class for the structured
# operator. This will usually be TensorIteratorBase. This also # operator. This will usually be TensorIteratorBase. This also
# changes the semantics of set_output to call the parent class. # changes the semantics of set_output to call the parent class.
structured_inherits: Optional[str] structured_inherits: str | None
# Structured kernels can declare elements as "precomputed". These elements # Structured kernels can declare elements as "precomputed". These elements
# are returned by the meta function in one struct and passed to the impl # are returned by the meta function in one struct and passed to the impl
@ -536,11 +538,11 @@ class NativeFunction:
# elements supersede. Information about the names and types of these # elements supersede. Information about the names and types of these
# precomputed elements and how they correspond to kernel arguments is stored # precomputed elements and how they correspond to kernel arguments is stored
# in this member, if applicable. # in this member, if applicable.
precomputed: Optional["Precompute"] precomputed: Precompute | None
# Argument names whose default should be excluded from the C++ interface. # Argument names whose default should be excluded from the C++ interface.
# Intended for resolving overload ambiguities between signatures. # Intended for resolving overload ambiguities between signatures.
cpp_no_default_args: Set[str] cpp_no_default_args: set[str]
# Note [Abstract ATen methods] # Note [Abstract ATen methods]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -560,7 +562,7 @@ class NativeFunction:
# Tags are used to describe semantic information about (groups of) operators, # Tags are used to describe semantic information about (groups of) operators,
# That aren't easily inferrable directly from the operator's schema. # That aren't easily inferrable directly from the operator's schema.
tags: Set[str] tags: set[str]
# NB: The benefit of defining a dataclass is that we automatically get # NB: The benefit of defining a dataclass is that we automatically get
# a constructor defined for all the fields we specify. No need # a constructor defined for all the fields we specify. No need
@ -569,13 +571,11 @@ class NativeFunction:
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
@staticmethod @staticmethod
def from_yaml( def from_yaml(
ei: Dict[str, object], ei: dict[str, object],
loc: "Location", loc: Location,
valid_tags: Set[str], valid_tags: set[str],
ignore_keys: Optional[Set[DispatchKey]] = None, ignore_keys: set[DispatchKey] | None = None,
) -> Tuple[ ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
]:
""" """
Parse a NativeFunction from a dictionary as directly parsed Parse a NativeFunction from a dictionary as directly parsed
from native_functions.yaml from native_functions.yaml
@ -602,7 +602,7 @@ class NativeFunction:
variants_s = e.pop("variants", "function") variants_s = e.pop("variants", "function")
assert isinstance(variants_s, str) assert isinstance(variants_s, str)
variants: Set[Variant] = set() variants: set[Variant] = set()
for v in variants_s.split(", "): for v in variants_s.split(", "):
if v == "function": if v == "function":
variants.add(Variant.function) variants.add(Variant.function)
@ -646,7 +646,7 @@ class NativeFunction:
"namespace is not supported in structured delegate," "namespace is not supported in structured delegate,"
" using the same namespace as the native function" " using the same namespace as the native function"
) )
structured_delegate: Optional[OperatorName] = None structured_delegate: OperatorName | None = None
if structured_delegate_s is not None: if structured_delegate_s is not None:
structured_delegate = OperatorName.parse(structured_delegate_s) structured_delegate = OperatorName.parse(structured_delegate_s)
@ -685,7 +685,7 @@ class NativeFunction:
if namespace == "aten" and "pt2_compliant_tag" in valid_tags: if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
tags_inp.append("pt2_compliant_tag") tags_inp.append("pt2_compliant_tag")
tags: Set[str] = set() tags: set[str] = set()
for t in tags_inp: for t in tags_inp:
assert len(valid_tags) > 0 assert len(valid_tags) > 0
# TODO: verify that the tag is valid and has an entry in tags.yaml # TODO: verify that the tag is valid and has an entry in tags.yaml
@ -698,7 +698,7 @@ class NativeFunction:
raw_dispatch = e.pop("dispatch", None) raw_dispatch = e.pop("dispatch", None)
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
dispatch: Dict[DispatchKey, BackendMetadata] = {} dispatch: dict[DispatchKey, BackendMetadata] = {}
num_dispatch_keys: int = 0 num_dispatch_keys: int = 0
if raw_dispatch is not None: if raw_dispatch is not None:
assert not manual_kernel_registration, ( assert not manual_kernel_registration, (
@ -1081,8 +1081,8 @@ class SchemaKind(Enum):
@dataclass(frozen=True) @dataclass(frozen=True)
class NativeFunctionsGroup: class NativeFunctionsGroup:
functional: NativeFunction functional: NativeFunction
inplace: Optional[NativeFunction] inplace: NativeFunction | None
mutable: Optional[NativeFunction] mutable: NativeFunction | None
out: NativeFunction out: NativeFunction
@property @property
@ -1136,7 +1136,7 @@ class NativeFunctionsGroup:
[str(f.func.name) for f in self.functions() if "generated" in f.tags] [str(f.func.name) for f in self.functions() if "generated" in f.tags]
) )
generated_fns_str = ", ".join(str(x) for x in generated_fns) generated_fns_str = ", ".join(str(x) for x in generated_fns)
expected_generated_fns: Set[str] = set() expected_generated_fns: set[str] = set()
for f in self.functions(): for f in self.functions():
expected_generated_fns.update(str(op) for op in f.autogen) expected_generated_fns.update(str(op) for op in f.autogen)
expected_generated_fns_str = ", ".join( expected_generated_fns_str = ", ".join(
@ -1155,7 +1155,7 @@ class NativeFunctionsGroup:
f" Instead, it found 'autogen: {expected_generated_fns_str}'" f" Instead, it found 'autogen: {expected_generated_fns_str}'"
) )
def signature(self) -> "FunctionSchema": def signature(self) -> FunctionSchema:
return self.out.func.signature() return self.out.func.signature()
def functions(self) -> Iterator[NativeFunction]: def functions(self) -> Iterator[NativeFunction]:
@ -1171,9 +1171,7 @@ class NativeFunctionsGroup:
return self.functional.root_name return self.functional.root_name
@staticmethod @staticmethod
def from_dict( def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
d: Dict[SchemaKind, NativeFunction]
) -> Optional["NativeFunctionsGroup"]:
assert d assert d
if len(d) == 1: if len(d) == 1:
return None return None
@ -1229,7 +1227,7 @@ class UfuncInnerLoop:
ufunc_key: UfuncKey ufunc_key: UfuncKey
@staticmethod @staticmethod
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop": def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
name, supported_dtypes_str = value.split(" ", 1) name, supported_dtypes_str = value.split(" ", 1)
assert supported_dtypes_str[0] == "(" assert supported_dtypes_str[0] == "("
assert supported_dtypes_str[-1] == ")" assert supported_dtypes_str[-1] == ")"
@ -1261,12 +1259,12 @@ class BackendIndex:
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
external: bool external: bool
# Other backend-specific information that is on a per-operator basis # Other backend-specific information that is on a per-operator basis
index: Dict["OperatorName", BackendMetadata] index: dict[OperatorName, BackendMetadata]
@staticmethod @staticmethod
def grow_index( def grow_index(
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None: ) -> None:
for k, v in child_index.items(): for k, v in child_index.items():
for op_name, metadata in v.items(): for op_name, metadata in v.items():
@ -1281,13 +1279,13 @@ class BackendIndex:
else: else:
return g.functional return g.functional
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
m = self.get_kernel(g) m = self.get_kernel(g)
return m is not None return m is not None
def get_kernel( def get_kernel(
self, g: Union[NativeFunction, NativeFunctionsGroup] self, g: NativeFunction | NativeFunctionsGroup
) -> Optional[BackendMetadata]: ) -> BackendMetadata | None:
if isinstance(g, NativeFunction): if isinstance(g, NativeFunction):
f = g f = g
elif isinstance(g, NativeFunctionsGroup): elif isinstance(g, NativeFunctionsGroup):
@ -1298,7 +1296,7 @@ class BackendIndex:
return None return None
return self.index[f.func.name] return self.index[f.func.name]
def native_function_class_name(self) -> Optional[str]: def native_function_class_name(self) -> str | None:
if self.external: if self.external:
return f"{str(self.dispatch_key)}NativeFunctions" return f"{str(self.dispatch_key)}NativeFunctions"
else: else:
@ -1364,16 +1362,16 @@ class BackendIndex:
@dataclass(frozen=True) @dataclass(frozen=True)
class FunctionSchema: class FunctionSchema:
# The name of the operator this function schema describes. # The name of the operator this function schema describes.
name: "OperatorName" name: OperatorName
arguments: "Arguments" arguments: Arguments
# TODO: Need to handle collisions with argument names at some point # TODO: Need to handle collisions with argument names at some point
returns: Tuple["Return", ...] returns: tuple[Return, ...]
@property @property
def is_mutable(self) -> bool: def is_mutable(self) -> bool:
def is_write(arg: "Argument") -> bool: def is_write(arg: Argument) -> bool:
if arg.annotation is None: if arg.annotation is None:
return False return False
return arg.annotation.is_write return arg.annotation.is_write
@ -1382,7 +1380,7 @@ class FunctionSchema:
# See aten/src/ATen/core/function_schema.h (keep these in sync) # See aten/src/ATen/core/function_schema.h (keep these in sync)
return any(is_write(a) for a in self.arguments.flat_all) return any(is_write(a) for a in self.arguments.flat_all)
def schema_order_arguments(self) -> Iterator["Argument"]: def schema_order_arguments(self) -> Iterator[Argument]:
return itertools.chain( return itertools.chain(
self.arguments.flat_positional, self.arguments.flat_positional,
self.arguments.flat_kwarg_only, self.arguments.flat_kwarg_only,
@ -1392,7 +1390,7 @@ class FunctionSchema:
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)") decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
@staticmethod @staticmethod
def parse(func: str) -> "FunctionSchema": def parse(func: str) -> FunctionSchema:
# We should probably get a proper parser here # We should probably get a proper parser here
decls = FunctionSchema.decl_re.findall(func) decls = FunctionSchema.decl_re.findall(func)
assert len(decls) == 1, f"Invalid function schema: {func}" assert len(decls) == 1, f"Invalid function schema: {func}"
@ -1587,8 +1585,8 @@ class FunctionSchema:
# - If the return aliases an input, we return the input name # - If the return aliases an input, we return the input name
# - Otherwise, we return None. # - Otherwise, we return None.
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this. # If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
def aliased_return_names(self) -> List[Optional[str]]: def aliased_return_names(self) -> list[str | None]:
outs: List[Optional[str]] = [] outs: list[str | None] = []
for r in self.returns: for r in self.returns:
aliased_args = [ aliased_args = [
a a
@ -1612,7 +1610,7 @@ class FunctionSchema:
strip_default: bool = False, strip_default: bool = False,
strip_view_copy_name: bool = False, strip_view_copy_name: bool = False,
keep_return_names: bool = False, keep_return_names: bool = False,
) -> "FunctionSchema": ) -> FunctionSchema:
""" """
Certain schemas are 'related', in that they are simply Certain schemas are 'related', in that they are simply
inplace/out/functional versions of the same function. This method inplace/out/functional versions of the same function. This method
@ -1709,10 +1707,10 @@ class FunctionSchema:
returns=returns, returns=returns,
) )
def view_signature(self) -> "FunctionSchema": def view_signature(self) -> FunctionSchema:
return self.signature(strip_view_copy_name=True) return self.signature(strip_view_copy_name=True)
def with_name(self, name: "OperatorName") -> "FunctionSchema": def with_name(self, name: OperatorName) -> FunctionSchema:
return FunctionSchema( return FunctionSchema(
name=name, name=name,
arguments=self.arguments, arguments=self.arguments,
@ -1747,12 +1745,12 @@ class FunctionSchema:
class Annotation: class Annotation:
# Typically only has one element. Not actually a set so # Typically only has one element. Not actually a set so
# we can conveniently assume it is canonically ordered # we can conveniently assume it is canonically ordered
alias_set: Tuple[str, ...] alias_set: tuple[str, ...]
is_write: bool is_write: bool
alias_set_after: Tuple[str, ...] alias_set_after: tuple[str, ...]
@staticmethod @staticmethod
def parse(ann: str) -> "Annotation": def parse(ann: str) -> Annotation:
# TODO: implement a proper parser if this gets more ugly # TODO: implement a proper parser if this gets more ugly
# Regex Explanation: # Regex Explanation:
# Example: "a! -> a|b" # Example: "a! -> a|b"
@ -1805,13 +1803,13 @@ class Annotation:
@dataclass(frozen=True) @dataclass(frozen=True)
class Type: class Type:
@staticmethod @staticmethod
def parse(t: str) -> "Type": def parse(t: str) -> Type:
r = Type._parse(t) r = Type._parse(t)
assert str(r) == t, f"{r} != {t}" assert str(r) == t, f"{r} != {t}"
return r return r
@staticmethod @staticmethod
def _parse(t: str) -> "Type": def _parse(t: str) -> Type:
m = re.match(r"^(.+)\?$", t) m = re.match(r"^(.+)\?$", t)
if m is not None: if m is not None:
return OptionalType(Type.parse(m.group(1))) return OptionalType(Type.parse(m.group(1)))
@ -1837,7 +1835,7 @@ class Type:
# so we can conveniently generate legacy Declarations.yaml but # so we can conveniently generate legacy Declarations.yaml but
# really we should probably just remove these at some point # really we should probably just remove these at some point
def is_base_ty_like(self, base_ty: "BaseTy") -> bool: def is_base_ty_like(self, base_ty: BaseTy) -> bool:
raise NotImplementedError raise NotImplementedError
def is_tensor_like(self) -> bool: def is_tensor_like(self) -> bool:
@ -1852,7 +1850,7 @@ class Type:
def is_nullable(self) -> bool: def is_nullable(self) -> bool:
raise NotImplementedError raise NotImplementedError
def is_list_like(self) -> Optional["ListType"]: def is_list_like(self) -> ListType | None:
raise NotImplementedError raise NotImplementedError
@ -1892,7 +1890,7 @@ class BaseType(Type):
def is_nullable(self) -> bool: def is_nullable(self) -> bool:
return False return False
def is_list_like(self) -> Optional["ListType"]: def is_list_like(self) -> ListType | None:
return None return None
def is_symint_like(self) -> bool: def is_symint_like(self) -> bool:
@ -1916,7 +1914,7 @@ class OptionalType(Type):
def is_nullable(self) -> bool: def is_nullable(self) -> bool:
return True return True
def is_list_like(self) -> Optional["ListType"]: def is_list_like(self) -> ListType | None:
return self.elem.is_list_like() return self.elem.is_list_like()
@ -1943,7 +1941,7 @@ class CustomClassType(Type):
""" """
return False return False
def is_list_like(self) -> Optional["ListType"]: def is_list_like(self) -> ListType | None:
return None return None
@ -1957,7 +1955,7 @@ class CustomClassType(Type):
@dataclass(frozen=True) @dataclass(frozen=True)
class ListType(Type): class ListType(Type):
elem: Type elem: Type
size: Optional[int] size: int | None
def __str__(self) -> str: def __str__(self) -> str:
size = f"{self.size}" if self.size else "" size = f"{self.size}" if self.size else ""
@ -1972,7 +1970,7 @@ class ListType(Type):
def is_nullable(self) -> bool: def is_nullable(self) -> bool:
return self.elem.is_nullable() return self.elem.is_nullable()
def is_list_like(self) -> Optional["ListType"]: def is_list_like(self) -> ListType | None:
return self return self
@ -1983,7 +1981,7 @@ class Argument:
name: str name: str
type: Type type: Type
default: Optional[str] default: str | None
# The semantics of the annotation field are a little strange. # The semantics of the annotation field are a little strange.
# #
@ -2004,16 +2002,16 @@ class Argument:
# structure of annotated types is very simple. So we just hard # structure of annotated types is very simple. So we just hard
# code it here. But if we ever do get anything more complex, this # code it here. But if we ever do get anything more complex, this
# model will have to change! # model will have to change!
annotation: Optional[Annotation] annotation: Annotation | None
@property @property
def alias_info(self) -> Optional[Annotation]: def alias_info(self) -> Annotation | None:
return self.annotation return self.annotation
@staticmethod @staticmethod
def parse(arg: str) -> "Argument": def parse(arg: str) -> Argument:
name: str name: str
default: Optional[str] default: str | None
assert " " in arg, f"illegal argument '{arg}'" assert " " in arg, f"illegal argument '{arg}'"
type_and_annot, name_and_default = arg.rsplit(" ", 1) type_and_annot, name_and_default = arg.rsplit(" ", 1)
if "=" in name_and_default: if "=" in name_and_default:
@ -2026,7 +2024,7 @@ class Argument:
default = None default = None
# TODO: deduplicate annotation matching with Return # TODO: deduplicate annotation matching with Return
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Optional[Annotation] annotation: Annotation | None
if match: if match:
# If you update this, make sure the __str__ still works too # If you update this, make sure the __str__ still works too
assert match.group(2) in [ assert match.group(2) in [
@ -2069,24 +2067,24 @@ class Argument:
@dataclass(frozen=True) @dataclass(frozen=True)
class Return: class Return:
name: Optional[str] name: str | None
type: Type type: Type
annotation: Optional[Annotation] annotation: Annotation | None
@property @property
def alias_info(self) -> Optional[Annotation]: def alias_info(self) -> Annotation | None:
return self.annotation return self.annotation
@staticmethod @staticmethod
def parse(arg: str) -> "Return": def parse(arg: str) -> Return:
name: Optional[str] name: str | None
if " " in arg: if " " in arg:
type_and_annot, name = arg.rsplit(" ", 1) type_and_annot, name = arg.rsplit(" ", 1)
else: else:
type_and_annot = arg type_and_annot = arg
name = None name = None
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
annotation: Optional[Annotation] annotation: Annotation | None
if match: if match:
# If you update this, make sure the __str__ still works too # If you update this, make sure the __str__ still works too
assert match.group(2) in [ assert match.group(2) in [
@ -2148,34 +2146,34 @@ class Arguments:
# pre_self_positional is usually empty, but is notably non-empty # pre_self_positional is usually empty, but is notably non-empty
# for where.self, where the condition argument comes before the # for where.self, where the condition argument comes before the
# self argument # self argument
pre_self_positional: Tuple[Argument, ...] pre_self_positional: tuple[Argument, ...]
self_arg: Optional[SelfArgument] self_arg: SelfArgument | None
post_self_positional: Tuple[Argument, ...] post_self_positional: tuple[Argument, ...]
pre_tensor_options_kwarg_only: Tuple[Argument, ...] pre_tensor_options_kwarg_only: tuple[Argument, ...]
tensor_options: Optional[TensorOptionsArguments] tensor_options: TensorOptionsArguments | None
# post_tensor_options is typically memory format, which should be # post_tensor_options is typically memory format, which should be
# part of tensor options but isn't right now, and is usually # part of tensor options but isn't right now, and is usually
# placed after the tensor options arguments # placed after the tensor options arguments
post_tensor_options_kwarg_only: Tuple[Argument, ...] post_tensor_options_kwarg_only: tuple[Argument, ...]
# Unlike in the previous codegen, we have factored out 'out' arguments # Unlike in the previous codegen, we have factored out 'out' arguments
# in the canonical representation, removing them from kwarg # in the canonical representation, removing them from kwarg
# arguments. This choice is justified by numerous downstream # arguments. This choice is justified by numerous downstream
# transformations which treat out arguments specially; additionally, # transformations which treat out arguments specially; additionally,
# you can see that canonicity is not violated! # you can see that canonicity is not violated!
out: Tuple[Argument, ...] # these are also kwarg-only out: tuple[Argument, ...] # these are also kwarg-only
@property @property
def flat_non_out(self) -> Sequence[Argument]: def flat_non_out(self) -> Sequence[Argument]:
ret: List[Argument] = [] ret: list[Argument] = []
ret.extend(self.flat_positional) ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only) ret.extend(self.flat_kwarg_only)
return ret return ret
@property @property
def flat_positional(self) -> Sequence[Argument]: def flat_positional(self) -> Sequence[Argument]:
ret: List[Argument] = [] ret: list[Argument] = []
ret.extend(self.pre_self_positional) ret.extend(self.pre_self_positional)
if self.self_arg is not None: if self.self_arg is not None:
ret.append(self.self_arg.argument) ret.append(self.self_arg.argument)
@ -2189,7 +2187,7 @@ class Arguments:
# NB: doesn't contain out arguments # NB: doesn't contain out arguments
@property @property
def flat_kwarg_only(self) -> Sequence[Argument]: def flat_kwarg_only(self) -> Sequence[Argument]:
ret: List[Argument] = [] ret: list[Argument] = []
ret.extend(self.pre_tensor_options_kwarg_only) ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None: if self.tensor_options is not None:
ret.extend(self.tensor_options.all()) ret.extend(self.tensor_options.all())
@ -2198,7 +2196,7 @@ class Arguments:
@property @property
def flat_all(self) -> Sequence[Argument]: def flat_all(self) -> Sequence[Argument]:
ret: List[Argument] = [] ret: list[Argument] = []
ret.extend(self.flat_positional) ret.extend(self.flat_positional)
ret.extend(self.flat_kwarg_only) ret.extend(self.flat_kwarg_only)
ret.extend(self.out) ret.extend(self.out)
@ -2207,15 +2205,15 @@ class Arguments:
@property @property
def non_out( def non_out(
self, self,
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: ) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
ret.extend(self.positional) ret.extend(self.positional)
ret.extend(self.kwarg_only) ret.extend(self.kwarg_only)
return ret return ret
@property @property
def positional(self) -> Sequence[Union[Argument, SelfArgument]]: def positional(self) -> Sequence[Argument | SelfArgument]:
ret: List[Union[Argument, SelfArgument]] = [] ret: list[Argument | SelfArgument] = []
ret.extend(self.pre_self_positional) ret.extend(self.pre_self_positional)
if self.self_arg is not None: if self.self_arg is not None:
ret.append(self.self_arg) ret.append(self.self_arg)
@ -2223,8 +2221,8 @@ class Arguments:
return ret return ret
@property @property
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
ret: List[Union[Argument, TensorOptionsArguments]] = [] ret: list[Argument | TensorOptionsArguments] = []
ret.extend(self.pre_tensor_options_kwarg_only) ret.extend(self.pre_tensor_options_kwarg_only)
if self.tensor_options is not None: if self.tensor_options is not None:
ret.append(self.tensor_options) ret.append(self.tensor_options)
@ -2232,14 +2230,14 @@ class Arguments:
return ret return ret
@property @property
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
ret.extend(self.positional) ret.extend(self.positional)
ret.extend(self.kwarg_only) ret.extend(self.kwarg_only)
ret.extend(self.out) ret.extend(self.out)
return ret return ret
def mutable_arg_names(self) -> List[str]: def mutable_arg_names(self) -> list[str]:
return [ return [
a.name a.name
for a in self.flat_all for a in self.flat_all
@ -2255,7 +2253,7 @@ class Arguments:
def has_generator_arg(self) -> bool: def has_generator_arg(self) -> bool:
return any(a.type.is_generator_like() for a in self.flat_non_out) return any(a.type.is_generator_like() for a in self.flat_non_out)
def signature(self, *, strip_default: bool = False) -> "Arguments": def signature(self, *, strip_default: bool = False) -> Arguments:
# dataclasses.replace could be used here, but it is less # dataclasses.replace could be used here, but it is less
# type safe so for now I've opted to type everything out # type safe so for now I've opted to type everything out
def strip_arg_annotation(a: Argument) -> Argument: def strip_arg_annotation(a: Argument) -> Argument:
@ -2290,7 +2288,7 @@ class Arguments:
out=(), out=(),
) )
def remove_self_annotation(self) -> "Arguments": def remove_self_annotation(self) -> Arguments:
assert self.self_arg is not None assert self.self_arg is not None
return dataclasses.replace( return dataclasses.replace(
self, self,
@ -2299,7 +2297,7 @@ class Arguments:
), ),
) )
def with_out_args(self, outs: List[Argument]) -> "Arguments": def with_out_args(self, outs: list[Argument]) -> Arguments:
assert len(self.out) == 0 assert len(self.out) == 0
return dataclasses.replace( return dataclasses.replace(
self, self,
@ -2307,10 +2305,10 @@ class Arguments:
) )
@staticmethod @staticmethod
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
positional: List[Argument] = [] positional: list[Argument] = []
kwarg_only: List[Argument] = [] kwarg_only: list[Argument] = []
out: List[Argument] = [] out: list[Argument] = []
arguments_acc = positional arguments_acc = positional
# TODO: Use a real parser here; this will get bamboozled # TODO: Use a real parser here; this will get bamboozled
@ -2343,7 +2341,7 @@ class Arguments:
return positional, kwarg_only, out return positional, kwarg_only, out
@staticmethod @staticmethod
def parse(args: str) -> "Arguments": def parse(args: str) -> Arguments:
""" """
Input: 'int x, int y, int z' Input: 'int x, int y, int z'
""" """
@ -2361,9 +2359,9 @@ class Arguments:
if a.name == "self": if a.name == "self":
self_ix = i self_ix = i
break break
pre_self_positional: List[Argument] pre_self_positional: list[Argument]
self_arg: Optional[SelfArgument] self_arg: SelfArgument | None
post_self_positional: List[Argument] post_self_positional: list[Argument]
if self_ix is not None: if self_ix is not None:
pre_self_positional = positional[:self_ix] pre_self_positional = positional[:self_ix]
self_arg = SelfArgument(positional[self_ix]) self_arg = SelfArgument(positional[self_ix])
@ -2374,9 +2372,9 @@ class Arguments:
post_self_positional = positional post_self_positional = positional
# Group tensor options arguments # Group tensor options arguments
pre_tensor_options_kwarg_only: List[Argument] = [] pre_tensor_options_kwarg_only: list[Argument] = []
tensor_options: Optional[TensorOptionsArguments] = None tensor_options: TensorOptionsArguments | None = None
post_tensor_options_kwarg_only: List[Argument] = [] post_tensor_options_kwarg_only: list[Argument] = []
kwarg_only_acc = pre_tensor_options_kwarg_only kwarg_only_acc = pre_tensor_options_kwarg_only
def pred(name: str, ty: Type) -> Callable[[Argument], bool]: def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
@ -2423,7 +2421,7 @@ class Arguments:
) )
def __str__(self) -> str: def __str__(self) -> str:
all_arguments: List[str] = [] all_arguments: list[str] = []
all_arguments.extend(map(str, self.flat_positional)) all_arguments.extend(map(str, self.flat_positional))
if self.flat_kwarg_only or self.out: if self.flat_kwarg_only or self.out:
all_arguments.append("*") all_arguments.append("*")
@ -2502,7 +2500,7 @@ class BaseOperatorName:
functional_overload: bool = False functional_overload: bool = False
@staticmethod @staticmethod
def parse(op: str) -> "BaseOperatorName": def parse(op: str) -> BaseOperatorName:
assert op != "" assert op != ""
assert not op.endswith("_out"), ( assert not op.endswith("_out"), (
"_out suffix is reserved and not permitted for operator names; " "_out suffix is reserved and not permitted for operator names; "
@ -2574,7 +2572,7 @@ class OperatorName:
overload_name: str overload_name: str
@staticmethod @staticmethod
def parse(op_name: str) -> "OperatorName": def parse(op_name: str) -> OperatorName:
if "." in op_name: if "." in op_name:
name, overload_name = op_name.split(".", 1) name, overload_name = op_name.split(".", 1)
else: else:
@ -2601,7 +2599,7 @@ class OperatorName:
else: else:
return f"{self.name}" return f"{self.name}"
def remove_inplace(self) -> "OperatorName": def remove_inplace(self) -> OperatorName:
return OperatorName( return OperatorName(
name=BaseOperatorName( name=BaseOperatorName(
base=self.name.base, base=self.name.base,
@ -2611,7 +2609,7 @@ class OperatorName:
overload_name=self.overload_name, overload_name=self.overload_name,
) )
def with_overload(self, overload: str) -> "OperatorName": def with_overload(self, overload: str) -> OperatorName:
return OperatorName( return OperatorName(
name=BaseOperatorName( name=BaseOperatorName(
base=self.name.base, base=self.name.base,
@ -2649,9 +2647,9 @@ class NativeFunctionsViewGroup:
# Note: the {view}_copy operator is optional because we currently don't generate copy variants # Note: the {view}_copy operator is optional because we currently don't generate copy variants
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
# (we already get them "for free" through decomposition) # (we already get them "for free" through decomposition)
view_copy: Optional[NativeFunction] view_copy: NativeFunction | None
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
view_inplace: Optional[NativeFunction] view_inplace: NativeFunction | None
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.view.is_view_op assert self.view.is_view_op
@ -2731,7 +2729,7 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
# Given a NativeFunction that corresponds to a view op, # Given a NativeFunction that corresponds to a view op,
# returns the OperatorName of the corresponding "copy" variant of the op. # returns the OperatorName of the corresponding "copy" variant of the op.
def get_view_copy_name(f: NativeFunction) -> "OperatorName": def get_view_copy_name(f: NativeFunction) -> OperatorName:
# Right now, when asking for a view op's corresponding "view_copy" name # Right now, when asking for a view op's corresponding "view_copy" name
# we assert for sanity that the op is allowed to have a generated view_copy variant. # we assert for sanity that the op is allowed to have a generated view_copy variant.
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
@ -2755,7 +2753,7 @@ def get_view_copy_name(f: NativeFunction) -> "OperatorName":
# Helper functions for parsing argument lists (both inputs and returns) # Helper functions for parsing argument lists (both inputs and returns)
def parse_returns(return_decl: str) -> Tuple[Return, ...]: def parse_returns(return_decl: str) -> tuple[Return, ...]:
""" """
Input: '()' Input: '()'
Output: [] Output: []
@ -2774,12 +2772,12 @@ def parse_returns(return_decl: str) -> Tuple[Return, ...]:
class Precompute: class Precompute:
# A map from kernel argument name -> a list of precomputed # A map from kernel argument name -> a list of precomputed
# elements that replaces/supersedes it. # elements that replaces/supersedes it.
replace: Dict[str, List[Argument]] replace: dict[str, list[Argument]]
# List of precomputed args added without replacement # List of precomputed args added without replacement
add: List[Argument] add: list[Argument]
@staticmethod @staticmethod
def parse(src: object) -> "Precompute": def parse(src: object) -> Precompute:
assert isinstance(src, list) assert isinstance(src, list)
# src is a list of strings of the format: # src is a list of strings of the format:
@ -2824,7 +2822,7 @@ class Precompute:
for a in args: for a in args:
assert a.name.upper() != a.name assert a.name.upper() != a.name
def to_list(self) -> List[str]: def to_list(self) -> list[str]:
replace_list = [] replace_list = []
for kernel_param, replacement_params in self.replace.items(): for kernel_param, replacement_params in self.replace.items():
replacements = ", ".join(str(param) for param in replacement_params) replacements = ", ".join(str(param) for param in replacement_params)

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Sequence
import torchgen.api.dispatcher as dispatcher import torchgen.api.dispatcher as dispatcher
from torchgen.api.translate import translate from torchgen.api.translate import translate
@ -101,9 +103,9 @@ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
# But have differing SchemaKinds. # But have differing SchemaKinds.
def pre_group_native_functions( def pre_group_native_functions(
native_functions: Sequence[NativeFunction], native_functions: Sequence[NativeFunction],
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]: ) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
pre_grouped_native_functions: Dict[ pre_grouped_native_functions: dict[
FunctionSchema, Dict[SchemaKind, NativeFunction] FunctionSchema, dict[SchemaKind, NativeFunction]
] = defaultdict(dict) ] = defaultdict(dict)
for f in native_functions: for f in native_functions:
d = pre_grouped_native_functions[f.func.signature()] d = pre_grouped_native_functions[f.func.signature()]
@ -113,7 +115,7 @@ def pre_group_native_functions(
# Returns the out variant overload name given a base function overload name # Returns the out variant overload name given a base function overload name
def get_expected_out_variant_overload_name(overload_name: Optional[str]) -> str: def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
return "out" if not overload_name else f"{overload_name}_out" return "out" if not overload_name else f"{overload_name}_out"
@ -178,7 +180,7 @@ def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations. # Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
def generate_out_args_from_schema( def generate_out_args_from_schema(
func: FunctionSchema, func: FunctionSchema,
) -> Tuple[List[Return], List[Argument]]: ) -> tuple[list[Return], list[Argument]]:
# More of a sanity check - our existing restrictions on schemas should enforce that # More of a sanity check - our existing restrictions on schemas should enforce that
# mutable schema kinds never return their mutable arguments. # mutable schema kinds never return their mutable arguments.
assert not any( assert not any(
@ -198,11 +200,11 @@ def generate_out_args_from_schema(
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
new_out_args: List[Argument] = [] new_out_args: list[Argument] = []
# The end result of new_returns is that: # The end result of new_returns is that:
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added. # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any). # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
new_returns: List[Return] = [] new_returns: list[Return] = []
for i, r in enumerate(func.returns): for i, r in enumerate(func.returns):
if r.type.is_tensor_like(): if r.type.is_tensor_like():
new_out = Argument( new_out = Argument(
@ -266,7 +268,7 @@ def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
# Details are in the function, but we only generate composite kernels (in some cases) today. # Details are in the function, but we only generate composite kernels (in some cases) today.
def generate_function( def generate_function(
f: NativeFunction, k: SchemaKind f: NativeFunction, k: SchemaKind
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]: ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
from torchgen.api import cpp from torchgen.api import cpp
if k == SchemaKind.functional: if k == SchemaKind.functional:
@ -375,8 +377,8 @@ def generate_function(
# Note: this function *mutates* its two inputs, # Note: this function *mutates* its two inputs,
# adding the new NativeFunctions / BackendMetadata to them # adding the new NativeFunctions / BackendMetadata to them
def add_generated_native_functions( def add_generated_native_functions(
rs: List[NativeFunction], rs: list[NativeFunction],
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
) -> None: ) -> None:
# The main code for generating new NativeFunctions # The main code for generating new NativeFunctions
# First we group of NativeFunctions by schema kind, # First we group of NativeFunctions by schema kind,
@ -497,7 +499,7 @@ out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THA
rs.append(fn) rs.append(fn)
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str: def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
assert len(rets) == len(names) assert len(rets) == len(names)
if len(rets) == 0: if len(rets) == 0:
return "" return ""
@ -509,7 +511,7 @@ def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
# Given a function, and the name of a variable corresponding to the output of that function, # Given a function, and the name of a variable corresponding to the output of that function,
# gather up all of the individual returns that are not aliased # gather up all of the individual returns that are not aliased
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]: def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
aliased_rets = func.aliased_return_names() aliased_rets = func.aliased_return_names()
non_aliased_names = [] non_aliased_names = []
is_out_var_a_tuple = len(func.returns) > 1 is_out_var_a_tuple = len(func.returns) > 1
@ -524,7 +526,7 @@ def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str
# Generates functional kernels in terms of their inplace.mutable counterparts. # Generates functional kernels in terms of their inplace.mutable counterparts.
# We only do this for "generated" NativeFunctions # We only do this for "generated" NativeFunctions
@with_native_function @with_native_function
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]: def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
# We should only be generating these for code-generated NativeFunctions # We should only be generating these for code-generated NativeFunctions
if "generated" not in g.functional.tags: if "generated" not in g.functional.tags:
return None return None
@ -541,7 +543,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
sig = DispatcherSignature(g.functional.func) sig = DispatcherSignature(g.functional.func)
target_sig = DispatcherSignature(target_f.func) target_sig = DispatcherSignature(target_f.func)
context: List[Union[Binding, Expr]] = [] context: list[Binding | Expr] = []
clone_mutable_inputs = [] clone_mutable_inputs = []
cloned_return_names = [] cloned_return_names = []
# We can't just directly pass all of the arguments from the functional op into the mutating op. # We can't just directly pass all of the arguments from the functional op into the mutating op.
@ -587,7 +589,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
# Generates out= kernels in terms of their functional counterparts. # Generates out= kernels in terms of their functional counterparts.
# We only do this for "generated" NativeFunctions # We only do this for "generated" NativeFunctions
@with_native_function @with_native_function
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]: def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
# We should only be generating these for code-generated NativeFunctions # We should only be generating these for code-generated NativeFunctions
if "generated" not in g.out.tags: if "generated" not in g.out.tags:
return None return None

View File

@ -1,9 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from __future__ import annotations
import os import os
from enum import Enum from enum import Enum
from operator import itemgetter from operator import itemgetter
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any
import torch import torch
from torch.jit.generate_bytecode import generate_upgraders_bytecode from torch.jit.generate_bytecode import generate_upgraders_bytecode
@ -185,7 +188,7 @@ PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
) )
def construct_instruction(instruction_list_from_yaml: List[Any]) -> str: def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
instruction_list_part = [] instruction_list_part = []
for instruction in instruction_list_from_yaml: for instruction in instruction_list_from_yaml:
instruction_list_part.append( instruction_list_part.append(
@ -200,7 +203,7 @@ def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
) )
def construct_constants(constants_list_from_yaml: List[Any]) -> str: def construct_constants(constants_list_from_yaml: list[Any]) -> str:
constants_list_part = [] constants_list_part = []
for constant_from_yaml in constants_list_from_yaml: for constant_from_yaml in constants_list_from_yaml:
convert_constant = None convert_constant = None
@ -226,7 +229,7 @@ def construct_constants(constants_list_from_yaml: List[Any]) -> str:
) )
def construct_operators(operator_list_from_yaml: List[Any]) -> str: def construct_operators(operator_list_from_yaml: list[Any]) -> str:
operator_list_part = [] operator_list_part = []
for operator in operator_list_from_yaml: for operator in operator_list_from_yaml:
operator_list_part.append( operator_list_part.append(
@ -241,7 +244,7 @@ def construct_operators(operator_list_from_yaml: List[Any]) -> str:
) )
def construct_types(types_tr_list_from_yaml: List[Any]) -> str: def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
types_tr_list_part = [] types_tr_list_part = []
for types_tr in types_tr_list_from_yaml: for types_tr in types_tr_list_from_yaml:
types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr)) types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
@ -260,7 +263,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
def construct_version_maps( def construct_version_maps(
upgrader_bytecode_function_to_index_map: Dict[str, Any] upgrader_bytecode_function_to_index_map: dict[str, Any]
) -> str: ) -> str:
version_map = torch._C._get_operator_version_map() version_map = torch._C._get_operator_version_map()
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
@ -302,8 +305,8 @@ def construct_version_maps(
def get_upgrader_bytecode_function_to_index_map( def get_upgrader_bytecode_function_to_index_map(
upgrader_dict: List[Dict[str, Any]] upgrader_dict: list[dict[str, Any]]
) -> Dict[str, Any]: ) -> dict[str, Any]:
upgrader_bytecode_function_to_index_map = {} upgrader_bytecode_function_to_index_map = {}
index = 0 index = 0
for upgrader_bytecode in upgrader_dict: for upgrader_bytecode in upgrader_dict:
@ -315,7 +318,7 @@ def get_upgrader_bytecode_function_to_index_map(
return upgrader_bytecode_function_to_index_map return upgrader_bytecode_function_to_index_map
def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None: def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
body_parts = [] body_parts = []
upgrader_bytecode_function_to_index_map = ( upgrader_bytecode_function_to_index_map = (
get_upgrader_bytecode_function_to_index_map(upgrader_dict) get_upgrader_bytecode_function_to_index_map(upgrader_dict)
@ -370,7 +373,7 @@ def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
out_file.write(upgrader_file_content.encode("utf-8")) out_file.write(upgrader_file_content.encode("utf-8"))
def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
sorted_upgrader_list = sorted( sorted_upgrader_list = sorted(
upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader)) upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
) )

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple
# This class holds information about a single operator used to determine # This class holds information about a single operator used to determine
@ -46,12 +47,12 @@ class SelectiveBuildOperator:
include_all_overloads: bool include_all_overloads: bool
# Debug Information at the operator level # Debug Information at the operator level
_debug_info: Optional[Tuple[str, ...]] _debug_info: tuple[str, ...] | None
@staticmethod @staticmethod
def from_yaml_dict( def from_yaml_dict(
op_name: str, op_info: Dict[str, object] op_name: str, op_info: dict[str, object]
) -> "SelectiveBuildOperator": ) -> SelectiveBuildOperator:
allowed_keys = { allowed_keys = {
"name", "name",
"is_root_operator", "is_root_operator",
@ -79,7 +80,7 @@ class SelectiveBuildOperator:
include_all_overloads = op_info.get("include_all_overloads", True) include_all_overloads = op_info.get("include_all_overloads", True)
assert isinstance(include_all_overloads, bool) assert isinstance(include_all_overloads, bool)
debug_info: Optional[Tuple[str, ...]] = None debug_info: tuple[str, ...] | None = None
if "debug_info" in op_info: if "debug_info" in op_info:
di_list = op_info["debug_info"] di_list = op_info["debug_info"]
assert isinstance(di_list, list) assert isinstance(di_list, list)
@ -96,7 +97,7 @@ class SelectiveBuildOperator:
@staticmethod @staticmethod
def from_legacy_operator_name_without_overload( def from_legacy_operator_name_without_overload(
name: str, name: str,
) -> "SelectiveBuildOperator": ) -> SelectiveBuildOperator:
return SelectiveBuildOperator( return SelectiveBuildOperator(
name=name, name=name,
is_root_operator=True, is_root_operator=True,
@ -105,8 +106,8 @@ class SelectiveBuildOperator:
_debug_info=None, _debug_info=None,
) )
def to_dict(self) -> Dict[str, object]: def to_dict(self) -> dict[str, object]:
ret: Dict[str, object] = { ret: dict[str, object] = {
"is_root_operator": self.is_root_operator, "is_root_operator": self.is_root_operator,
"is_used_for_training": self.is_used_for_training, "is_used_for_training": self.is_used_for_training,
"include_all_overloads": self.include_all_overloads, "include_all_overloads": self.include_all_overloads,
@ -118,9 +119,9 @@ class SelectiveBuildOperator:
def merge_debug_info( def merge_debug_info(
lhs: Optional[Tuple[str, ...]], lhs: tuple[str, ...] | None,
rhs: Optional[Tuple[str, ...]], rhs: tuple[str, ...] | None,
) -> Optional[Tuple[str, ...]]: ) -> tuple[str, ...] | None:
# Ensure that when merging, each entry shows up just once. # Ensure that when merging, each entry shows up just once.
if lhs is None and rhs is None: if lhs is None and rhs is None:
return None return None
@ -129,8 +130,8 @@ def merge_debug_info(
def combine_operators( def combine_operators(
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator" lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
) -> "SelectiveBuildOperator": ) -> SelectiveBuildOperator:
if str(lhs.name) != str(rhs.name): if str(lhs.name) != str(rhs.name):
raise Exception( # noqa: TRY002 raise Exception( # noqa: TRY002
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
@ -152,10 +153,10 @@ def combine_operators(
def merge_operator_dicts( def merge_operator_dicts(
lhs: Dict[str, SelectiveBuildOperator], lhs: dict[str, SelectiveBuildOperator],
rhs: Dict[str, SelectiveBuildOperator], rhs: dict[str, SelectiveBuildOperator],
) -> Dict[str, SelectiveBuildOperator]: ) -> dict[str, SelectiveBuildOperator]:
operators: Dict[str, SelectiveBuildOperator] = {} operators: dict[str, SelectiveBuildOperator] = {}
for op_name, op in list(lhs.items()) + list(rhs.items()): for op_name, op in list(lhs.items()) + list(rhs.items()):
new_op = op new_op = op
if op_name in operators: if op_name in operators:

View File

@ -1,11 +1,12 @@
from __future__ import annotations
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING
import yaml import yaml
from torchgen.model import NativeFunction
from torchgen.selective_build.operator import ( from torchgen.selective_build.operator import (
merge_debug_info, merge_debug_info,
merge_operator_dicts, merge_operator_dicts,
@ -14,6 +15,10 @@ from torchgen.selective_build.operator import (
) )
if TYPE_CHECKING:
from torchgen.model import NativeFunction
# A SelectiveBuilder holds information extracted from the selective build # A SelectiveBuilder holds information extracted from the selective build
# YAML specification. # YAML specification.
# #
@ -28,10 +33,10 @@ class SelectiveBuilder:
include_all_operators: bool include_all_operators: bool
# Debug Information at the selective/custom build level. # Debug Information at the selective/custom build level.
_debug_info: Optional[Tuple[str, ...]] _debug_info: tuple[str, ...] | None
# A dictionary of operator -> operator metadata. # A dictionary of operator -> operator metadata.
operators: Dict[str, SelectiveBuildOperator] operators: dict[str, SelectiveBuildOperator]
# A dictionary of selected kernel tags and dtypes. Typically a # A dictionary of selected kernel tags and dtypes. Typically a
# PyTorch Operator Kernel (function) may have many code paths # PyTorch Operator Kernel (function) may have many code paths
@ -39,22 +44,22 @@ class SelectiveBuilder:
# one per kernel function, but there could be many per kernel # one per kernel function, but there could be many per kernel
# function. The tag isn't a kernel function name, but some fragment # function. The tag isn't a kernel function name, but some fragment
# of the kernel function implementation itself. # of the kernel function implementation itself.
kernel_metadata: Dict[str, List[str]] kernel_metadata: dict[str, list[str]]
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input # ExecuTorch only. A dictionary of kernel tag -> list of (list of input
# dtypes for tensor-like input args). # dtypes for tensor-like input args).
# This is from selective.yaml # This is from selective.yaml
et_kernel_metadata: Dict[str, List[str]] et_kernel_metadata: dict[str, list[str]]
# A set of all the custom torch bind classes used by the selected models # A set of all the custom torch bind classes used by the selected models
# Stored as a set internally to remove duplicates proactively, but written # Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls # as a list to yamls
custom_classes: Set[str] custom_classes: set[str]
# A set of all the build features used by the selected models # A set of all the build features used by the selected models
# Stored as a set internally to remove duplicates proactively, but written # Stored as a set internally to remove duplicates proactively, but written
# as a list to yamls # as a list to yamls
build_features: Set[str] build_features: set[str]
# If true, then fragments for all dtypes for all kernel functions # If true, then fragments for all dtypes for all kernel functions
# are included as well as all custom classes. This is typically set when any one of the # are included as well as all custom classes. This is typically set when any one of the
@ -63,11 +68,11 @@ class SelectiveBuilder:
include_all_non_op_selectives: bool include_all_non_op_selectives: bool
@staticmethod @staticmethod
def get_nop_selector() -> "SelectiveBuilder": def get_nop_selector() -> SelectiveBuilder:
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True}) return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
@staticmethod @staticmethod
def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder": def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
valid_top_level_keys = { valid_top_level_keys = {
"include_all_non_op_selectives", "include_all_non_op_selectives",
"include_all_operators", "include_all_operators",
@ -135,20 +140,20 @@ class SelectiveBuilder:
) )
@staticmethod @staticmethod
def from_yaml_str(config_contents: str) -> "SelectiveBuilder": def from_yaml_str(config_contents: str) -> SelectiveBuilder:
contents = yaml.safe_load(config_contents) contents = yaml.safe_load(config_contents)
return SelectiveBuilder.from_yaml_dict(contents) return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod @staticmethod
def from_yaml_path(config_path: str) -> "SelectiveBuilder": def from_yaml_path(config_path: str) -> SelectiveBuilder:
with open(config_path) as f: with open(config_path) as f:
contents = yaml.safe_load(f) contents = yaml.safe_load(f)
return SelectiveBuilder.from_yaml_dict(contents) return SelectiveBuilder.from_yaml_dict(contents)
@staticmethod @staticmethod
def from_legacy_op_registration_allow_list( def from_legacy_op_registration_allow_list(
allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
) -> "SelectiveBuilder": ) -> SelectiveBuilder:
operators = {} operators = {}
for op in allow_list: for op in allow_list:
operators[op] = { operators[op] = {
@ -231,7 +236,7 @@ class SelectiveBuilder:
and dtype in self.kernel_metadata[kernel_tag] and dtype in self.kernel_metadata[kernel_tag]
) )
def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]: def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
""" """
Return a list of kernel keys that cover the used ops Return a list of kernel keys that cover the used ops
""" """
@ -261,8 +266,8 @@ class SelectiveBuilder:
return list(result_set) return list(result_set)
def to_dict(self) -> Dict[str, object]: def to_dict(self) -> dict[str, object]:
ret: Dict[str, object] = { ret: dict[str, object] = {
"include_all_non_op_selectives": self.include_all_non_op_selectives, "include_all_non_op_selectives": self.include_all_non_op_selectives,
"include_all_operators": self.include_all_operators, "include_all_operators": self.include_all_operators,
} }
@ -288,10 +293,10 @@ class SelectiveBuilder:
def merge_kernel_metadata( def merge_kernel_metadata(
lhs: Dict[str, List[str]], lhs: dict[str, list[str]],
rhs: Dict[str, List[str]], rhs: dict[str, list[str]],
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
kernel_metadata: Dict[str, List[str]] = {} kernel_metadata: dict[str, list[str]] = {}
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()): for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
dtypes_copy = set(dtypes) dtypes_copy = set(dtypes)
if tag_name in kernel_metadata: if tag_name in kernel_metadata:
@ -303,10 +308,10 @@ def merge_kernel_metadata(
def merge_et_kernel_metadata( def merge_et_kernel_metadata(
lhs: Dict[str, List[str]], lhs: dict[str, list[str]],
rhs: Dict[str, List[str]], rhs: dict[str, list[str]],
) -> Dict[str, List[str]]: ) -> dict[str, list[str]]:
merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set) merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
for op in list(lhs.keys()) + list(rhs.keys()): for op in list(lhs.keys()) + list(rhs.keys()):
merge_et_kernel_metadata[op].update(lhs.get(op, [])) merge_et_kernel_metadata[op].update(lhs.get(op, []))
merge_et_kernel_metadata[op].update(rhs.get(op, [])) merge_et_kernel_metadata[op].update(rhs.get(op, []))

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import importlib.util
import os import os
import sys import sys
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain from itertools import chain
from pathlib import Path from pathlib import Path
@ -18,9 +18,9 @@ you are in the root directory of the Pytorch git repo"""
if not file_path.exists(): if not file_path.exists():
raise Exception(err_msg) # noqa: TRY002 raise Exception(err_msg) # noqa: TRY002
spec = importlib.util.spec_from_file_location(module_name, file_path) spec = spec_from_file_location(module_name, file_path)
assert spec is not None assert spec is not None
module = importlib.util.module_from_spec(spec) module = module_from_spec(spec)
sys.modules[module_name] = module sys.modules[module_name] = module
assert spec.loader is not None assert spec.loader is not None
assert module is not None assert module is not None

View File

@ -1,9 +1,9 @@
from typing import Dict, Union from __future__ import annotations
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str: def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
if isinstance(g, NativeFunctionsGroup): if isinstance(g, NativeFunctionsGroup):
return str(g.functional.func.name.name.base) return str(g.functional.func.name.name.base)
else: else:
@ -55,12 +55,12 @@ is_hand_written_ops_ = frozenset(
) )
def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
name_base = func_name_base_str(g) name_base = func_name_base_str(g)
return name_base in is_hand_written_ops_ return name_base in is_hand_written_ops_
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None: def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
assert index == 0 or index == 1 assert index == 0 or index == 1
if op_name == "addr": if op_name == "addr":
if index == 0: if index == 0:

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import argparse import argparse
import itertools import itertools
import os import os
@ -28,7 +30,7 @@ def group_functions_by_op_name(
return [] return []
groups = [] groups = []
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
with native_function_manager(g): with native_function_manager(g):
return generator.is_supported(g) return generator.is_supported(g)

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import json import json
import logging import logging
import math import math
from typing import Dict, List, Optional, Sequence, Tuple, Union from typing import Sequence
import torchgen.api.cpp as cpp import torchgen.api.cpp as cpp
from torchgen.context import native_function_manager from torchgen.context import native_function_manager
@ -25,7 +27,7 @@ logger: logging.Logger = logging.getLogger()
def has_alias( def has_alias(
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]] arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
) -> bool: ) -> bool:
for arg in arguments: for arg in arguments:
annotation = getattr(arg, "annotation", None) annotation = getattr(arg, "annotation", None)
@ -237,7 +239,7 @@ BLOCKED_OPS = frozenset(
) )
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool: def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
base_op_name = "" base_op_name = ""
func = None func = None
if isinstance(g, NativeFunctionsViewGroup): if isinstance(g, NativeFunctionsViewGroup):
@ -298,8 +300,8 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
def ivalue_type_conversion_method( def ivalue_type_conversion_method(
arg_type: Union[BaseType, OptionalType, Type] arg_type: BaseType | OptionalType | Type,
) -> Optional[Tuple[bool, str]]: ) -> tuple[bool, str] | None:
""" """
Return the method call expression of `c10::ivalue' to convert its contained value to Return the method call expression of `c10::ivalue' to convert its contained value to
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor, the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
@ -394,7 +396,7 @@ def test_tensor_dim(op_name: str) -> int:
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}' test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string) test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
def test_tensor_shape(op_name: str) -> str: def test_tensor_shape(op_name: str) -> str:
@ -405,7 +407,7 @@ def test_tensor_shape(op_name: str) -> str:
def test_value_expression( def test_value_expression(
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str arg_type: BaseType | OptionalType | Type, index: int, op_name: str
) -> str: ) -> str:
tensor_size_ex = test_tensor_shape(op_name) tensor_size_ex = test_tensor_shape(op_name)
if tensor_size_ex == "": if tensor_size_ex == "":
@ -475,8 +477,8 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
def generate_test_ir_arguments( def generate_test_ir_arguments(
schema: FunctionSchema, schema: FunctionSchema,
) -> List[Tuple[str, Optional[str]]]: ) -> list[tuple[str, str | None]]:
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]: def ir_argument(arg: Argument) -> tuple[str, str | None]:
t = arg.type t = arg.type
add_optional = False add_optional = False
if isinstance(t, OptionalType): if isinstance(t, OptionalType):

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import contextlib import contextlib
import functools import functools
import hashlib import hashlib
@ -5,31 +7,29 @@ import os
import re import re
import sys import sys
import textwrap import textwrap
from argparse import Namespace
from dataclasses import fields, is_dataclass from dataclasses import fields, is_dataclass
from enum import auto, Enum from enum import auto, Enum
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Generic, Generic,
Iterable, Iterable,
Iterator, Iterator,
List,
Literal, Literal,
NoReturn, NoReturn,
Optional,
Sequence, Sequence,
Set, TYPE_CHECKING,
Tuple,
TypeVar, TypeVar,
Union,
) )
from typing_extensions import Self from typing_extensions import Self
from torchgen.code_template import CodeTemplate from torchgen.code_template import CodeTemplate
if TYPE_CHECKING:
from argparse import Namespace
# Many of these functions share logic for defining both the definition # Many of these functions share logic for defining both the definition
# and declaration (for example, the function signature is the same), so # and declaration (for example, the function signature is the same), so
# we organize them into one function that takes a Target to say which # we organize them into one function that takes a Target to say which
@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)"
# TODO: Use a real parser here; this will get bamboozled # TODO: Use a real parser here; this will get bamboozled
def split_name_params(schema: str) -> Tuple[str, List[str]]: def split_name_params(schema: str) -> tuple[str, list[str]]:
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema) m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
if m is None: if m is None:
raise RuntimeError(f"Unsupported function schema: {schema}") raise RuntimeError(f"Unsupported function schema: {schema}")
@ -73,7 +73,7 @@ S = TypeVar("S")
# Map over function that may return None; omit Nones from output sequence # Map over function that may return None; omit Nones from output sequence
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]: def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
for x in xs: for x in xs:
r = func(x) r = func(x)
if r is not None: if r is not None:
@ -127,7 +127,7 @@ class FileManager:
install_dir: str install_dir: str
template_dir: str template_dir: str
dry_run: bool dry_run: bool
filenames: Set[str] filenames: set[str]
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None: def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
self.install_dir = install_dir self.install_dir = install_dir
@ -136,7 +136,7 @@ class FileManager:
self.dry_run = dry_run self.dry_run = dry_run
def _write_if_changed(self, filename: str, contents: str) -> None: def _write_if_changed(self, filename: str, contents: str) -> None:
old_contents: Optional[str] old_contents: str | None
try: try:
with open(filename) as f: with open(filename) as f:
old_contents = f.read() old_contents = f.read()
@ -150,7 +150,7 @@ class FileManager:
# Read from template file and replace pattern with callable (type could be dict or str). # Read from template file and replace pattern with callable (type could be dict or str).
def substitute_with_template( def substitute_with_template(
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]] self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
) -> str: ) -> str:
template_path = os.path.join(self.template_dir, template_fn) template_path = os.path.join(self.template_dir, template_fn)
env = env_callable() env = env_callable()
@ -171,7 +171,7 @@ class FileManager:
self, self,
filename: str, filename: str,
template_fn: str, template_fn: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]], env_callable: Callable[[], str | dict[str, Any]],
) -> None: ) -> None:
filename = f"{self.install_dir}/{filename}" filename = f"{self.install_dir}/{filename}"
assert filename not in self.filenames, "duplicate file write {filename}" assert filename not in self.filenames, "duplicate file write {filename}"
@ -186,7 +186,7 @@ class FileManager:
def write( def write(
self, self,
filename: str, filename: str,
env_callable: Callable[[], Union[str, Dict[str, Any]]], env_callable: Callable[[], str | dict[str, Any]],
) -> None: ) -> None:
self.write_with_template(filename, filename, env_callable) self.write_with_template(filename, filename, env_callable)
@ -196,13 +196,13 @@ class FileManager:
items: Iterable[T], items: Iterable[T],
*, *,
key_fn: Callable[[T], str], key_fn: Callable[[T], str],
env_callable: Callable[[T], Dict[str, List[str]]], env_callable: Callable[[T], dict[str, list[str]]],
num_shards: int, num_shards: int,
base_env: Optional[Dict[str, Any]] = None, base_env: dict[str, Any] | None = None,
sharded_keys: Set[str], sharded_keys: set[str],
) -> None: ) -> None:
everything: Dict[str, Any] = {"shard_id": "Everything"} everything: dict[str, Any] = {"shard_id": "Everything"}
shards: List[Dict[str, Any]] = [ shards: list[dict[str, Any]] = [
{"shard_id": f"_{i}"} for i in range(num_shards) {"shard_id": f"_{i}"} for i in range(num_shards)
] ]
all_shards = [everything] + shards all_shards = [everything] + shards
@ -221,7 +221,7 @@ class FileManager:
else: else:
shard[key] = [] shard[key] = []
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None: def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
for k, v in from_.items(): for k, v in from_.items():
assert k in sharded_keys, f"undeclared sharded key {k}" assert k in sharded_keys, f"undeclared sharded key {k}"
into[k] += v into[k] += v
@ -275,7 +275,7 @@ class FileManager:
# Helper function to generate file manager # Helper function to generate file manager
def make_file_manager( def make_file_manager(
options: Namespace, install_dir: Optional[str] = None options: Namespace, install_dir: str | None = None
) -> FileManager: ) -> FileManager:
template_dir = os.path.join(options.source_path, "templates") template_dir = os.path.join(options.source_path, "templates")
install_dir = install_dir if install_dir else options.install_dir install_dir = install_dir if install_dir else options.install_dir
@ -335,7 +335,7 @@ def _pformat(
def _format_dict( def _format_dict(
attr: Dict[Any, Any], attr: dict[Any, Any],
indent: int, indent: int,
width: int, width: int,
curr_indent: int, curr_indent: int,
@ -355,7 +355,7 @@ def _format_dict(
def _format_list( def _format_list(
attr: Union[List[Any], Set[Any], Tuple[Any, ...]], attr: list[Any] | set[Any] | tuple[Any, ...],
indent: int, indent: int,
width: int, width: int,
curr_indent: int, curr_indent: int,
@ -370,7 +370,7 @@ def _format_list(
def _format( def _format(
fields_str: List[str], fields_str: list[str],
indent: int, indent: int,
width: int, width: int,
curr_indent: int, curr_indent: int,
@ -402,7 +402,9 @@ class NamespaceHelper:
} // namespace torch } // namespace torch
""" """
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2): def __init__(
self, namespace_str: str, entity_name: str = "", max_level: int = 2
) -> None:
# cpp_namespace can be a colon joined string such as torch::lazy # cpp_namespace can be a colon joined string such as torch::lazy
cpp_namespaces = namespace_str.split("::") cpp_namespaces = namespace_str.split("::")
assert ( assert (
@ -419,7 +421,7 @@ class NamespaceHelper:
@staticmethod @staticmethod
def from_namespaced_entity( def from_namespaced_entity(
namespaced_entity: str, max_level: int = 2 namespaced_entity: str, max_level: int = 2
) -> "NamespaceHelper": ) -> NamespaceHelper:
""" """
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add" Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
""" """
@ -452,9 +454,9 @@ class NamespaceHelper:
class OrderedSet(Generic[T]): class OrderedSet(Generic[T]):
storage: Dict[T, Literal[None]] storage: dict[T, Literal[None]]
def __init__(self, iterable: Optional[Iterable[T]] = None): def __init__(self, iterable: Iterable[T] | None = None) -> None:
if iterable is None: if iterable is None:
self.storage = {} self.storage = {}
else: else:
@ -466,28 +468,28 @@ class OrderedSet(Generic[T]):
def __iter__(self) -> Iterator[T]: def __iter__(self) -> Iterator[T]:
return iter(self.storage.keys()) return iter(self.storage.keys())
def update(self, items: "OrderedSet[T]") -> None: def update(self, items: OrderedSet[T]) -> None:
self.storage.update(items.storage) self.storage.update(items.storage)
def add(self, item: T) -> None: def add(self, item: T) -> None:
self.storage[item] = None self.storage[item] = None
def copy(self) -> "OrderedSet[T]": def copy(self) -> OrderedSet[T]:
ret: OrderedSet[T] = OrderedSet() ret: OrderedSet[T] = OrderedSet()
ret.storage = self.storage.copy() ret.storage = self.storage.copy()
return ret return ret
@staticmethod @staticmethod
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]": def union(*args: OrderedSet[T]) -> OrderedSet[T]:
ret = args[0].copy() ret = args[0].copy()
for s in args[1:]: for s in args[1:]:
ret.update(s) ret.update(s)
return ret return ret
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]": def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
return OrderedSet.union(self, other) return OrderedSet.union(self, other)
def __ior__(self, other: "OrderedSet[T]") -> Self: def __ior__(self, other: OrderedSet[T]) -> Self:
self.update(other) self.update(other)
return self return self