mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Easy] enable postponed annotations in torchgen
(#129376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129376 Approved by: https://github.com/ezyang ghstack dependencies: #129375
This commit is contained in:
committed by
PyTorch MergeBot
parent
8a67daf283
commit
9120992c72
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
|
from typing import cast, Sequence
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
@ -48,16 +50,16 @@ class Derivative:
|
|||||||
original_formula: str
|
original_formula: str
|
||||||
|
|
||||||
# Names of the arguments for which this formula calculates derivatives.
|
# Names of the arguments for which this formula calculates derivatives.
|
||||||
var_names: Tuple[str, ...]
|
var_names: tuple[str, ...]
|
||||||
|
|
||||||
# Saved inputs that are referenced by the formula.
|
# Saved inputs that are referenced by the formula.
|
||||||
saved_inputs: Tuple[SavedAttribute, ...]
|
saved_inputs: tuple[SavedAttribute, ...]
|
||||||
|
|
||||||
# Saved outputs that are referenced by the formula.
|
# Saved outputs that are referenced by the formula.
|
||||||
saved_outputs: Tuple[SavedAttribute, ...]
|
saved_outputs: tuple[SavedAttribute, ...]
|
||||||
|
|
||||||
# Gradients that are referenced by name in the formula.
|
# Gradients that are referenced by name in the formula.
|
||||||
named_gradients: Set[str]
|
named_gradients: set[str]
|
||||||
|
|
||||||
|
|
||||||
# Represents a forward formula that calculates forward derivatives
|
# Represents a forward formula that calculates forward derivatives
|
||||||
@ -71,17 +73,17 @@ class ForwardDerivative:
|
|||||||
|
|
||||||
# Name of the output arguments for which this formula calculates forward
|
# Name of the output arguments for which this formula calculates forward
|
||||||
# derivatives
|
# derivatives
|
||||||
var_names: Tuple[str, ...]
|
var_names: tuple[str, ...]
|
||||||
|
|
||||||
# Type of the output arguments for which this formula calculates forward
|
# Type of the output arguments for which this formula calculates forward
|
||||||
# derivatives
|
# derivatives
|
||||||
var_types: Tuple[Type, ...]
|
var_types: tuple[Type, ...]
|
||||||
|
|
||||||
# Inputs for which the forward derivatives are required for this formula
|
# Inputs for which the forward derivatives are required for this formula
|
||||||
required_inputs_fw_grad: Optional[Tuple[str, ...]]
|
required_inputs_fw_grad: tuple[str, ...] | None
|
||||||
|
|
||||||
# Inputs for which the primal is required for this formula
|
# Inputs for which the primal is required for this formula
|
||||||
required_inputs_primal: Optional[Tuple[str, ...]]
|
required_inputs_primal: tuple[str, ...] | None
|
||||||
|
|
||||||
# Flag to specify if this formula requires the original value of self
|
# Flag to specify if this formula requires the original value of self
|
||||||
# This is only used by inplace operations
|
# This is only used by inplace operations
|
||||||
@ -116,7 +118,7 @@ class DifferentiabilityInfo:
|
|||||||
# The name of the generated autograd function.
|
# The name of the generated autograd function.
|
||||||
# It's set only if we will calculate a derivative, i.e.
|
# It's set only if we will calculate a derivative, i.e.
|
||||||
# 'args_with_derivatives' is not empty.
|
# 'args_with_derivatives' is not empty.
|
||||||
op: Optional[str]
|
op: str | None
|
||||||
|
|
||||||
# The derivatives formulae for this function.
|
# The derivatives formulae for this function.
|
||||||
# Note that the length of this sequence is the number of differentiable inputs
|
# Note that the length of this sequence is the number of differentiable inputs
|
||||||
@ -138,7 +140,7 @@ class DifferentiabilityInfo:
|
|||||||
|
|
||||||
# The named gradients that are used in any of the derivatives.
|
# The named gradients that are used in any of the derivatives.
|
||||||
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
|
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
|
||||||
used_named_gradients: Set[str]
|
used_named_gradients: set[str]
|
||||||
|
|
||||||
# The function's input arguments for which it calculates derivatives.
|
# The function's input arguments for which it calculates derivatives.
|
||||||
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
||||||
@ -149,7 +151,7 @@ class DifferentiabilityInfo:
|
|||||||
non_differentiable_arg_names: Sequence[str]
|
non_differentiable_arg_names: Sequence[str]
|
||||||
|
|
||||||
# Raw data read from derivatives.yaml.
|
# Raw data read from derivatives.yaml.
|
||||||
output_differentiability: Optional[List[bool]]
|
output_differentiability: list[bool] | None
|
||||||
|
|
||||||
# output_differentiability in derivatives.yaml can be a list of
|
# output_differentiability in derivatives.yaml can be a list of
|
||||||
# conditions that express if the output is differentiable. In this case,
|
# conditions that express if the output is differentiable. In this case,
|
||||||
@ -157,7 +159,7 @@ class DifferentiabilityInfo:
|
|||||||
# (NB: we only support one condition right now).
|
# (NB: we only support one condition right now).
|
||||||
# output_differentiability gets populated with True for each condition,
|
# output_differentiability gets populated with True for each condition,
|
||||||
# while output_differentiability_conditions gets populated with the conditions
|
# while output_differentiability_conditions gets populated with the conditions
|
||||||
output_differentiability_conditions: Optional[List[str]]
|
output_differentiability_conditions: list[str] | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_derivatives(self) -> bool:
|
def has_derivatives(self) -> bool:
|
||||||
@ -170,7 +172,7 @@ class DifferentiabilityInfo:
|
|||||||
# See Note [Codegen'd {view}_copy Operators]
|
# See Note [Codegen'd {view}_copy Operators]
|
||||||
def create_view_copy_from_view_derivative(
|
def create_view_copy_from_view_derivative(
|
||||||
self, g: NativeFunctionsViewGroup
|
self, g: NativeFunctionsViewGroup
|
||||||
) -> Optional["DifferentiabilityInfo"]:
|
) -> DifferentiabilityInfo | None:
|
||||||
if g.view_copy is None:
|
if g.view_copy is None:
|
||||||
return None
|
return None
|
||||||
f = g.view_copy
|
f = g.view_copy
|
||||||
@ -201,7 +203,7 @@ class DifferentiabilityInfo:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
|
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
|
||||||
if info is None:
|
if info is None:
|
||||||
return False
|
return False
|
||||||
for derivative in info.derivatives:
|
for derivative in info.derivatives:
|
||||||
@ -211,11 +213,11 @@ def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
|
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
|
||||||
return uses_ident(info, "retain_variables")
|
return uses_ident(info, "retain_variables")
|
||||||
|
|
||||||
|
|
||||||
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
|
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
|
||||||
return uses_ident(info, "grad")
|
return uses_ident(info, "grad")
|
||||||
|
|
||||||
|
|
||||||
@ -253,8 +255,8 @@ class DifferentiableOutput:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NativeFunctionWithDifferentiabilityInfo:
|
class NativeFunctionWithDifferentiabilityInfo:
|
||||||
func: NativeFunction
|
func: NativeFunction
|
||||||
info: Optional[Dict[str, DifferentiabilityInfo]]
|
info: dict[str, DifferentiabilityInfo] | None
|
||||||
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
|
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
|
||||||
|
|
||||||
|
|
||||||
# TODO: Update comment below since it is out of date.
|
# TODO: Update comment below since it is out of date.
|
||||||
@ -363,19 +365,19 @@ def is_reference_for_foreach(
|
|||||||
# TODO(crcrpar): Avoid hard coding "Default" ideally.
|
# TODO(crcrpar): Avoid hard coding "Default" ideally.
|
||||||
def gen_foreach_derivativeinfo(
|
def gen_foreach_derivativeinfo(
|
||||||
foreach_function: NativeFunction,
|
foreach_function: NativeFunction,
|
||||||
functional_info_by_signature: Dict[
|
functional_info_by_signature: dict[
|
||||||
FunctionSchema, Dict[str, DifferentiabilityInfo]
|
FunctionSchema, dict[str, DifferentiabilityInfo]
|
||||||
],
|
],
|
||||||
non_functional_info_by_signature: Dict[
|
non_functional_info_by_signature: dict[
|
||||||
FunctionSchema, Dict[str, DifferentiabilityInfo]
|
FunctionSchema, dict[str, DifferentiabilityInfo]
|
||||||
],
|
],
|
||||||
dispatch_key: str = "Default",
|
dispatch_key: str = "Default",
|
||||||
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
|
) -> tuple[DifferentiabilityInfo | None, bool]:
|
||||||
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
|
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
|
||||||
|
|
||||||
The second return value indicates whether the info is generated in this function.
|
The second return value indicates whether the info is generated in this function.
|
||||||
"""
|
"""
|
||||||
ref_diff_info: Optional[DifferentiabilityInfo] = None
|
ref_diff_info: DifferentiabilityInfo | None = None
|
||||||
|
|
||||||
for function_schema, diff_info in functional_info_by_signature.items():
|
for function_schema, diff_info in functional_info_by_signature.items():
|
||||||
if not is_reference_for_foreach(foreach_function, function_schema):
|
if not is_reference_for_foreach(foreach_function, function_schema):
|
||||||
@ -485,13 +487,13 @@ def gen_foreach_derivativeinfo(
|
|||||||
if arg.name in all_var_names
|
if arg.name in all_var_names
|
||||||
]
|
]
|
||||||
|
|
||||||
forward_derivatives: List[ForwardDerivative] = []
|
forward_derivatives: list[ForwardDerivative] = []
|
||||||
fw_derivative: ForwardDerivative
|
fw_derivative: ForwardDerivative
|
||||||
for fw_derivative in ref_diff_info.forward_derivatives:
|
for fw_derivative in ref_diff_info.forward_derivatives:
|
||||||
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
||||||
var_types: List[Type] = list(fw_derivative.var_types)
|
var_types: list[Type] = list(fw_derivative.var_types)
|
||||||
required_inputs_fw_grad: List[str] = []
|
required_inputs_fw_grad: list[str] = []
|
||||||
required_inputs_primal: List[str] = []
|
required_inputs_primal: list[str] = []
|
||||||
if fw_derivative.required_inputs_fw_grad is not None:
|
if fw_derivative.required_inputs_fw_grad is not None:
|
||||||
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
|
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
|
||||||
if fw_derivative.required_inputs_primal:
|
if fw_derivative.required_inputs_primal:
|
||||||
@ -578,9 +580,9 @@ def gen_foreach_derivativeinfo(
|
|||||||
|
|
||||||
|
|
||||||
def match_differentiability_info(
|
def match_differentiability_info(
|
||||||
native_functions: List[NativeFunction],
|
native_functions: list[NativeFunction],
|
||||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||||
) -> List[NativeFunctionWithDifferentiabilityInfo]:
|
) -> list[NativeFunctionWithDifferentiabilityInfo]:
|
||||||
"""Sets the "derivative" key on declarations to matching autograd function
|
"""Sets the "derivative" key on declarations to matching autograd function
|
||||||
In-place functions will use the out-of-place derivative definition if there
|
In-place functions will use the out-of-place derivative definition if there
|
||||||
is no in-place specific derivative.
|
is no in-place specific derivative.
|
||||||
@ -599,7 +601,7 @@ def match_differentiability_info(
|
|||||||
|
|
||||||
def find_info(
|
def find_info(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
|
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
|
||||||
# Don't bother matching info to generated out= variants
|
# Don't bother matching info to generated out= variants
|
||||||
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
|
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
|
||||||
return None, False
|
return None, False
|
||||||
@ -653,7 +655,7 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
|
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
result: List[NativeFunctionWithDifferentiabilityInfo] = []
|
result: list[NativeFunctionWithDifferentiabilityInfo] = []
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
info_dict, is_exact_match = find_info(f)
|
info_dict, is_exact_match = find_info(f)
|
||||||
|
|
||||||
@ -677,7 +679,7 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
|
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
|
||||||
for key, info in info_dict.items():
|
for key, info in info_dict.items():
|
||||||
if not info.forward_derivatives:
|
if not info.forward_derivatives:
|
||||||
fw_derivative_dict[key] = []
|
fw_derivative_dict[key] = []
|
||||||
@ -713,7 +715,7 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
formula = fw_info.formula
|
formula = fw_info.formula
|
||||||
|
|
||||||
def replace_self_with_original_self(formula: str, postfix: str) -> str:
|
def replace_self_with_original_self(formula: str, postfix: str) -> str:
|
||||||
def repl(m: Match[str]) -> str:
|
def repl(m: re.Match[str]) -> str:
|
||||||
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
|
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
|
||||||
|
|
||||||
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
|
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
|
||||||
@ -734,7 +736,7 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
formula = replace_self_with_original_self(formula, "_t")
|
formula = replace_self_with_original_self(formula, "_t")
|
||||||
|
|
||||||
# replace "result" from the formula by "self_p"
|
# replace "result" from the formula by "self_p"
|
||||||
def repl(m: Match[str]) -> str:
|
def repl(m: re.Match[str]) -> str:
|
||||||
return f"{m.group(1)}self_p{m.group(2)}"
|
return f"{m.group(1)}self_p{m.group(2)}"
|
||||||
|
|
||||||
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
|
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
|
||||||
@ -758,8 +760,8 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
# If there is a need, we can relax (2) to allow any op that has an in-place variant
|
# If there is a need, we can relax (2) to allow any op that has an in-place variant
|
||||||
is_single_method_on_self_t = False
|
is_single_method_on_self_t = False
|
||||||
directly_do_inplace = False
|
directly_do_inplace = False
|
||||||
op_name: Optional[str] = None
|
op_name: str | None = None
|
||||||
between_parens: Optional[str] = None
|
between_parens: str | None = None
|
||||||
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
|
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
|
||||||
if match:
|
if match:
|
||||||
op_name, between_parens = match.group(1), match.group(2)
|
op_name, between_parens = match.group(1), match.group(2)
|
||||||
@ -823,7 +825,7 @@ Attempted to convert a derivative formula for a mutable operator
|
|||||||
|
|
||||||
|
|
||||||
def is_differentiable(
|
def is_differentiable(
|
||||||
name: str, type: Type, info: Optional[DifferentiabilityInfo]
|
name: str, type: Type, info: DifferentiabilityInfo | None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return type.is_tensor_like() and (
|
return type.is_tensor_like() and (
|
||||||
info is None or name not in info.non_differentiable_arg_names
|
info is None or name not in info.non_differentiable_arg_names
|
||||||
@ -832,10 +834,10 @@ def is_differentiable(
|
|||||||
|
|
||||||
def gen_differentiable_outputs(
|
def gen_differentiable_outputs(
|
||||||
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
||||||
) -> List[DifferentiableOutput]:
|
) -> list[DifferentiableOutput]:
|
||||||
f = fn.func
|
f = fn.func
|
||||||
info = fn.info[key] if fn.info else None
|
info = fn.info[key] if fn.info else None
|
||||||
outputs: List[DifferentiableOutput] = [
|
outputs: list[DifferentiableOutput] = [
|
||||||
DifferentiableOutput(
|
DifferentiableOutput(
|
||||||
name=name,
|
name=name,
|
||||||
type=ret.type,
|
type=ret.type,
|
||||||
@ -850,7 +852,7 @@ def gen_differentiable_outputs(
|
|||||||
f"The length of output_differentiability ({len(output_differentiability)}), "
|
f"The length of output_differentiability ({len(output_differentiability)}), "
|
||||||
f"does not match the number of outputs ({len(outputs)})."
|
f"does not match the number of outputs ({len(outputs)})."
|
||||||
)
|
)
|
||||||
differentiable_outputs: List[DifferentiableOutput] = []
|
differentiable_outputs: list[DifferentiableOutput] = []
|
||||||
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
|
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"output_differentiability=False for inplace operation (version_counter won't get updated)"
|
"output_differentiability=False for inplace operation (version_counter won't get updated)"
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import List, Optional, Sequence, Set, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -94,7 +96,7 @@ def valuetype_type(
|
|||||||
binds: ArgName,
|
binds: ArgName,
|
||||||
remove_non_owning_ref_types: bool = False,
|
remove_non_owning_ref_types: bool = False,
|
||||||
symint: bool = False,
|
symint: bool = False,
|
||||||
) -> Optional[NamedCType]:
|
) -> NamedCType | None:
|
||||||
if isinstance(t, BaseType):
|
if isinstance(t, BaseType):
|
||||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||||
return None
|
return None
|
||||||
@ -279,7 +281,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
|
|||||||
|
|
||||||
|
|
||||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||||
returns: List[str] = []
|
returns: list[str] = []
|
||||||
for i, r in enumerate(f.func.returns):
|
for i, r in enumerate(f.func.returns):
|
||||||
# If we have an inplace function, the return argument is
|
# If we have an inplace function, the return argument is
|
||||||
# implicitly named self.
|
# implicitly named self.
|
||||||
@ -368,17 +370,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def argument(
|
def argument(
|
||||||
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||||
*,
|
*,
|
||||||
cpp_no_default_args: Set[str],
|
cpp_no_default_args: set[str],
|
||||||
method: bool,
|
method: bool,
|
||||||
faithful: bool,
|
faithful: bool,
|
||||||
symint: bool = False,
|
symint: bool = False,
|
||||||
has_tensor_options: bool,
|
has_tensor_options: bool,
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
def sub_argument(
|
def sub_argument(
|
||||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
return argument(
|
return argument(
|
||||||
a,
|
a,
|
||||||
cpp_no_default_args=cpp_no_default_args,
|
cpp_no_default_args=cpp_no_default_args,
|
||||||
@ -394,7 +396,7 @@ def argument(
|
|||||||
binds = SpecialArgName.possibly_redundant_memory_format
|
binds = SpecialArgName.possibly_redundant_memory_format
|
||||||
else:
|
else:
|
||||||
binds = a.name
|
binds = a.name
|
||||||
default: Optional[str] = None
|
default: str | None = None
|
||||||
if a.name not in cpp_no_default_args and a.default is not None:
|
if a.name not in cpp_no_default_args and a.default is not None:
|
||||||
default = default_expr(a.default, a.type, symint=symint)
|
default = default_expr(a.default, a.type, symint=symint)
|
||||||
return [
|
return [
|
||||||
@ -445,9 +447,9 @@ def arguments(
|
|||||||
faithful: bool,
|
faithful: bool,
|
||||||
symint: bool = False,
|
symint: bool = False,
|
||||||
method: bool,
|
method: bool,
|
||||||
cpp_no_default_args: Set[str],
|
cpp_no_default_args: set[str],
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
if faithful:
|
if faithful:
|
||||||
args.extend(arguments.non_out)
|
args.extend(arguments.non_out)
|
||||||
args.extend(arguments.out)
|
args.extend(arguments.out)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from typing import List, Sequence, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||||
@ -76,10 +78,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
|
|||||||
return cpp.returns_type(rs, symint=symint)
|
return cpp.returns_type(rs, symint=symint)
|
||||||
|
|
||||||
|
|
||||||
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
def jit_arguments(func: FunctionSchema) -> list[Argument]:
|
||||||
def to_argument(
|
def to_argument(
|
||||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||||
) -> List[Argument]:
|
) -> list[Argument]:
|
||||||
if isinstance(a, Argument):
|
if isinstance(a, Argument):
|
||||||
return [a]
|
return [a]
|
||||||
elif isinstance(a, SelfArgument):
|
elif isinstance(a, SelfArgument):
|
||||||
@ -114,5 +116,5 @@ def argument(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
|
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
|
||||||
return [argument(a, symint=symint) for a in jit_arguments(func)]
|
return [argument(a, symint=symint) for a in jit_arguments(func)]
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
from torchgen.api import dispatcher
|
from torchgen.api import dispatcher
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -93,7 +93,7 @@ def name(
|
|||||||
*,
|
*,
|
||||||
is_reverse: bool,
|
is_reverse: bool,
|
||||||
include_namespace: bool,
|
include_namespace: bool,
|
||||||
reapply_views: Optional[bool] = None,
|
reapply_views: bool | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if reapply_views is None:
|
if reapply_views is None:
|
||||||
# reapply_views is only important for the fwd lambda,
|
# reapply_views is only important for the fwd lambda,
|
||||||
@ -124,7 +124,7 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
|||||||
return f"{api_name}_inverse"
|
return f"{api_name}_inverse"
|
||||||
|
|
||||||
|
|
||||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
|
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
||||||
# capture arguments include all arguments except `self`.
|
# capture arguments include all arguments except `self`.
|
||||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
||||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
||||||
@ -152,14 +152,14 @@ def returns_type(func: FunctionSchema) -> CType:
|
|||||||
return BaseCType(tensorT)
|
return BaseCType(tensorT)
|
||||||
|
|
||||||
|
|
||||||
def outer_arguments(*, is_reverse: bool) -> List[Binding]:
|
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
||||||
if is_reverse:
|
if is_reverse:
|
||||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
||||||
else:
|
else:
|
||||||
return [base_binding, mutated_view_idx_binding]
|
return [base_binding, mutated_view_idx_binding]
|
||||||
|
|
||||||
|
|
||||||
def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
|
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
||||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
||||||
if len(func.returns) > 1 or (
|
if len(func.returns) > 1 or (
|
||||||
@ -169,7 +169,7 @@ def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
|
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||||
args = func.arguments.flat_all
|
args = func.arguments.flat_all
|
||||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||||
non_self_args = args[1:]
|
non_self_args = args[1:]
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
BaseCppType,
|
BaseCppType,
|
||||||
@ -34,7 +36,7 @@ from torchgen.model import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_valueT: Optional[BaseCppType] = None
|
_valueT: BaseCppType | None = None
|
||||||
|
|
||||||
|
|
||||||
# A ValueT is an IR type which represents the computation of a Tensor. In other
|
# A ValueT is an IR type which represents the computation of a Tensor. In other
|
||||||
@ -66,8 +68,8 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
|
|||||||
|
|
||||||
|
|
||||||
def process_ir_type(
|
def process_ir_type(
|
||||||
typ: Type, properties: "LazyIrProperties", *, symint: bool
|
typ: Type, properties: LazyIrProperties, *, symint: bool
|
||||||
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
|
) -> BaseCType | VectorCType | OptionalCType | ListCType:
|
||||||
"""
|
"""
|
||||||
This function takes a type from NativeFunctions and converts it for use with
|
This function takes a type from NativeFunctions and converts it for use with
|
||||||
lazy tensor codegen.
|
lazy tensor codegen.
|
||||||
@ -147,7 +149,7 @@ def process_ir_type(
|
|||||||
#
|
#
|
||||||
# Invariant: passed typ should be an *owning* CType (e.g., we will report
|
# Invariant: passed typ should be an *owning* CType (e.g., we will report
|
||||||
# that ArrayRef<Value> is NOT a value type)
|
# that ArrayRef<Value> is NOT a value type)
|
||||||
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
|
def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a type, determine if it is a Value-like type. This is equivalent to
|
Given a type, determine if it is a Value-like type. This is equivalent to
|
||||||
being Tensor-like, but assumes the type has already been transformed.
|
being Tensor-like, but assumes the type has already been transformed.
|
||||||
@ -202,7 +204,7 @@ def isGeneratorType(typ: Type) -> bool:
|
|||||||
class LazyArgument:
|
class LazyArgument:
|
||||||
name: str
|
name: str
|
||||||
orig_type: Type
|
orig_type: Type
|
||||||
lazy_type_: Optional[CType]
|
lazy_type_: CType | None
|
||||||
is_wrapped_scalar: bool
|
is_wrapped_scalar: bool
|
||||||
is_generator: bool
|
is_generator: bool
|
||||||
# TODO: this is lies, it is false for symint list
|
# TODO: this is lies, it is false for symint list
|
||||||
@ -214,7 +216,9 @@ class LazyArgument:
|
|||||||
# true if this argument is or contains a lazy IR value
|
# true if this argument is or contains a lazy IR value
|
||||||
is_lazy_value: bool
|
is_lazy_value: bool
|
||||||
|
|
||||||
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
|
def __init__(
|
||||||
|
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
|
||||||
|
) -> None:
|
||||||
self.name = arg.name
|
self.name = arg.name
|
||||||
self.orig_type = arg.type
|
self.orig_type = arg.type
|
||||||
self.symint = symint
|
self.symint = symint
|
||||||
@ -248,7 +252,7 @@ class LazyIrProperties:
|
|||||||
attributes. The mutual exclusivity is automatically handled.
|
attributes. The mutual exclusivity is automatically handled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Properties: Tuple[Tuple[str, ...], ...] = (
|
Properties: tuple[tuple[str, ...], ...] = (
|
||||||
(
|
(
|
||||||
"ShapePrecompute", # Assume shape has been precomputed
|
"ShapePrecompute", # Assume shape has been precomputed
|
||||||
"ShapeCompute", # Need to compute the shape on construction
|
"ShapeCompute", # Need to compute the shape on construction
|
||||||
@ -271,8 +275,8 @@ class LazyIrProperties:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, *default_properties: str):
|
def __init__(self, *default_properties: str) -> None:
|
||||||
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
|
properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
|
||||||
LazyIrProperties.Properties
|
LazyIrProperties.Properties
|
||||||
)
|
)
|
||||||
self.__dict__["properties"] = properties
|
self.__dict__["properties"] = properties
|
||||||
@ -305,17 +309,17 @@ class LazyIrProperties:
|
|||||||
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
|
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
|
||||||
class LazyIrSchema:
|
class LazyIrSchema:
|
||||||
# The name of the operator this function schema describes.
|
# The name of the operator this function schema describes.
|
||||||
name: "OperatorName"
|
name: OperatorName
|
||||||
|
|
||||||
positional_args: Tuple[LazyArgument, ...]
|
positional_args: tuple[LazyArgument, ...]
|
||||||
keyword_args: Tuple[LazyArgument, ...]
|
keyword_args: tuple[LazyArgument, ...]
|
||||||
|
|
||||||
# TODO: Need to handle collisions with argument names at some point
|
# TODO: Need to handle collisions with argument names at some point
|
||||||
returns: Tuple["Return", ...]
|
returns: tuple[Return, ...]
|
||||||
|
|
||||||
# if this schema has a Generator arg, list its orig ctype/name but don't
|
# if this schema has a Generator arg, list its orig ctype/name but don't
|
||||||
# build a LazyArgument since lazy IR doesn't support it
|
# build a LazyArgument since lazy IR doesn't support it
|
||||||
generator_arg: Optional[NamedCType] = None
|
generator_arg: NamedCType | None = None
|
||||||
|
|
||||||
# original function schema
|
# original function schema
|
||||||
func: FunctionSchema
|
func: FunctionSchema
|
||||||
@ -329,21 +333,21 @@ class LazyIrSchema:
|
|||||||
"Lower",
|
"Lower",
|
||||||
"CanBeReused",
|
"CanBeReused",
|
||||||
)
|
)
|
||||||
opkind: Optional[str] = None
|
opkind: str | None = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
func: FunctionSchema,
|
func: FunctionSchema,
|
||||||
properties: Optional[LazyIrProperties] = None,
|
properties: LazyIrProperties | None = None,
|
||||||
*,
|
*,
|
||||||
symint: bool,
|
symint: bool,
|
||||||
):
|
) -> None:
|
||||||
if properties:
|
if properties:
|
||||||
self.properties = properties
|
self.properties = properties
|
||||||
|
|
||||||
self.func = func
|
self.func = func
|
||||||
self.symint = symint
|
self.symint = symint
|
||||||
positional_args: List[LazyArgument] = []
|
positional_args: list[LazyArgument] = []
|
||||||
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
||||||
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
||||||
arg = func.arguments.self_arg.argument
|
arg = func.arguments.self_arg.argument
|
||||||
@ -357,7 +361,7 @@ class LazyIrSchema:
|
|||||||
)
|
)
|
||||||
self.positional_args = tuple(positional_args)
|
self.positional_args = tuple(positional_args)
|
||||||
|
|
||||||
keyword_args: List[LazyArgument] = []
|
keyword_args: list[LazyArgument] = []
|
||||||
for arg_field in [
|
for arg_field in [
|
||||||
"pre_tensor_options_kwarg_only",
|
"pre_tensor_options_kwarg_only",
|
||||||
"tensor_options",
|
"tensor_options",
|
||||||
@ -411,13 +415,13 @@ class LazyIrSchema:
|
|||||||
values: bool = True,
|
values: bool = True,
|
||||||
scalars: bool = True,
|
scalars: bool = True,
|
||||||
generator: bool = True,
|
generator: bool = True,
|
||||||
) -> List[LazyArgument]:
|
) -> list[LazyArgument]:
|
||||||
# This function maintains the sorted order of arguments but provides different filtered views.
|
# This function maintains the sorted order of arguments but provides different filtered views.
|
||||||
# Some parts of the code care about kwargs vs args (TS lowerings),
|
# Some parts of the code care about kwargs vs args (TS lowerings),
|
||||||
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
|
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
|
||||||
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
|
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
|
||||||
# in TS lowerings and therefore also omitted from lazy IR.
|
# in TS lowerings and therefore also omitted from lazy IR.
|
||||||
args: List[LazyArgument] = []
|
args: list[LazyArgument] = []
|
||||||
if positional:
|
if positional:
|
||||||
args.extend(self.positional_args)
|
args.extend(self.positional_args)
|
||||||
if keyword:
|
if keyword:
|
||||||
@ -439,25 +443,25 @@ class LazyIrSchema:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positional_values(self) -> List[LazyArgument]:
|
def positional_values(self) -> list[LazyArgument]:
|
||||||
return self.filtered_args(
|
return self.filtered_args(
|
||||||
positional=True, keyword=False, values=True, scalars=False
|
positional=True, keyword=False, values=True, scalars=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positional_scalars(self) -> List[LazyArgument]:
|
def positional_scalars(self) -> list[LazyArgument]:
|
||||||
return self.filtered_args(
|
return self.filtered_args(
|
||||||
positional=True, keyword=False, values=False, scalars=True
|
positional=True, keyword=False, values=False, scalars=True
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keyword_values(self) -> List[LazyArgument]:
|
def keyword_values(self) -> list[LazyArgument]:
|
||||||
return self.filtered_args(
|
return self.filtered_args(
|
||||||
positional=False, keyword=True, values=True, scalars=False
|
positional=False, keyword=True, values=True, scalars=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keyword_scalars(self) -> List[LazyArgument]:
|
def keyword_scalars(self) -> list[LazyArgument]:
|
||||||
return self.filtered_args(
|
return self.filtered_args(
|
||||||
positional=False, keyword=True, values=False, scalars=True
|
positional=False, keyword=True, values=False, scalars=True
|
||||||
)
|
)
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import List, Optional, Sequence, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
@ -81,11 +83,11 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
|
|||||||
|
|
||||||
|
|
||||||
def argument(
|
def argument(
|
||||||
a: Union[Argument, SelfArgument, TensorOptionsArguments],
|
a: Argument | SelfArgument | TensorOptionsArguments,
|
||||||
*,
|
*,
|
||||||
is_out: bool,
|
is_out: bool,
|
||||||
symint: bool,
|
symint: bool,
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
# Ideally, we NEVER default native functions. However, there are a number
|
# Ideally, we NEVER default native functions. However, there are a number
|
||||||
# of functions that call native:: directly and rely on the defaulting
|
# of functions that call native:: directly and rely on the defaulting
|
||||||
# existing. So for BC, we generate defaults for non-out variants (but not
|
# existing. So for BC, we generate defaults for non-out variants (but not
|
||||||
@ -93,7 +95,7 @@ def argument(
|
|||||||
# default)
|
# default)
|
||||||
should_default = not is_out
|
should_default = not is_out
|
||||||
if isinstance(a, Argument):
|
if isinstance(a, Argument):
|
||||||
default: Optional[str] = None
|
default: str | None = None
|
||||||
if should_default and a.default is not None:
|
if should_default and a.default is not None:
|
||||||
default = cpp.default_expr(a.default, a.type, symint=symint)
|
default = cpp.default_expr(a.default, a.type, symint=symint)
|
||||||
return [
|
return [
|
||||||
@ -144,8 +146,8 @@ def argument(
|
|||||||
assert_never(a)
|
assert_never(a)
|
||||||
|
|
||||||
|
|
||||||
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
|
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
args.extend(func.arguments.non_out)
|
args.extend(func.arguments.non_out)
|
||||||
args.extend(func.arguments.out)
|
args.extend(func.arguments.out)
|
||||||
return [
|
return [
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||||
@ -197,14 +199,14 @@ from torchgen.model import (
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PythonReturns:
|
class PythonReturns:
|
||||||
returns: Tuple[Return, ...]
|
returns: tuple[Return, ...]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PythonArgument:
|
class PythonArgument:
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
default: Optional[str]
|
default: str | None
|
||||||
|
|
||||||
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
|
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
|
||||||
#
|
#
|
||||||
@ -212,7 +214,7 @@ class PythonArgument:
|
|||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# ^
|
# ^
|
||||||
# +--- default_init str
|
# +--- default_init str
|
||||||
default_init: Optional[str]
|
default_init: str | None
|
||||||
|
|
||||||
# Compute argument formal for python argument parsing.
|
# Compute argument formal for python argument parsing.
|
||||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||||
@ -300,12 +302,10 @@ class PythonOutArgument(PythonArgument):
|
|||||||
# 'auto out = _r.tensorlist_n<2>(2);',
|
# 'auto out = _r.tensorlist_n<2>(2);',
|
||||||
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
|
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
|
||||||
# TODO: maybe don't need keep scattered out fields for python signature?
|
# TODO: maybe don't need keep scattered out fields for python signature?
|
||||||
outputs: Tuple[PythonArgument, ...]
|
outputs: tuple[PythonArgument, ...]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_outputs(
|
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
|
||||||
outputs: Tuple[PythonArgument, ...]
|
|
||||||
) -> Optional["PythonOutArgument"]:
|
|
||||||
if not outputs:
|
if not outputs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -339,13 +339,13 @@ class PythonSignature:
|
|||||||
|
|
||||||
# Positional arguments.
|
# Positional arguments.
|
||||||
# TODO: create a dedicated SelfArgument type for 'self'?
|
# TODO: create a dedicated SelfArgument type for 'self'?
|
||||||
input_args: Tuple[PythonArgument, ...]
|
input_args: tuple[PythonArgument, ...]
|
||||||
|
|
||||||
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
|
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
|
||||||
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
|
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
|
||||||
input_kwargs: Tuple[PythonArgument, ...]
|
input_kwargs: tuple[PythonArgument, ...]
|
||||||
|
|
||||||
output_args: Optional[PythonOutArgument]
|
output_args: PythonOutArgument | None
|
||||||
|
|
||||||
# Return types, which are only used by pyi
|
# Return types, which are only used by pyi
|
||||||
returns: PythonReturns
|
returns: PythonReturns
|
||||||
@ -356,7 +356,7 @@ class PythonSignature:
|
|||||||
# for out variant), in which case they will be used as scattered fields without
|
# for out variant), in which case they will be used as scattered fields without
|
||||||
# being packed into 'options'.
|
# being packed into 'options'.
|
||||||
# TODO: maybe create a PythonTensorOptionsArgument?
|
# TODO: maybe create a PythonTensorOptionsArgument?
|
||||||
tensor_options_args: Tuple[PythonArgument, ...]
|
tensor_options_args: tuple[PythonArgument, ...]
|
||||||
|
|
||||||
# method or function signature?
|
# method or function signature?
|
||||||
method: bool
|
method: bool
|
||||||
@ -367,8 +367,8 @@ class PythonSignature:
|
|||||||
|
|
||||||
def arguments(
|
def arguments(
|
||||||
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
||||||
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
|
) -> tuple[PythonArgument | PythonOutArgument, ...]:
|
||||||
result: List[Union[PythonArgument, PythonOutArgument]] = []
|
result: list[PythonArgument | PythonOutArgument] = []
|
||||||
result.extend(self.input_args)
|
result.extend(self.input_args)
|
||||||
result.extend(self.input_kwargs)
|
result.extend(self.input_kwargs)
|
||||||
if self.output_args is not None and not skip_outputs:
|
if self.output_args is not None and not skip_outputs:
|
||||||
@ -394,7 +394,7 @@ class PythonSignature:
|
|||||||
# signature_str_pyi().
|
# signature_str_pyi().
|
||||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = [
|
schema_formals: list[str] = [
|
||||||
a.argument_str(method=self.method, symint=symint) for a in args
|
a.argument_str(method=self.method, symint=symint) for a in args
|
||||||
]
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
@ -405,7 +405,7 @@ class PythonSignature:
|
|||||||
|
|
||||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = [
|
schema_formals: list[str] = [
|
||||||
a.argument_str_pyi(method=self.method) for a in args
|
a.argument_str_pyi(method=self.method) for a in args
|
||||||
]
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
@ -419,10 +419,10 @@ class PythonSignature:
|
|||||||
schema_formals.insert(0, "self")
|
schema_formals.insert(0, "self")
|
||||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||||
|
|
||||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||||
# only pyi uses vararg signatures
|
# only pyi uses vararg signatures
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = [
|
schema_formals: list[str] = [
|
||||||
a.argument_str_pyi(method=self.method) for a in args
|
a.argument_str_pyi(method=self.method) for a in args
|
||||||
]
|
]
|
||||||
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
||||||
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
|||||||
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||||
# [func call]: self.addmm(mat1, mat2, beta, 1)
|
# [func call]: self.addmm(mat1, mat2, beta, 1)
|
||||||
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
|
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
|
||||||
deprecated_args_exprs: Tuple[str, ...]
|
deprecated_args_exprs: tuple[str, ...]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def deprecated(self) -> bool:
|
def deprecated(self) -> bool:
|
||||||
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
|||||||
|
|
||||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||||
args = self.arguments(skip_outputs=skip_outputs)
|
args = self.arguments(skip_outputs=skip_outputs)
|
||||||
schema_formals: List[str] = [
|
schema_formals: list[str] = [
|
||||||
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
||||||
]
|
]
|
||||||
positional_argc = len(self.input_args)
|
positional_argc = len(self.input_args)
|
||||||
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
|||||||
returns_str = returns_str_pyi(self)
|
returns_str = returns_str_pyi(self)
|
||||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||||
|
|
||||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||||
# the codegen doesn't include vararg variants for deprecated signatures
|
# the codegen doesn't include vararg variants for deprecated signatures
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -530,14 +530,14 @@ class PythonSignatureGroup:
|
|||||||
base: NativeFunction
|
base: NativeFunction
|
||||||
|
|
||||||
# The out variant (e.g. conv2d_out)
|
# The out variant (e.g. conv2d_out)
|
||||||
outplace: Optional[NativeFunction]
|
outplace: NativeFunction | None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pairs(
|
def from_pairs(
|
||||||
cls,
|
cls,
|
||||||
functional: PythonSignatureNativeFunctionPair,
|
functional: PythonSignatureNativeFunctionPair,
|
||||||
out: Optional[PythonSignatureNativeFunctionPair],
|
out: PythonSignatureNativeFunctionPair | None,
|
||||||
) -> "PythonSignatureGroup":
|
) -> PythonSignatureGroup:
|
||||||
if out is None:
|
if out is None:
|
||||||
return PythonSignatureGroup(
|
return PythonSignatureGroup(
|
||||||
signature=functional.signature,
|
signature=functional.signature,
|
||||||
@ -716,7 +716,7 @@ def argument_type_str(
|
|||||||
raise RuntimeError(f"unrecognized type {repr(t)}")
|
raise RuntimeError(f"unrecognized type {repr(t)}")
|
||||||
|
|
||||||
|
|
||||||
def argument_type_size(t: Type) -> Optional[int]:
|
def argument_type_size(t: Type) -> int | None:
|
||||||
l = t.is_list_like()
|
l = t.is_list_like()
|
||||||
if l is not None and str(l.elem) != "bool":
|
if l is not None and str(l.elem) != "bool":
|
||||||
return l.size
|
return l.size
|
||||||
@ -750,11 +750,11 @@ def signature(
|
|||||||
def signature_from_schema(
|
def signature_from_schema(
|
||||||
func: FunctionSchema,
|
func: FunctionSchema,
|
||||||
*,
|
*,
|
||||||
category_override: Optional[str],
|
category_override: str | None,
|
||||||
method: bool = False,
|
method: bool = False,
|
||||||
pyi: bool = False,
|
pyi: bool = False,
|
||||||
) -> PythonSignature:
|
) -> PythonSignature:
|
||||||
args: List[Argument] = []
|
args: list[Argument] = []
|
||||||
args.extend(func.arguments.pre_self_positional)
|
args.extend(func.arguments.pre_self_positional)
|
||||||
# Skip SelfArgument if this is method.
|
# Skip SelfArgument if this is method.
|
||||||
if not method and func.arguments.self_arg is not None:
|
if not method and func.arguments.self_arg is not None:
|
||||||
@ -807,10 +807,10 @@ def signature_from_schema(
|
|||||||
)
|
)
|
||||||
is_dummy_function = category_override == "dummy"
|
is_dummy_function = category_override == "dummy"
|
||||||
|
|
||||||
tensor_options_args: List[PythonArgument] = []
|
tensor_options_args: list[PythonArgument] = []
|
||||||
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
||||||
|
|
||||||
def topt_default_init(name: str) -> Optional[str]:
|
def topt_default_init(name: str) -> str | None:
|
||||||
topt_args = func.arguments.tensor_options
|
topt_args = func.arguments.tensor_options
|
||||||
if topt_args is None:
|
if topt_args is None:
|
||||||
return None
|
return None
|
||||||
@ -891,7 +891,7 @@ def signature_from_schema(
|
|||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
|
|
||||||
|
|
||||||
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
|
||||||
if len(returns) <= 1 or all(r.name is None for r in returns):
|
if len(returns) <= 1 or all(r.name is None for r in returns):
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
|
|||||||
return argument_type_str_pyi(t)
|
return argument_type_str_pyi(t)
|
||||||
|
|
||||||
|
|
||||||
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
|
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||||
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
||||||
structseq_name = signature.name
|
structseq_name = signature.name
|
||||||
field_names = structseq_fieldnames(signature.returns.returns)
|
field_names = structseq_fieldnames(signature.returns.returns)
|
||||||
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
|
|||||||
|
|
||||||
def dispatch_lambda_args(
|
def dispatch_lambda_args(
|
||||||
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
||||||
) -> Tuple[DispatchLambdaArgument, ...]:
|
) -> tuple[DispatchLambdaArgument, ...]:
|
||||||
if isinstance(ps, PythonSignatureDeprecated):
|
if isinstance(ps, PythonSignatureDeprecated):
|
||||||
schema = ps.deprecated_schema
|
schema = ps.deprecated_schema
|
||||||
else:
|
else:
|
||||||
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
|
|||||||
method=False,
|
method=False,
|
||||||
cpp_no_default_args=f.cpp_no_default_args,
|
cpp_no_default_args=f.cpp_no_default_args,
|
||||||
)
|
)
|
||||||
out_args: Set[str] = {a.name for a in schema.arguments.out}
|
out_args: set[str] = {a.name for a in schema.arguments.out}
|
||||||
|
|
||||||
# Convert from cpp argument to lambda argument
|
# Convert from cpp argument to lambda argument
|
||||||
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
|
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
|
||||||
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
|
|||||||
def cpp_dispatch_exprs(
|
def cpp_dispatch_exprs(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
*,
|
*,
|
||||||
python_signature: Optional[PythonSignature] = None,
|
python_signature: PythonSignature | None = None,
|
||||||
) -> Tuple[str, ...]:
|
) -> tuple[str, ...]:
|
||||||
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
|
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
|
||||||
|
|
||||||
exprs: Tuple[str, ...] = tuple()
|
exprs: tuple[str, ...] = tuple()
|
||||||
if not isinstance(python_signature, PythonSignatureDeprecated):
|
if not isinstance(python_signature, PythonSignatureDeprecated):
|
||||||
# By default the exprs are consistent with the C++ signature.
|
# By default the exprs are consistent with the C++ signature.
|
||||||
exprs = tuple(a.name for a in cpp_args)
|
exprs = tuple(a.name for a in cpp_args)
|
||||||
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
|
|||||||
# For certain cases it is intentionally more restrictive than necessary,
|
# For certain cases it is intentionally more restrictive than necessary,
|
||||||
# e.g.: it doesn't accepts doublelist with definite size.
|
# e.g.: it doesn't accepts doublelist with definite size.
|
||||||
def arg_parser_unpack_method(
|
def arg_parser_unpack_method(
|
||||||
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
|
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
has_default_init = default_init is not None
|
has_default_init = default_init is not None
|
||||||
if has_default_init and str(t) not in (
|
if has_default_init and str(t) not in (
|
||||||
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
|
|||||||
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
|
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
|
||||||
def arg_parser_output_exprs(
|
def arg_parser_output_exprs(
|
||||||
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
||||||
) -> Dict[str, PythonArgParserOutputExpr]:
|
) -> dict[str, PythonArgParserOutputExpr]:
|
||||||
return {
|
return {
|
||||||
e.name: e
|
e.name: e
|
||||||
for i, a in enumerate(ps.arguments())
|
for i, a in enumerate(ps.arguments())
|
||||||
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
|
|||||||
# outputs.
|
# outputs.
|
||||||
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||||
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
||||||
inits: List[str] = []
|
inits: list[str] = []
|
||||||
lambda_args_exprs: Dict[str, str] = {}
|
lambda_args_exprs: dict[str, str] = {}
|
||||||
|
|
||||||
has_toptions = has_tensor_options(f)
|
has_toptions = has_tensor_options(f)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -97,7 +97,7 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
|||||||
|
|
||||||
|
|
||||||
# Structured kernels are never defaulted
|
# Structured kernels are never defaulted
|
||||||
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
|
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
|
||||||
if isinstance(a, Argument):
|
if isinstance(a, Argument):
|
||||||
return [
|
return [
|
||||||
Binding(
|
Binding(
|
||||||
@ -115,15 +115,15 @@ def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[B
|
|||||||
assert_never(a)
|
assert_never(a)
|
||||||
|
|
||||||
|
|
||||||
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
|
|
||||||
if g.out.precomputed:
|
if g.out.precomputed:
|
||||||
# A list of parameters for the impl function with
|
# A list of parameters for the impl function with
|
||||||
# certain parameters replaced with precomputed counterparts
|
# certain parameters replaced with precomputed counterparts
|
||||||
# as specified in native_functions.yaml.
|
# as specified in native_functions.yaml.
|
||||||
non_out_args_replaced: List[
|
non_out_args_replaced: list[
|
||||||
Union[Argument, TensorOptionsArguments, SelfArgument]
|
Argument | TensorOptionsArguments | SelfArgument
|
||||||
] = []
|
] = []
|
||||||
for a in g.out.func.arguments.non_out:
|
for a in g.out.func.arguments.non_out:
|
||||||
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
||||||
@ -145,13 +145,13 @@ def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
|||||||
return [r for arg in args for r in argument(arg)]
|
return [r for arg in args for r in argument(arg)]
|
||||||
|
|
||||||
|
|
||||||
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
args.extend(g.functional.func.arguments.non_out)
|
args.extend(g.functional.func.arguments.non_out)
|
||||||
return [r for arg in args for r in argument(arg)]
|
return [r for arg in args for r in argument(arg)]
|
||||||
|
|
||||||
|
|
||||||
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
args.extend(g.out.func.arguments.out)
|
args.extend(g.out.func.arguments.out)
|
||||||
return [r for arg in args for r in argument(arg)]
|
return [r for arg in args for r in argument(arg)]
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import Dict, List, NoReturn, Sequence, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import NoReturn, Sequence
|
||||||
|
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
ArrayRefCType,
|
ArrayRefCType,
|
||||||
@ -95,13 +97,13 @@ class UnsatError(RuntimeError):
|
|||||||
# something more complicated, e.g., tracking the set of bindings in a context,
|
# something more complicated, e.g., tracking the set of bindings in a context,
|
||||||
# you may find using these smaller types more convenient.
|
# you may find using these smaller types more convenient.
|
||||||
def translate(
|
def translate(
|
||||||
bindings: Sequence[Union[Expr, Binding]],
|
bindings: Sequence[Expr | Binding],
|
||||||
goals: Sequence[Union[NamedCType, Binding]],
|
goals: Sequence[NamedCType | Binding],
|
||||||
*,
|
*,
|
||||||
method: bool = False,
|
method: bool = False,
|
||||||
allow_expensive_conversions: bool = False,
|
allow_expensive_conversions: bool = False,
|
||||||
) -> List[Expr]:
|
) -> list[Expr]:
|
||||||
binding_exprs: List[Expr] = []
|
binding_exprs: list[Expr] = []
|
||||||
for b in bindings:
|
for b in bindings:
|
||||||
if isinstance(b, Binding):
|
if isinstance(b, Binding):
|
||||||
binding_exprs.append(
|
binding_exprs.append(
|
||||||
@ -113,7 +115,7 @@ def translate(
|
|||||||
else:
|
else:
|
||||||
binding_exprs.append(b)
|
binding_exprs.append(b)
|
||||||
|
|
||||||
goal_ctypes: List[NamedCType] = []
|
goal_ctypes: list[NamedCType] = []
|
||||||
for g in goals:
|
for g in goals:
|
||||||
if isinstance(g, Binding):
|
if isinstance(g, Binding):
|
||||||
goal_ctypes.append(g.nctype)
|
goal_ctypes.append(g.nctype)
|
||||||
@ -121,7 +123,7 @@ def translate(
|
|||||||
goal_ctypes.append(g)
|
goal_ctypes.append(g)
|
||||||
|
|
||||||
# Add all the bindings to the context
|
# Add all the bindings to the context
|
||||||
ctx: Dict[NamedCType, str] = {}
|
ctx: dict[NamedCType, str] = {}
|
||||||
for b in binding_exprs:
|
for b in binding_exprs:
|
||||||
ctx[b.type] = b.expr
|
ctx[b.type] = b.expr
|
||||||
|
|
||||||
|
@ -1,7 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
|
from typing import Iterator, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.types.types_base import Binding, CType, Expr
|
from torchgen.api.types.types_base import Binding, CType, Expr
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
BackendIndex,
|
BackendIndex,
|
||||||
FunctionSchema,
|
FunctionSchema,
|
||||||
@ -38,7 +43,7 @@ class CppSignature:
|
|||||||
symint: bool
|
symint: bool
|
||||||
|
|
||||||
# The set of C++ arguments which should not have defaults applied to them
|
# The set of C++ arguments which should not have defaults applied to them
|
||||||
cpp_no_default_args: Set[str]
|
cpp_no_default_args: set[str]
|
||||||
|
|
||||||
# Is this a fallback C++ binding? Fallback bindings are enabled by
|
# Is this a fallback C++ binding? Fallback bindings are enabled by
|
||||||
# manual_cpp_binding: True and are alternate, non-public API that
|
# manual_cpp_binding: True and are alternate, non-public API that
|
||||||
@ -72,7 +77,7 @@ class CppSignature:
|
|||||||
def decl(
|
def decl(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
is_redispatching_fn: bool = False,
|
is_redispatching_fn: bool = False,
|
||||||
suppress_symint_suffix: bool = False,
|
suppress_symint_suffix: bool = False,
|
||||||
@ -93,7 +98,7 @@ class CppSignature:
|
|||||||
def defn(
|
def defn(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
name: Optional[str] = None,
|
name: str | None = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
is_redispatching_fn: bool = False,
|
is_redispatching_fn: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -126,9 +131,9 @@ class CppSignature:
|
|||||||
class CppSignatureGroup:
|
class CppSignatureGroup:
|
||||||
func: FunctionSchema
|
func: FunctionSchema
|
||||||
signature: CppSignature
|
signature: CppSignature
|
||||||
faithful_signature: Optional[CppSignature]
|
faithful_signature: CppSignature | None
|
||||||
symint_signature: Optional[CppSignature]
|
symint_signature: CppSignature | None
|
||||||
symint_faithful_signature: Optional[CppSignature]
|
symint_faithful_signature: CppSignature | None
|
||||||
|
|
||||||
def most_faithful_signature(self) -> CppSignature:
|
def most_faithful_signature(self) -> CppSignature:
|
||||||
if self.faithful_signature:
|
if self.faithful_signature:
|
||||||
@ -149,7 +154,7 @@ class CppSignatureGroup:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_function(
|
def from_native_function(
|
||||||
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
||||||
) -> "CppSignatureGroup":
|
) -> CppSignatureGroup:
|
||||||
func = f.func
|
func = f.func
|
||||||
|
|
||||||
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
|
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
|
||||||
@ -162,16 +167,16 @@ class CppSignatureGroup:
|
|||||||
cpp_no_default_args=f.cpp_no_default_args,
|
cpp_no_default_args=f.cpp_no_default_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
|
def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
|
||||||
faithful_signature: Optional[CppSignature] = None
|
faithful_signature: CppSignature | None = None
|
||||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||||
faithful_signature = make_sig(faithful=True, symint=symint)
|
faithful_signature = make_sig(faithful=True, symint=symint)
|
||||||
signature = make_sig(faithful=False, symint=symint)
|
signature = make_sig(faithful=False, symint=symint)
|
||||||
return signature, faithful_signature
|
return signature, faithful_signature
|
||||||
|
|
||||||
signature, faithful_signature = make_sigs(symint=False)
|
signature, faithful_signature = make_sigs(symint=False)
|
||||||
symint_signature: Optional[CppSignature] = None
|
symint_signature: CppSignature | None = None
|
||||||
symint_faithful_signature: Optional[CppSignature] = None
|
symint_faithful_signature: CppSignature | None = None
|
||||||
if func.has_symint():
|
if func.has_symint():
|
||||||
symint_signature, symint_faithful_signature = make_sigs(symint=True)
|
symint_signature, symint_faithful_signature = make_sigs(symint=True)
|
||||||
|
|
||||||
@ -196,20 +201,20 @@ class DispatcherSignature:
|
|||||||
|
|
||||||
symint: bool = True
|
symint: bool = True
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> list[Binding]:
|
||||||
return dispatcher.arguments(self.func, symint=self.symint)
|
return dispatcher.arguments(self.func, symint=self.symint)
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.prefix + dispatcher.name(self.func)
|
return self.prefix + dispatcher.name(self.func)
|
||||||
|
|
||||||
def decl(self, name: Optional[str] = None) -> str:
|
def decl(self, name: str | None = None) -> str:
|
||||||
args_str = ", ".join(a.decl() for a in self.arguments())
|
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||||
if name is None:
|
if name is None:
|
||||||
name = self.name()
|
name = self.name()
|
||||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||||
|
|
||||||
def defn(
|
def defn(
|
||||||
self, name: Optional[str] = None, *, is_redispatching_fn: bool = False
|
self, name: str | None = None, *, is_redispatching_fn: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
args = [a.defn() for a in self.arguments()]
|
args = [a.defn() for a in self.arguments()]
|
||||||
if is_redispatching_fn:
|
if is_redispatching_fn:
|
||||||
@ -219,7 +224,7 @@ class DispatcherSignature:
|
|||||||
name = self.name()
|
name = self.name()
|
||||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||||
|
|
||||||
def exprs(self) -> List[Expr]:
|
def exprs(self) -> list[Expr]:
|
||||||
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
||||||
|
|
||||||
def returns_type(self) -> CType:
|
def returns_type(self) -> CType:
|
||||||
@ -237,7 +242,7 @@ class DispatcherSignature:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_schema(
|
def from_schema(
|
||||||
func: FunctionSchema, *, prefix: str = "", symint: bool = True
|
func: FunctionSchema, *, prefix: str = "", symint: bool = True
|
||||||
) -> "DispatcherSignature":
|
) -> DispatcherSignature:
|
||||||
return DispatcherSignature(func, prefix, symint)
|
return DispatcherSignature(func, prefix, symint)
|
||||||
|
|
||||||
|
|
||||||
@ -253,13 +258,13 @@ class NativeSignature:
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.prefix + native.name(self.func)
|
return self.prefix + native.name(self.func)
|
||||||
|
|
||||||
def decl(self, name: Optional[str] = None) -> str:
|
def decl(self, name: str | None = None) -> str:
|
||||||
args_str = ", ".join(a.decl() for a in self.arguments())
|
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||||
if name is None:
|
if name is None:
|
||||||
name = self.name()
|
name = self.name()
|
||||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
|
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
|
||||||
|
|
||||||
def defn(self, name: Optional[str] = None) -> str:
|
def defn(self, name: str | None = None) -> str:
|
||||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||||
if name is None:
|
if name is None:
|
||||||
name = self.name()
|
name = self.name()
|
||||||
@ -270,13 +275,13 @@ class NativeSignature:
|
|||||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
|
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> list[Binding]:
|
||||||
return native.arguments(self.func, symint=self.symint)
|
return native.arguments(self.func, symint=self.symint)
|
||||||
|
|
||||||
def returns_type(self) -> CType:
|
def returns_type(self) -> CType:
|
||||||
return native.returns_type(self.func.returns, symint=self.symint)
|
return native.returns_type(self.func.returns, symint=self.symint)
|
||||||
|
|
||||||
def dispatcher_exprs(self) -> List[Expr]:
|
def dispatcher_exprs(self) -> list[Expr]:
|
||||||
return translate.translate(
|
return translate.translate(
|
||||||
self.arguments(), dispatcher.arguments(self.func), method=False
|
self.arguments(), dispatcher.arguments(self.func), method=False
|
||||||
)
|
)
|
||||||
@ -307,7 +312,7 @@ class FunctionalizationLambda:
|
|||||||
# are we generating the forward lambda or the reverse lambda?
|
# are we generating the forward lambda or the reverse lambda?
|
||||||
is_reverse: bool
|
is_reverse: bool
|
||||||
|
|
||||||
def captures(self) -> List[Expr]:
|
def captures(self) -> list[Expr]:
|
||||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
||||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
||||||
# and plumb it into the lambda.
|
# and plumb it into the lambda.
|
||||||
@ -336,7 +341,7 @@ class FunctionalizationLambda:
|
|||||||
]
|
]
|
||||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
||||||
|
|
||||||
def inner_call(self, *, reapply_views: Optional[bool] = None) -> str:
|
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
||||||
inner_call_name = functionalization.name(
|
inner_call_name = functionalization.name(
|
||||||
self.g,
|
self.g,
|
||||||
is_reverse=self.is_reverse,
|
is_reverse=self.is_reverse,
|
||||||
@ -366,7 +371,7 @@ class FunctionalizationLambda:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_func(
|
def from_func(
|
||||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
||||||
) -> "FunctionalizationLambda":
|
) -> FunctionalizationLambda:
|
||||||
return FunctionalizationLambda(g, is_reverse)
|
return FunctionalizationLambda(g, is_reverse)
|
||||||
|
|
||||||
|
|
||||||
@ -375,11 +380,11 @@ class StructuredImplSignature:
|
|||||||
g: NativeFunctionsGroup
|
g: NativeFunctionsGroup
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def defn(self, name: Optional[str] = None) -> str:
|
def defn(self, name: str | None = None) -> str:
|
||||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||||
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
|
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> list[Binding]:
|
||||||
return structured.impl_arguments(self.g)
|
return structured.impl_arguments(self.g)
|
||||||
|
|
||||||
|
|
||||||
@ -388,7 +393,7 @@ class StructuredImplSignature:
|
|||||||
|
|
||||||
def kernel_signature(
|
def kernel_signature(
|
||||||
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
|
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
|
||||||
) -> Union["NativeSignature", "DispatcherSignature"]:
|
) -> NativeSignature | DispatcherSignature:
|
||||||
# Note [External Backends Follow Dispatcher API]
|
# Note [External Backends Follow Dispatcher API]
|
||||||
# Kernel signatures for in-tree backends follow the "native" API,
|
# Kernel signatures for in-tree backends follow the "native" API,
|
||||||
# while kernels for out-of-tree backends follow the dispatcher API.
|
# while kernels for out-of-tree backends follow the dispatcher API.
|
||||||
|
@ -12,8 +12,10 @@ if we want to generate code for another C++ library.
|
|||||||
Add new types to `types.py` if these types are ATen/c10 related.
|
Add new types to `types.py` if these types are ATen/c10 related.
|
||||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from torchgen.api.types.types_base import (
|
from torchgen.api.types.types_base import (
|
||||||
BaseCppType,
|
BaseCppType,
|
||||||
@ -83,7 +85,7 @@ symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
|||||||
scalar_t = BaseCppType("", "scalar_t")
|
scalar_t = BaseCppType("", "scalar_t")
|
||||||
opmath_t = BaseCppType("", "opmath_t")
|
opmath_t = BaseCppType("", "opmath_t")
|
||||||
|
|
||||||
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
|
||||||
ScalarType.Byte: byteT,
|
ScalarType.Byte: byteT,
|
||||||
ScalarType.Char: charT,
|
ScalarType.Char: charT,
|
||||||
ScalarType.Short: shortT,
|
ScalarType.Short: shortT,
|
||||||
@ -102,7 +104,7 @@ ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
|||||||
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
|
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||||
BaseTy.int: longT,
|
BaseTy.int: longT,
|
||||||
BaseTy.float: doubleT,
|
BaseTy.float: doubleT,
|
||||||
BaseTy.bool: boolT,
|
BaseTy.bool: boolT,
|
||||||
@ -128,7 +130,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class OptionalCType(CType):
|
class OptionalCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -137,13 +139,13 @@ class OptionalCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
|
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return OptionalCType(self.elem.remove_const_ref())
|
return OptionalCType(self.elem.remove_const_ref())
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ListCType(CType):
|
class ListCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -152,13 +154,13 @@ class ListCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
|
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return ListCType(self.elem.remove_const_ref())
|
return ListCType(self.elem.remove_const_ref())
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ArrayRefCType(CType):
|
class ArrayRefCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -167,7 +169,7 @@ class ArrayRefCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return ArrayRefCType(self.elem.remove_const_ref())
|
return ArrayRefCType(self.elem.remove_const_ref())
|
||||||
|
|
||||||
|
|
||||||
@ -185,5 +187,5 @@ class VectorizedCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return self
|
return self
|
||||||
|
@ -12,11 +12,16 @@ if we want to generate code for another C++ library.
|
|||||||
Add new types to `types.py` if these types are ATen/c10 related.
|
Add new types to `types.py` if these types are ATen/c10 related.
|
||||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import List, Optional, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
|
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
|
||||||
|
|
||||||
|
|
||||||
@ -36,7 +41,7 @@ ArgName = Union[str, SpecialArgName]
|
|||||||
# This class shouldn't be created directly; instead, use/create one of the singletons below.
|
# This class shouldn't be created directly; instead, use/create one of the singletons below.
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BaseCppType:
|
class BaseCppType:
|
||||||
ns: Optional[str]
|
ns: str | None
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -71,7 +76,7 @@ class CType(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -87,13 +92,13 @@ class BaseCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return str(self.type).replace("at::", "")
|
return str(self.type).replace("at::", "")
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ConstRefCType(CType):
|
class ConstRefCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
if strip_ref:
|
if strip_ref:
|
||||||
@ -103,13 +108,13 @@ class ConstRefCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"const {self.elem.cpp_type_registration_declarations()} &"
|
return f"const {self.elem.cpp_type_registration_declarations()} &"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return self.elem.remove_const_ref()
|
return self.elem.remove_const_ref()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class VectorCType(CType):
|
class VectorCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -118,13 +123,13 @@ class VectorCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
|
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return VectorCType(self.elem.remove_const_ref())
|
return VectorCType(self.elem.remove_const_ref())
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ArrayCType(CType):
|
class ArrayCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
size: int
|
size: int
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
@ -134,13 +139,13 @@ class ArrayCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
|
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return ArrayCType(self.elem.remove_const_ref(), self.size)
|
return ArrayCType(self.elem.remove_const_ref(), self.size)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TupleCType(CType):
|
class TupleCType(CType):
|
||||||
elems: List["CType"]
|
elems: list[CType]
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -149,13 +154,13 @@ class TupleCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
|
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return TupleCType([e.remove_const_ref() for e in self.elems])
|
return TupleCType([e.remove_const_ref() for e in self.elems])
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MutRefCType(CType):
|
class MutRefCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
if strip_ref:
|
if strip_ref:
|
||||||
@ -165,7 +170,7 @@ class MutRefCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"{self.elem.cpp_type_registration_declarations()} &"
|
return f"{self.elem.cpp_type_registration_declarations()} &"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return self.elem.remove_const_ref()
|
return self.elem.remove_const_ref()
|
||||||
|
|
||||||
|
|
||||||
@ -190,10 +195,10 @@ class NamedCType:
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return self.type.cpp_type_registration_declarations()
|
return self.type.cpp_type_registration_declarations()
|
||||||
|
|
||||||
def remove_const_ref(self) -> "NamedCType":
|
def remove_const_ref(self) -> NamedCType:
|
||||||
return NamedCType(self.name, self.type.remove_const_ref())
|
return NamedCType(self.name, self.type.remove_const_ref())
|
||||||
|
|
||||||
def with_name(self, name: str) -> "NamedCType":
|
def with_name(self, name: str) -> NamedCType:
|
||||||
return NamedCType(name, self.type)
|
return NamedCType(name, self.type)
|
||||||
|
|
||||||
|
|
||||||
@ -208,11 +213,11 @@ class NamedCType:
|
|||||||
class Binding:
|
class Binding:
|
||||||
name: str
|
name: str
|
||||||
nctype: NamedCType
|
nctype: NamedCType
|
||||||
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
|
argument: Argument | TensorOptionsArguments | SelfArgument
|
||||||
# TODO: maybe don't represent default here
|
# TODO: maybe don't represent default here
|
||||||
default: Optional[str] = None
|
default: str | None = None
|
||||||
|
|
||||||
def rename(self, name: str) -> "Binding":
|
def rename(self, name: str) -> Binding:
|
||||||
return Binding(
|
return Binding(
|
||||||
name=name,
|
name=name,
|
||||||
nctype=self.nctype,
|
nctype=self.nctype,
|
||||||
@ -224,7 +229,7 @@ class Binding:
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
return self.nctype.cpp_type()
|
return self.nctype.cpp_type()
|
||||||
|
|
||||||
def no_default(self) -> "Binding":
|
def no_default(self) -> Binding:
|
||||||
return Binding(
|
return Binding(
|
||||||
name=self.name,
|
name=self.name,
|
||||||
nctype=self.nctype,
|
nctype=self.nctype,
|
||||||
@ -255,7 +260,7 @@ class Binding:
|
|||||||
def defn(self) -> str:
|
def defn(self) -> str:
|
||||||
return f"{self.type} {self.name}"
|
return f"{self.type} {self.name}"
|
||||||
|
|
||||||
def with_name(self, name: str) -> "Binding":
|
def with_name(self, name: str) -> Binding:
|
||||||
return Binding(
|
return Binding(
|
||||||
name=name, nctype=self.nctype, argument=self.argument, default=self.default
|
name=name, nctype=self.nctype, argument=self.argument, default=self.default
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torchgen.api.types as api_types
|
import torchgen.api.types as api_types
|
||||||
from torchgen.api import cpp, structured
|
from torchgen.api import cpp, structured
|
||||||
@ -38,7 +39,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
|
|||||||
# argument registers)
|
# argument registers)
|
||||||
#
|
#
|
||||||
# NB: used for CPU only
|
# NB: used for CPU only
|
||||||
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
|
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
|
||||||
# Dispatch stubs are always plain ints
|
# Dispatch stubs are always plain ints
|
||||||
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||||||
if r is not None:
|
if r is not None:
|
||||||
@ -134,8 +135,8 @@ def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class UfunctorBindings:
|
class UfunctorBindings:
|
||||||
ctor: List[Binding]
|
ctor: list[Binding]
|
||||||
apply: List[Binding]
|
apply: list[Binding]
|
||||||
|
|
||||||
|
|
||||||
# ufunctors are a CUDA-only concept representing functors that take some of
|
# ufunctors are a CUDA-only concept representing functors that take some of
|
||||||
@ -156,7 +157,7 @@ class UfunctorBindings:
|
|||||||
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
|
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
|
||||||
# to the operator() definition
|
# to the operator() definition
|
||||||
def ufunctor_arguments(
|
def ufunctor_arguments(
|
||||||
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
|
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
|
||||||
) -> UfunctorBindings:
|
) -> UfunctorBindings:
|
||||||
ctor = []
|
ctor = []
|
||||||
apply = []
|
apply = []
|
||||||
@ -185,7 +186,7 @@ def ufunctor_arguments(
|
|||||||
# }
|
# }
|
||||||
#
|
#
|
||||||
# In this file, we refer to T as compute_t which is bound by caller
|
# In this file, we refer to T as compute_t which is bound by caller
|
||||||
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
|
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
|
||||||
return [
|
return [
|
||||||
ufunc_argument(a, compute_t=compute_t)
|
ufunc_argument(a, compute_t=compute_t)
|
||||||
for a in g.functional.func.arguments.flat_non_out
|
for a in g.functional.func.arguments.flat_non_out
|
||||||
@ -197,7 +198,7 @@ def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Bindin
|
|||||||
#
|
#
|
||||||
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
||||||
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
||||||
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||||
# stubs drop all tensor arguments (they are implicit in the TensorIterator
|
# stubs drop all tensor arguments (they are implicit in the TensorIterator
|
||||||
# argument and keep everything else)
|
# argument and keep everything else)
|
||||||
return [
|
return [
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple
|
from __future__ import annotations
|
||||||
|
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
from torchgen.api.types import Binding, CppSignatureGroup, CType
|
from torchgen.api.types import Binding, CppSignatureGroup, CType
|
||||||
@ -103,7 +103,7 @@ def name(f: NativeFunction) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# Convert all the arguments in a NativeFunction to C++ code
|
# Convert all the arguments in a NativeFunction to C++ code
|
||||||
def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
|
def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
|
||||||
# we need the 'self' argument so method needs to be False
|
# we need the 'self' argument so method needs to be False
|
||||||
args = (
|
args = (
|
||||||
CppSignatureGroup.from_native_function(f, method=False)
|
CppSignatureGroup.from_native_function(f, method=False)
|
||||||
@ -138,7 +138,7 @@ def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
|
|||||||
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
||||||
def argumenttype_ivalue_convert(
|
def argumenttype_ivalue_convert(
|
||||||
t: Type, arg_name: str, *, mutable: bool = False
|
t: Type, arg_name: str, *, mutable: bool = False
|
||||||
) -> Tuple[str, CType, List[str], List[str]]:
|
) -> tuple[str, CType, list[str], list[str]]:
|
||||||
# Unboxing is for mobile, which doesn't care about SymInts
|
# Unboxing is for mobile, which doesn't care about SymInts
|
||||||
ctype = cpp.argumenttype_type(
|
ctype = cpp.argumenttype_type(
|
||||||
t=t, mutable=mutable, binds=arg_name, symint=False
|
t=t, mutable=mutable, binds=arg_name, symint=False
|
||||||
@ -172,7 +172,7 @@ def argumenttype_ivalue_convert(
|
|||||||
|
|
||||||
def _gen_code_base_type(
|
def _gen_code_base_type(
|
||||||
arg_name: str, out_name: str, ctype: CType
|
arg_name: str, out_name: str, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
return [
|
return [
|
||||||
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||||
], []
|
], []
|
||||||
@ -180,7 +180,7 @@ def _gen_code_base_type(
|
|||||||
|
|
||||||
def _gen_code_optional_type(
|
def _gen_code_optional_type(
|
||||||
arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
in_name = f"{arg_name}_opt_in"
|
in_name = f"{arg_name}_opt_in"
|
||||||
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
|
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
|
||||||
return (
|
return (
|
||||||
@ -203,7 +203,7 @@ if ({arg_name}_opt.has_value()) {{
|
|||||||
|
|
||||||
def _gen_code_list_type(
|
def _gen_code_list_type(
|
||||||
arg_name: str, out_name: str, t: ListType, ctype: CType
|
arg_name: str, out_name: str, t: ListType, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
in_name = f"{arg_name}_list_in"
|
in_name = f"{arg_name}_list_in"
|
||||||
elem_name = f"{arg_name}_elem"
|
elem_name = f"{arg_name}_elem"
|
||||||
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
|
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Mapping, Match, Optional, Sequence
|
from typing import Mapping, Sequence
|
||||||
|
|
||||||
|
|
||||||
# match $identifier or ${identifier} and replace with value in env
|
# match $identifier or ${identifier} and replace with value in env
|
||||||
@ -20,7 +22,7 @@ class CodeTemplate:
|
|||||||
filename: str
|
filename: str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_file(filename: str) -> "CodeTemplate":
|
def from_file(filename: str) -> CodeTemplate:
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
return CodeTemplate(f.read(), filename)
|
return CodeTemplate(f.read(), filename)
|
||||||
|
|
||||||
@ -29,7 +31,7 @@ class CodeTemplate:
|
|||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
|
||||||
def substitute(
|
def substitute(
|
||||||
self, env: Optional[Mapping[str, object]] = None, **kwargs: object
|
self, env: Mapping[str, object] | None = None, **kwargs: object
|
||||||
) -> str:
|
) -> str:
|
||||||
if env is None:
|
if env is None:
|
||||||
env = {}
|
env = {}
|
||||||
@ -43,7 +45,7 @@ class CodeTemplate:
|
|||||||
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
||||||
).rstrip()
|
).rstrip()
|
||||||
|
|
||||||
def replace(match: Match[str]) -> str:
|
def replace(match: re.Match[str]) -> str:
|
||||||
indent = match.group(1)
|
indent = match.group(1)
|
||||||
key = match.group(2)
|
key = match.group(2)
|
||||||
comma_before = ""
|
comma_before = ""
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
|
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
|
||||||
|
|
||||||
import torchgen.local as local
|
import torchgen.local as local
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
@ -38,7 +40,7 @@ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
|
|||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def native_function_manager(
|
def native_function_manager(
|
||||||
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
|
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
|
||||||
) -> Iterator[None]:
|
) -> Iterator[None]:
|
||||||
if isinstance(g, NativeFunctionsGroup):
|
if isinstance(g, NativeFunctionsGroup):
|
||||||
# By default, we associate all errors with structured native functions
|
# By default, we associate all errors with structured native functions
|
||||||
@ -118,10 +120,10 @@ def with_native_function_and_index(
|
|||||||
|
|
||||||
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
|
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
|
||||||
def with_native_function_and_indices(
|
def with_native_function_and_indices(
|
||||||
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
|
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
|
||||||
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
|
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
|
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
|
||||||
with native_function_manager(f):
|
with native_function_manager(f):
|
||||||
return func(f, backend_indices)
|
return func(f, backend_indices)
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any
|
||||||
|
|
||||||
import torchgen.api.dispatcher as dispatcher
|
import torchgen.api.dispatcher as dispatcher
|
||||||
from torchgen.api.lazy import (
|
from torchgen.api.lazy import (
|
||||||
@ -109,7 +111,7 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str:
|
|||||||
|
|
||||||
def gen_fallback_code(
|
def gen_fallback_code(
|
||||||
schema: LazyIrSchema,
|
schema: LazyIrSchema,
|
||||||
sig: Union[DispatcherSignature, NativeSignature],
|
sig: DispatcherSignature | NativeSignature,
|
||||||
overload_name: str,
|
overload_name: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
@ -147,9 +149,9 @@ def aten_symbol(schema: LazyIrSchema) -> str:
|
|||||||
# converts all tensor-like arguments to meta tensors. Returns:
|
# converts all tensor-like arguments to meta tensors. Returns:
|
||||||
# (1) a string containing all of the logic that does the conversions.
|
# (1) a string containing all of the logic that does the conversions.
|
||||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
|
||||||
context: List[Binding] = []
|
context: list[Binding] = []
|
||||||
unwrapped_tensor_args: List[str] = []
|
unwrapped_tensor_args: list[str] = []
|
||||||
for arg in sig.arguments():
|
for arg in sig.arguments():
|
||||||
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
||||||
unwrapped_name = f"{arg.name}_meta"
|
unwrapped_name = f"{arg.name}_meta"
|
||||||
@ -171,7 +173,7 @@ class GenLazyIR(ABC):
|
|||||||
use_lazy_shape: bool
|
use_lazy_shape: bool
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
||||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||||
metadata = self.backend_index.get_kernel(
|
metadata = self.backend_index.get_kernel(
|
||||||
f.functional if isinstance(f, NativeFunctionsGroup) else f
|
f.functional if isinstance(f, NativeFunctionsGroup) else f
|
||||||
@ -236,7 +238,7 @@ class GenLazyIR(ABC):
|
|||||||
/* num_outputs */ {len(schema.returns)},
|
/* num_outputs */ {len(schema.returns)},
|
||||||
torch::lazy::MHash({scalar_hashes}))"""
|
torch::lazy::MHash({scalar_hashes}))"""
|
||||||
|
|
||||||
def gen(self, schema: LazyIrSchema) -> List[str]:
|
def gen(self, schema: LazyIrSchema) -> list[str]:
|
||||||
opkind = schema.opkind or aten_symbol(schema)
|
opkind = schema.opkind or aten_symbol(schema)
|
||||||
|
|
||||||
# for now, we just want one IR class decl and soon after also the method defs
|
# for now, we just want one IR class decl and soon after also the method defs
|
||||||
@ -413,7 +415,7 @@ class GenLazyNativeFuncDefinition:
|
|||||||
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
||||||
value_args = schema.filtered_args(values=True, scalars=False)
|
value_args = schema.filtered_args(values=True, scalars=False)
|
||||||
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
|
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
|
||||||
lazy_tensor_decls: List[str] = []
|
lazy_tensor_decls: list[str] = []
|
||||||
for arg in value_args:
|
for arg in value_args:
|
||||||
if arg.is_wrapped_scalar:
|
if arg.is_wrapped_scalar:
|
||||||
if isinstance(arg.lazy_type, OptionalCType):
|
if isinstance(arg.lazy_type, OptionalCType):
|
||||||
@ -460,7 +462,7 @@ class GenLazyNativeFuncDefinition:
|
|||||||
func: NativeFunction,
|
func: NativeFunction,
|
||||||
schema: LazyIrSchema,
|
schema: LazyIrSchema,
|
||||||
metadata: BackendMetadata,
|
metadata: BackendMetadata,
|
||||||
sig: Union[DispatcherSignature, NativeSignature],
|
sig: DispatcherSignature | NativeSignature,
|
||||||
) -> str:
|
) -> str:
|
||||||
if self.gen_forced_fallback_code:
|
if self.gen_forced_fallback_code:
|
||||||
return gen_fallback_code(
|
return gen_fallback_code(
|
||||||
@ -574,7 +576,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
|
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
|
||||||
# xla uses an instance method for tensor creation, for the time being
|
# xla uses an instance method for tensor creation, for the time being
|
||||||
if self.create_from_first_tensor:
|
if self.create_from_first_tensor:
|
||||||
# TODO(whc) remove this if XLA switches to using static method for creation
|
# TODO(whc) remove this if XLA switches to using static method for creation
|
||||||
@ -615,7 +617,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
|||||||
return bridge_str
|
return bridge_str
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, func: NativeFunction) -> List[str]:
|
def __call__(self, func: NativeFunction) -> list[str]:
|
||||||
sig = kernel_signature(func, self.backend_index)
|
sig = kernel_signature(func, self.backend_index)
|
||||||
metadata = self.backend_index.get_kernel(func)
|
metadata = self.backend_index.get_kernel(func)
|
||||||
assert metadata is not None
|
assert metadata is not None
|
||||||
@ -639,7 +641,7 @@ class ComputeShapeSignature:
|
|||||||
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
|
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
|
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
|
||||||
self.__schema = LazyIrSchema(f.func, symint=symint)
|
self.__schema = LazyIrSchema(f.func, symint=symint)
|
||||||
self.__dispatch_args = ", ".join(
|
self.__dispatch_args = ", ".join(
|
||||||
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
|
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
|
||||||
@ -670,7 +672,7 @@ class GenLazyShapeInferenceDefinition:
|
|||||||
tensor_class: str
|
tensor_class: str
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> List[str]:
|
def __call__(self, f: NativeFunction) -> list[str]:
|
||||||
metadata = self.backend_index.get_kernel(f)
|
metadata = self.backend_index.get_kernel(f)
|
||||||
assert metadata is not None
|
assert metadata is not None
|
||||||
|
|
||||||
@ -687,8 +689,8 @@ class GenLazyShapeInferenceDefinition:
|
|||||||
|
|
||||||
|
|
||||||
def generate_non_native_lazy_ir_nodes(
|
def generate_non_native_lazy_ir_nodes(
|
||||||
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
|
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""Generate the non-native lazy IR node classes"""
|
"""Generate the non-native lazy IR node classes"""
|
||||||
nodes = []
|
nodes = []
|
||||||
for op in non_native:
|
for op in non_native:
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
import torchgen.api.meta as meta
|
import torchgen.api.meta as meta
|
||||||
import torchgen.api.structured as structured
|
import torchgen.api.structured as structured
|
||||||
@ -9,7 +9,7 @@ from torchgen.utils import mapMaybe
|
|||||||
|
|
||||||
|
|
||||||
@with_native_function_and_index
|
@with_native_function_and_index
|
||||||
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
|
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
|
||||||
sig = kernel_signature(f, backend_index)
|
sig = kernel_signature(f, backend_index)
|
||||||
metadata = backend_index.get_kernel(f)
|
metadata = backend_index.get_kernel(f)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
@ -22,7 +22,7 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional
|
|||||||
|
|
||||||
|
|
||||||
@with_native_function_and_index
|
@with_native_function_and_index
|
||||||
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
|
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
|
||||||
meta_name = meta.name(g)
|
meta_name = meta.name(g)
|
||||||
out_args = structured.impl_arguments(g)
|
out_args = structured.impl_arguments(g)
|
||||||
metadata = backend_index.get_kernel(g)
|
metadata = backend_index.get_kernel(g)
|
||||||
@ -42,8 +42,8 @@ void impl({', '.join(a.decl() for a in out_args)});
|
|||||||
# actual kernel definitions we keep in aten/src/ATen/native/
|
# actual kernel definitions we keep in aten/src/ATen/native/
|
||||||
@with_native_function_and_index
|
@with_native_function_and_index
|
||||||
def compute_native_function_declaration(
|
def compute_native_function_declaration(
|
||||||
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
|
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
metadata = backend_index.get_kernel(g)
|
metadata = backend_index.get_kernel(g)
|
||||||
if isinstance(g, NativeFunctionsGroup):
|
if isinstance(g, NativeFunctionsGroup):
|
||||||
if metadata is not None and metadata.structured:
|
if metadata is not None and metadata.structured:
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, Optional, Tuple, Union
|
from typing import Literal, TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.cpp as cpp
|
import torchgen.api.cpp as cpp
|
||||||
import torchgen.api.meta as meta
|
import torchgen.api.meta as meta
|
||||||
@ -34,15 +36,18 @@ from torchgen.model import (
|
|||||||
SchemaKind,
|
SchemaKind,
|
||||||
TensorOptionsArguments,
|
TensorOptionsArguments,
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
|
||||||
from torchgen.utils import assert_never, mapMaybe, Target
|
from torchgen.utils import assert_never, mapMaybe, Target
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
def gen_registration_headers(
|
def gen_registration_headers(
|
||||||
backend_index: BackendIndex,
|
backend_index: BackendIndex,
|
||||||
per_operator_headers: bool,
|
per_operator_headers: bool,
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
if per_operator_headers:
|
if per_operator_headers:
|
||||||
headers = ["#include <ATen/ops/as_strided_native.h>"]
|
headers = ["#include <ATen/ops/as_strided_native.h>"]
|
||||||
else:
|
else:
|
||||||
@ -73,7 +78,7 @@ def gen_registration_headers(
|
|||||||
|
|
||||||
def gen_empty_impl_names(
|
def gen_empty_impl_names(
|
||||||
backend_index: BackendIndex,
|
backend_index: BackendIndex,
|
||||||
) -> Tuple[Optional[str], Optional[str]]:
|
) -> tuple[str | None, str | None]:
|
||||||
empty_impl = None
|
empty_impl = None
|
||||||
empty_strided_impl = None
|
empty_strided_impl = None
|
||||||
|
|
||||||
@ -97,7 +102,7 @@ def gen_empty_impl_names(
|
|||||||
return empty_impl, empty_strided_impl
|
return empty_impl, empty_strided_impl
|
||||||
|
|
||||||
|
|
||||||
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
|
def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
|
||||||
if backend_index.dispatch_key == DispatchKey.Meta:
|
if backend_index.dispatch_key == DispatchKey.Meta:
|
||||||
empty_options = "options.device(at::kMeta)"
|
empty_options = "options.device(at::kMeta)"
|
||||||
else:
|
else:
|
||||||
@ -120,7 +125,7 @@ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &o
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
|
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
|
||||||
_, empty_strided_impl = gen_empty_impl_names(backend_index)
|
_, empty_strided_impl = gen_empty_impl_names(backend_index)
|
||||||
return (
|
return (
|
||||||
[]
|
[]
|
||||||
@ -138,7 +143,7 @@ std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
|
def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
|
||||||
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
||||||
# The function isn't used by this key (since only functional ops have a kernel for this key),
|
# The function isn't used by this key (since only functional ops have a kernel for this key),
|
||||||
# so we need to not include it to avoid a defined-but-not-used error.
|
# so we need to not include it to avoid a defined-but-not-used error.
|
||||||
@ -168,7 +173,7 @@ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
|
def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
|
||||||
return [
|
return [
|
||||||
"""
|
"""
|
||||||
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
|
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
|
||||||
@ -191,7 +196,7 @@ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &o
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
|
def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
|
||||||
return [
|
return [
|
||||||
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
|
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
|
||||||
*gen_create_out_helper(backend_index),
|
*gen_create_out_helper(backend_index),
|
||||||
@ -249,7 +254,7 @@ class RegisterDispatchKey:
|
|||||||
# Finally, this field is currently Optional because it is only used by external backends.
|
# Finally, this field is currently Optional because it is only used by external backends.
|
||||||
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
|
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
|
||||||
# all of the existing kernel signatures scattered across aten/src/ATen/native.
|
# all of the existing kernel signatures scattered across aten/src/ATen/native.
|
||||||
class_method_name: Optional[str]
|
class_method_name: str | None
|
||||||
|
|
||||||
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
|
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
|
||||||
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
|
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
|
||||||
@ -257,7 +262,7 @@ class RegisterDispatchKey:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gen_device_check(
|
def gen_device_check(
|
||||||
type: DeviceCheckType, args: List[Argument], method_name: str
|
type: DeviceCheckType, args: list[Argument], method_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
if type == DeviceCheckType.NoCheck:
|
if type == DeviceCheckType.NoCheck:
|
||||||
return " // No device check\n"
|
return " // No device check\n"
|
||||||
@ -272,7 +277,7 @@ class RegisterDispatchKey:
|
|||||||
return device_check
|
return device_check
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
||||||
if isinstance(f, NativeFunctionsGroup):
|
if isinstance(f, NativeFunctionsGroup):
|
||||||
g: NativeFunctionsGroup = f
|
g: NativeFunctionsGroup = f
|
||||||
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
|
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
|
||||||
@ -291,7 +296,7 @@ class RegisterDispatchKey:
|
|||||||
|
|
||||||
def wrapper_kernel_sig(
|
def wrapper_kernel_sig(
|
||||||
self, f: NativeFunction
|
self, f: NativeFunction
|
||||||
) -> Union[NativeSignature, DispatcherSignature]:
|
) -> NativeSignature | DispatcherSignature:
|
||||||
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
||||||
return DispatcherSignature.from_schema(
|
return DispatcherSignature.from_schema(
|
||||||
f.func,
|
f.func,
|
||||||
@ -300,8 +305,8 @@ class RegisterDispatchKey:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def gen_out_inplace_wrapper(
|
def gen_out_inplace_wrapper(
|
||||||
self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
|
self, f: NativeFunction, g: NativeFunctionsGroup | None
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
if g is None:
|
if g is None:
|
||||||
return None
|
return None
|
||||||
k = f.func.kind()
|
k = f.func.kind()
|
||||||
@ -350,7 +355,7 @@ class RegisterDispatchKey:
|
|||||||
}}
|
}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
|
def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
|
||||||
metadata = self.backend_index.get_kernel(g)
|
metadata = self.backend_index.get_kernel(g)
|
||||||
if self.backend_index.dispatch_key == DispatchKey.Meta:
|
if self.backend_index.dispatch_key == DispatchKey.Meta:
|
||||||
assert not self.backend_index.has_kernel(g.out), (
|
assert not self.backend_index.has_kernel(g.out), (
|
||||||
@ -380,8 +385,8 @@ class RegisterDispatchKey:
|
|||||||
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
||||||
|
|
||||||
def gen_unstructured(
|
def gen_unstructured(
|
||||||
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
|
self, f: NativeFunction, g: NativeFunctionsGroup | None = None
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
with native_function_manager(f):
|
with native_function_manager(f):
|
||||||
inplace_meta = False
|
inplace_meta = False
|
||||||
gets_out_inplace_wrapper = False
|
gets_out_inplace_wrapper = False
|
||||||
@ -732,7 +737,7 @@ resize_out(out, sizes, strides, options);
|
|||||||
return "\n".join(line for line in lines if line)
|
return "\n".join(line for line in lines if line)
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def gen_one(self, f: NativeFunction) -> Optional[str]:
|
def gen_one(self, f: NativeFunction) -> str | None:
|
||||||
assert not f.manual_kernel_registration
|
assert not f.manual_kernel_registration
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -806,7 +811,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
|||||||
sig_body = []
|
sig_body = []
|
||||||
# We'll use context to keep track of any variables we've brought
|
# We'll use context to keep track of any variables we've brought
|
||||||
# into scope while generating code
|
# into scope while generating code
|
||||||
context: List[Union[Binding, Expr]] = list(sig.arguments())
|
context: list[Binding | Expr] = list(sig.arguments())
|
||||||
|
|
||||||
# Initialize the class corresponding to this structured
|
# Initialize the class corresponding to this structured
|
||||||
# operator; feeding it the output argument(s) if it is known
|
# operator; feeding it the output argument(s) if it is known
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.ufunc as ufunc
|
import torchgen.api.ufunc as ufunc
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
@ -14,7 +16,6 @@ from torchgen.api.types import (
|
|||||||
StructuredImplSignature,
|
StructuredImplSignature,
|
||||||
VectorizedCType,
|
VectorizedCType,
|
||||||
)
|
)
|
||||||
from torchgen.api.ufunc import UfunctorBindings
|
|
||||||
from torchgen.context import with_native_function
|
from torchgen.context import with_native_function
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
Argument,
|
Argument,
|
||||||
@ -28,6 +29,10 @@ from torchgen.model import (
|
|||||||
from torchgen.utils import OrderedSet
|
from torchgen.utils import OrderedSet
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.api.ufunc import UfunctorBindings
|
||||||
|
|
||||||
|
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||||
#
|
#
|
||||||
# CUDA STUFF
|
# CUDA STUFF
|
||||||
@ -60,7 +65,7 @@ from torchgen.utils import OrderedSet
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class UfunctorSignature:
|
class UfunctorSignature:
|
||||||
g: NativeFunctionsGroup
|
g: NativeFunctionsGroup
|
||||||
scalar_tensor_idx: Optional[int]
|
scalar_tensor_idx: int | None
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
def arguments(self) -> UfunctorBindings:
|
def arguments(self) -> UfunctorBindings:
|
||||||
@ -68,7 +73,7 @@ class UfunctorSignature:
|
|||||||
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
|
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
|
||||||
)
|
)
|
||||||
|
|
||||||
def fields(self) -> List[Binding]:
|
def fields(self) -> list[Binding]:
|
||||||
# fields are renamed to have a trailing underscore, as is conventional
|
# fields are renamed to have a trailing underscore, as is conventional
|
||||||
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
|
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
|
||||||
|
|
||||||
@ -98,10 +103,10 @@ class UfuncSignature:
|
|||||||
name: str
|
name: str
|
||||||
compute_t: CType
|
compute_t: CType
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> list[Binding]:
|
||||||
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
|
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
|
||||||
|
|
||||||
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
|
def call(self, ctx: Sequence[Binding | Expr]) -> str:
|
||||||
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
||||||
|
|
||||||
|
|
||||||
@ -132,10 +137,10 @@ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
|
|||||||
|
|
||||||
def compute_ufunc_cuda_functors(
|
def compute_ufunc_cuda_functors(
|
||||||
g: NativeFunctionsGroup,
|
g: NativeFunctionsGroup,
|
||||||
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
|
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
|
||||||
# First, build the functors.
|
# First, build the functors.
|
||||||
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
|
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
|
||||||
ufunctors: List[str] = []
|
ufunctors: list[str] = []
|
||||||
loops = g.out.ufunc_inner_loop
|
loops = g.out.ufunc_inner_loop
|
||||||
scalar_tensor_idx_lookup = {
|
scalar_tensor_idx_lookup = {
|
||||||
UfuncKey.CUDAFunctorOnSelf: 1,
|
UfuncKey.CUDAFunctorOnSelf: 1,
|
||||||
@ -237,7 +242,7 @@ BinaryScalarSpecializationConfigs = [
|
|||||||
def compute_ufunc_cuda_dtype_body(
|
def compute_ufunc_cuda_dtype_body(
|
||||||
g: NativeFunctionsGroup,
|
g: NativeFunctionsGroup,
|
||||||
dtype: ScalarType,
|
dtype: ScalarType,
|
||||||
inner_loops: Dict[UfuncKey, UfunctorSignature],
|
inner_loops: dict[UfuncKey, UfunctorSignature],
|
||||||
parent_ctx: Sequence[Binding],
|
parent_ctx: Sequence[Binding],
|
||||||
) -> str:
|
) -> str:
|
||||||
body = "using opmath_t = at::opmath_type<scalar_t>;"
|
body = "using opmath_t = at::opmath_type<scalar_t>;"
|
||||||
@ -249,7 +254,7 @@ def compute_ufunc_cuda_dtype_body(
|
|||||||
scalar_idx = config.scalar_idx + 1
|
scalar_idx = config.scalar_idx + 1
|
||||||
# Make a copy and at the same time widen the type (not permissible
|
# Make a copy and at the same time widen the type (not permissible
|
||||||
# without copy; we don't want to mutate the input argument anyway)
|
# without copy; we don't want to mutate the input argument anyway)
|
||||||
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
|
ctx: list[Expr | Binding] = list(parent_ctx)
|
||||||
ctx.append(
|
ctx.append(
|
||||||
Expr(
|
Expr(
|
||||||
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
|
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
|
||||||
@ -346,7 +351,7 @@ class StubSignature:
|
|||||||
def type_name(self) -> str:
|
def type_name(self) -> str:
|
||||||
return f"{str(self.g.functional.func.name.name)}_fn"
|
return f"{str(self.g.functional.func.name.name)}_fn"
|
||||||
|
|
||||||
def arguments(self) -> List[Binding]:
|
def arguments(self) -> list[Binding]:
|
||||||
return ufunc.stub_arguments(self.g)
|
return ufunc.stub_arguments(self.g)
|
||||||
|
|
||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
@ -393,7 +398,7 @@ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
|
|||||||
def compute_ufunc_cpu_dtype_body(
|
def compute_ufunc_cpu_dtype_body(
|
||||||
g: NativeFunctionsGroup,
|
g: NativeFunctionsGroup,
|
||||||
dtype: ScalarType,
|
dtype: ScalarType,
|
||||||
inner_loops: Dict[UfuncKey, UfuncSignature],
|
inner_loops: dict[UfuncKey, UfuncSignature],
|
||||||
parent_ctx: Sequence[Binding],
|
parent_ctx: Sequence[Binding],
|
||||||
) -> str:
|
) -> str:
|
||||||
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
|
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
|
||||||
@ -459,8 +464,8 @@ def compute_ufunc_cpu_dtype_body(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
|
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
|
||||||
r: List[Union[Expr, Binding]] = []
|
r: list[Expr | Binding] = []
|
||||||
r.extend(ctx)
|
r.extend(ctx)
|
||||||
r.extend(b)
|
r.extend(b)
|
||||||
return r
|
return r
|
||||||
@ -489,7 +494,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
|
|||||||
|
|
||||||
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
|
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
|
||||||
loops = g.out.ufunc_inner_loop
|
loops = g.out.ufunc_inner_loop
|
||||||
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
|
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
|
||||||
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
|
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
|
||||||
lks = []
|
lks = []
|
||||||
# ORDER MATTERS: this specifies overriding precedence
|
# ORDER MATTERS: this specifies overriding precedence
|
||||||
|
@ -1,24 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple
|
from typing import Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen import dest
|
from torchgen import dest
|
||||||
|
|
||||||
# disable import sorting to avoid circular dependency.
|
# disable import sorting to avoid circular dependency.
|
||||||
from torchgen.api.types import DispatcherSignature # usort:skip
|
from torchgen.api.types import DispatcherSignature # usort:skip
|
||||||
from torchgen.context import method_with_native_function
|
from torchgen.context import method_with_native_function
|
||||||
from torchgen.executorch.model import ETKernelIndex
|
|
||||||
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
|
||||||
from torchgen.utils import concatMap, Target
|
from torchgen.utils import concatMap, Target
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.executorch.model import ETKernelIndex
|
||||||
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
||||||
# model authoring side.
|
# model authoring side.
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeNativeFunctionStub:
|
class ComputeNativeFunctionStub:
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
if Variant.function not in f.variants:
|
if Variant.function not in f.variants:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -80,7 +85,7 @@ def gen_custom_ops_registration(
|
|||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
kernel_index: ETKernelIndex,
|
kernel_index: ETKernelIndex,
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Generate custom ops registration code for dest.RegisterDispatchKey.
|
Generate custom ops registration code for dest.RegisterDispatchKey.
|
||||||
|
|
||||||
@ -97,7 +102,7 @@ def gen_custom_ops_registration(
|
|||||||
dispatch_key = DispatchKey.CPU
|
dispatch_key = DispatchKey.CPU
|
||||||
backend_index = kernel_index._to_backend_index()
|
backend_index = kernel_index._to_backend_index()
|
||||||
static_init_dispatch_registrations = ""
|
static_init_dispatch_registrations = ""
|
||||||
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||||
for native_function in native_functions:
|
for native_function in native_functions:
|
||||||
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
from typing import List, Optional, Sequence, Set, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen import local
|
from torchgen import local
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
@ -63,7 +65,7 @@ def valuetype_type(
|
|||||||
*,
|
*,
|
||||||
binds: ArgName,
|
binds: ArgName,
|
||||||
remove_non_owning_ref_types: bool = False,
|
remove_non_owning_ref_types: bool = False,
|
||||||
) -> Optional[NamedCType]:
|
) -> NamedCType | None:
|
||||||
if isinstance(t, BaseType):
|
if isinstance(t, BaseType):
|
||||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||||
return None
|
return None
|
||||||
@ -209,7 +211,7 @@ def returns_type(rs: Sequence[Return]) -> CType:
|
|||||||
|
|
||||||
|
|
||||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||||
returns: List[str] = []
|
returns: list[str] = []
|
||||||
for i, r in enumerate(f.func.returns):
|
for i, r in enumerate(f.func.returns):
|
||||||
# If we have an inplace function, the return argument is
|
# If we have an inplace function, the return argument is
|
||||||
# implicitly named self.
|
# implicitly named self.
|
||||||
@ -295,16 +297,16 @@ def default_expr(d: str, t: Type) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def argument(
|
def argument(
|
||||||
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||||
*,
|
*,
|
||||||
cpp_no_default_args: Set[str],
|
cpp_no_default_args: set[str],
|
||||||
method: bool,
|
method: bool,
|
||||||
faithful: bool,
|
faithful: bool,
|
||||||
has_tensor_options: bool,
|
has_tensor_options: bool,
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
def sub_argument(
|
def sub_argument(
|
||||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
return argument(
|
return argument(
|
||||||
a,
|
a,
|
||||||
cpp_no_default_args=cpp_no_default_args,
|
cpp_no_default_args=cpp_no_default_args,
|
||||||
@ -319,7 +321,7 @@ def argument(
|
|||||||
binds = SpecialArgName.possibly_redundant_memory_format
|
binds = SpecialArgName.possibly_redundant_memory_format
|
||||||
else:
|
else:
|
||||||
binds = a.name
|
binds = a.name
|
||||||
default: Optional[str] = None
|
default: str | None = None
|
||||||
if a.name not in cpp_no_default_args and a.default is not None:
|
if a.name not in cpp_no_default_args and a.default is not None:
|
||||||
default = default_expr(a.default, a.type)
|
default = default_expr(a.default, a.type)
|
||||||
return [
|
return [
|
||||||
@ -347,9 +349,9 @@ def arguments(
|
|||||||
*,
|
*,
|
||||||
faithful: bool,
|
faithful: bool,
|
||||||
method: bool,
|
method: bool,
|
||||||
cpp_no_default_args: Set[str],
|
cpp_no_default_args: set[str],
|
||||||
) -> List[Binding]:
|
) -> list[Binding]:
|
||||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||||
if faithful:
|
if faithful:
|
||||||
args.extend(arguments.non_out)
|
args.extend(arguments.non_out)
|
||||||
args.extend(arguments.out)
|
args.extend(arguments.out)
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Set
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torchgen.api.cpp as aten_cpp
|
import torchgen.api.cpp as aten_cpp
|
||||||
from torchgen.api.types import Binding, CType
|
|
||||||
from torchgen.executorch.api.types.types import contextArg
|
from torchgen.executorch.api.types.types import contextArg
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.api.types import Binding, CType
|
||||||
from torchgen.model import FunctionSchema, NativeFunction
|
from torchgen.model import FunctionSchema, NativeFunction
|
||||||
|
|
||||||
|
|
||||||
@ -20,14 +25,14 @@ class ExecutorchCppSignature:
|
|||||||
func: FunctionSchema
|
func: FunctionSchema
|
||||||
|
|
||||||
# The set of C++ arguments which should not have defaults applied to them
|
# The set of C++ arguments which should not have defaults applied to them
|
||||||
cpp_no_default_args: Set[str]
|
cpp_no_default_args: set[str]
|
||||||
|
|
||||||
# Allows you to prepend an arbitrary prefix to the signature name.
|
# Allows you to prepend an arbitrary prefix to the signature name.
|
||||||
# This is useful for parts of the codegen that generate wrappers around kernels,
|
# This is useful for parts of the codegen that generate wrappers around kernels,
|
||||||
# and need to avoid naming collisions.
|
# and need to avoid naming collisions.
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
|
|
||||||
def arguments(self, *, include_context: bool = True) -> List[Binding]:
|
def arguments(self, *, include_context: bool = True) -> list[Binding]:
|
||||||
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
||||||
self.func.arguments,
|
self.func.arguments,
|
||||||
faithful=True, # always faithful, out argument at the end
|
faithful=True, # always faithful, out argument at the end
|
||||||
@ -41,7 +46,7 @@ class ExecutorchCppSignature:
|
|||||||
faithful_name_for_out_overloads=True,
|
faithful_name_for_out_overloads=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
|
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
|
||||||
args_str = ", ".join(
|
args_str = ", ".join(
|
||||||
a.decl() for a in self.arguments(include_context=include_context)
|
a.decl() for a in self.arguments(include_context=include_context)
|
||||||
)
|
)
|
||||||
@ -49,7 +54,7 @@ class ExecutorchCppSignature:
|
|||||||
name = self.name()
|
name = self.name()
|
||||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||||
|
|
||||||
def defn(self, name: Optional[str] = None) -> str:
|
def defn(self, name: str | None = None) -> str:
|
||||||
args = [a.defn() for a in self.arguments()]
|
args = [a.defn() for a in self.arguments()]
|
||||||
args_str = ", ".join(args)
|
args_str = ", ".join(args)
|
||||||
if name is None:
|
if name is None:
|
||||||
@ -62,7 +67,7 @@ class ExecutorchCppSignature:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_function(
|
def from_native_function(
|
||||||
f: NativeFunction, *, prefix: str = ""
|
f: NativeFunction, *, prefix: str = ""
|
||||||
) -> "ExecutorchCppSignature":
|
) -> ExecutorchCppSignature:
|
||||||
return ExecutorchCppSignature(
|
return ExecutorchCppSignature(
|
||||||
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
|
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from torchgen.api.types import (
|
from torchgen.api.types import (
|
||||||
BaseCppType,
|
BaseCppType,
|
||||||
@ -40,7 +41,7 @@ contextArg = Binding(
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||||
BaseTy.int: longT,
|
BaseTy.int: longT,
|
||||||
BaseTy.float: doubleT,
|
BaseTy.float: doubleT,
|
||||||
BaseTy.bool: boolT,
|
BaseTy.bool: boolT,
|
||||||
@ -54,7 +55,7 @@ BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class OptionalCType(CType):
|
class OptionalCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -63,13 +64,13 @@ class OptionalCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
|
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return OptionalCType(self.elem.remove_const_ref())
|
return OptionalCType(self.elem.remove_const_ref())
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ArrayRefCType(CType):
|
class ArrayRefCType(CType):
|
||||||
elem: "CType"
|
elem: CType
|
||||||
|
|
||||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||||
# Do not pass `strip_ref` recursively.
|
# Do not pass `strip_ref` recursively.
|
||||||
@ -78,5 +79,5 @@ class ArrayRefCType(CType):
|
|||||||
def cpp_type_registration_declarations(self) -> str:
|
def cpp_type_registration_declarations(self) -> str:
|
||||||
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||||
|
|
||||||
def remove_const_ref(self) -> "CType":
|
def remove_const_ref(self) -> CType:
|
||||||
return ArrayRefCType(self.elem.remove_const_ref())
|
return ArrayRefCType(self.elem.remove_const_ref())
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
from __future__ import annotations
|
||||||
from typing import Callable, List, Sequence, Tuple
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Sequence, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api.types import Binding, CType, NamedCType
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
Argument,
|
Argument,
|
||||||
BaseTy,
|
BaseTy,
|
||||||
@ -13,6 +14,10 @@ from torchgen.model import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.api.types import Binding, CType, NamedCType
|
||||||
|
|
||||||
|
|
||||||
connector = "\n\t"
|
connector = "\n\t"
|
||||||
|
|
||||||
|
|
||||||
@ -52,7 +57,7 @@ class Unboxing:
|
|||||||
# Convert all the arguments in a NativeFunction to C++ code
|
# Convert all the arguments in a NativeFunction to C++ code
|
||||||
def convert_arguments(
|
def convert_arguments(
|
||||||
self, args: Sequence[Binding]
|
self, args: Sequence[Binding]
|
||||||
) -> Tuple[List[Binding], List[str]]:
|
) -> tuple[list[Binding], list[str]]:
|
||||||
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
||||||
binding_list = []
|
binding_list = []
|
||||||
for arg in args:
|
for arg in args:
|
||||||
@ -72,7 +77,7 @@ class Unboxing:
|
|||||||
|
|
||||||
def argumenttype_evalue_convert(
|
def argumenttype_evalue_convert(
|
||||||
self, t: Type, arg_name: str, *, mutable: bool = False
|
self, t: Type, arg_name: str, *, mutable: bool = False
|
||||||
) -> Tuple[str, CType, List[str], List[str]]:
|
) -> tuple[str, CType, list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
||||||
(1) the C++ code necessary to unbox the argument
|
(1) the C++ code necessary to unbox the argument
|
||||||
@ -107,14 +112,14 @@ class Unboxing:
|
|||||||
|
|
||||||
def _gen_code_base_type(
|
def _gen_code_base_type(
|
||||||
self, arg_name: str, out_name: str, ctype: CType
|
self, arg_name: str, out_name: str, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
return [
|
return [
|
||||||
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||||
], []
|
], []
|
||||||
|
|
||||||
def _gen_code_optional_type(
|
def _gen_code_optional_type(
|
||||||
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
in_name = f"{arg_name}_opt_in"
|
in_name = f"{arg_name}_opt_in"
|
||||||
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
||||||
t.elem, in_name
|
t.elem, in_name
|
||||||
@ -130,7 +135,7 @@ class Unboxing:
|
|||||||
|
|
||||||
def _gen_code_list_type(
|
def _gen_code_list_type(
|
||||||
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
in_name = f"{arg_name}_list_in"
|
in_name = f"{arg_name}_list_in"
|
||||||
elem_name = f"{arg_name}_elem"
|
elem_name = f"{arg_name}_elem"
|
||||||
code = []
|
code = []
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# Represents all kernels used by an Executorch model.
|
# Represents all kernels used by an Executorch model.
|
||||||
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
|
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections import defaultdict, namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
from torchgen.model import (
|
from torchgen.model import (
|
||||||
BackendIndex,
|
BackendIndex,
|
||||||
@ -41,7 +42,7 @@ class ETKernelKeyOpArgMeta:
|
|||||||
arg_name: str
|
arg_name: str
|
||||||
dtype: str
|
dtype: str
|
||||||
# The order of the dimensions if entry is a Tensor
|
# The order of the dimensions if entry is a Tensor
|
||||||
dim_order: Tuple[int, ...]
|
dim_order: tuple[int, ...]
|
||||||
|
|
||||||
def to_native_string(self) -> str:
|
def to_native_string(self) -> str:
|
||||||
dtype_str = ScalarType[self.dtype].value
|
dtype_str = ScalarType[self.dtype].value
|
||||||
@ -52,7 +53,7 @@ class ETKernelKeyOpArgMeta:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ETKernelKey:
|
class ETKernelKey:
|
||||||
# Field undefined is default = True
|
# Field undefined is default = True
|
||||||
arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
|
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
|
||||||
|
|
||||||
# Indicator for this kernel being used as a catch all
|
# Indicator for this kernel being used as a catch all
|
||||||
default: bool = False
|
default: bool = False
|
||||||
@ -61,10 +62,10 @@ class ETKernelKey:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gen_from_yaml(
|
def gen_from_yaml(
|
||||||
args: Dict[str, Tuple[str, str]],
|
args: dict[str, tuple[str, str]],
|
||||||
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
|
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
|
||||||
dim_order_alias_map: Dict[str, List[int]],
|
dim_order_alias_map: dict[str, list[int]],
|
||||||
) -> List["ETKernelKey"]:
|
) -> list[ETKernelKey]:
|
||||||
"""Generate ETKernelKeys from arg kernel specs
|
"""Generate ETKernelKeys from arg kernel specs
|
||||||
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
|
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
|
||||||
type_alias_map (actualizing each potential type permutation as a KernelKey)
|
type_alias_map (actualizing each potential type permutation as a KernelKey)
|
||||||
@ -137,15 +138,15 @@ class ETKernelKey:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ETKernelIndex:
|
class ETKernelIndex:
|
||||||
index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
|
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
|
||||||
|
|
||||||
def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
|
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
||||||
m = self.get_kernels(g)
|
m = self.get_kernels(g)
|
||||||
return m is not None
|
return m is not None
|
||||||
|
|
||||||
def get_kernels(
|
def get_kernels(
|
||||||
self, g: Union[NativeFunction, NativeFunctionsGroup]
|
self, g: NativeFunction | NativeFunctionsGroup
|
||||||
) -> Dict[ETKernelKey, BackendMetadata]:
|
) -> dict[ETKernelKey, BackendMetadata]:
|
||||||
if isinstance(g, NativeFunction):
|
if isinstance(g, NativeFunction):
|
||||||
f = g
|
f = g
|
||||||
elif isinstance(g, NativeFunctionsGroup):
|
elif isinstance(g, NativeFunctionsGroup):
|
||||||
@ -158,8 +159,8 @@ class ETKernelIndex:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def grow_from_backend_indices(
|
def grow_from_backend_indices(
|
||||||
kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
|
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
|
||||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
|
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for dk in backend_indices:
|
for dk in backend_indices:
|
||||||
index = backend_indices[dk]
|
index = backend_indices[dk]
|
||||||
@ -171,17 +172,17 @@ class ETKernelIndex:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_backend_indices(
|
def from_backend_indices(
|
||||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||||
) -> "ETKernelIndex":
|
) -> ETKernelIndex:
|
||||||
kernel_index: Dict[
|
kernel_index: dict[
|
||||||
OperatorName, Dict[ETKernelKey, BackendMetadata]
|
OperatorName, dict[ETKernelKey, BackendMetadata]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
||||||
return ETKernelIndex(kernel_index)
|
return ETKernelIndex(kernel_index)
|
||||||
|
|
||||||
def grow(
|
def grow(
|
||||||
self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||||
) -> "ETKernelIndex":
|
) -> ETKernelIndex:
|
||||||
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
|
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -189,7 +190,7 @@ class ETKernelIndex:
|
|||||||
"""
|
"""
|
||||||
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
|
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
|
||||||
"""
|
"""
|
||||||
index: Dict[OperatorName, BackendMetadata] = {}
|
index: dict[OperatorName, BackendMetadata] = {}
|
||||||
for op in self.index:
|
for op in self.index:
|
||||||
kernel_dict = self.index[op]
|
kernel_dict = self.index[op]
|
||||||
assert (
|
assert (
|
||||||
@ -209,9 +210,7 @@ class ETKernelIndex:
|
|||||||
|
|
||||||
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
|
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_indices(
|
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
|
||||||
index_a: "ETKernelIndex", index_b: "ETKernelIndex"
|
|
||||||
) -> "ETKernelIndex":
|
|
||||||
combined = defaultdict(dict, index_a.index.copy())
|
combined = defaultdict(dict, index_a.index.copy())
|
||||||
|
|
||||||
for op, entry in index_b.index.items():
|
for op, entry in index_b.index.items():
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict, namedtuple
|
from collections import defaultdict, namedtuple
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -22,7 +24,7 @@ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indice
|
|||||||
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
||||||
|
|
||||||
|
|
||||||
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
|
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
|
||||||
"""Given a loaded yaml representing kernel assignment information, extract the
|
"""Given a loaded yaml representing kernel assignment information, extract the
|
||||||
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
||||||
|
|
||||||
@ -34,11 +36,11 @@ def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]
|
|||||||
if (kernels := e.pop("kernels", None)) is None:
|
if (kernels := e.pop("kernels", None)) is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
||||||
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
||||||
dim_order_alias.pop("__line__", None)
|
dim_order_alias.pop("__line__", None)
|
||||||
|
|
||||||
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
|
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
|
||||||
|
|
||||||
for entry in kernels: # type: ignore[attr-defined]
|
for entry in kernels: # type: ignore[attr-defined]
|
||||||
arg_meta = entry.get("arg_meta")
|
arg_meta = entry.get("arg_meta")
|
||||||
@ -76,7 +78,7 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
|||||||
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
||||||
that should be used by the kernel key).
|
that should be used by the kernel key).
|
||||||
"""
|
"""
|
||||||
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
|
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
|
||||||
for ei in es: # type: ignore[attr-defined]
|
for ei in es: # type: ignore[attr-defined]
|
||||||
e = ei.copy()
|
e = ei.copy()
|
||||||
|
|
||||||
@ -95,11 +97,11 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
|||||||
return ETKernelIndex(indices)
|
return ETKernelIndex(indices)
|
||||||
|
|
||||||
|
|
||||||
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
|
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
|
||||||
"""Given a loaded yaml representing a list of operators, extract the
|
"""Given a loaded yaml representing a list of operators, extract the
|
||||||
kernel key related fields indexed by the operator name.
|
kernel key related fields indexed by the operator name.
|
||||||
"""
|
"""
|
||||||
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
|
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
|
||||||
for ei in es: # type: ignore[attr-defined]
|
for ei in es: # type: ignore[attr-defined]
|
||||||
funcs = ei.get("func")
|
funcs = ei.get("func")
|
||||||
assert isinstance(funcs, str), f"not a str: {funcs}"
|
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||||||
@ -118,9 +120,9 @@ def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
|
|||||||
def parse_et_yaml(
|
def parse_et_yaml(
|
||||||
path: str,
|
path: str,
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
ignore_keys: set[DispatchKey] | None = None,
|
||||||
skip_native_fns_gen: bool = False,
|
skip_native_fns_gen: bool = False,
|
||||||
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
|
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
|
||||||
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
||||||
of fields to persist from native_functions.yaml to functions.yaml
|
of fields to persist from native_functions.yaml to functions.yaml
|
||||||
"""
|
"""
|
||||||
|
278
torchgen/gen.py
278
torchgen/gen.py
@ -1,23 +1,13 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
from collections import defaultdict, namedtuple, OrderedDict
|
from collections import defaultdict, namedtuple, OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from pathlib import Path
|
||||||
Any,
|
from typing import Any, Callable, Literal, Sequence, TypeVar
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -148,20 +138,20 @@ class LineLoader(YamlLoader):
|
|||||||
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
|
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
|
||||||
|
|
||||||
|
|
||||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {}
|
_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
|
||||||
_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {}
|
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
def parse_native_yaml_struct(
|
def parse_native_yaml_struct(
|
||||||
es: object,
|
es: object,
|
||||||
valid_tags: Set[str],
|
valid_tags: set[str],
|
||||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
ignore_keys: set[DispatchKey] | None = None,
|
||||||
path: str = "<stdin>",
|
path: str = "<stdin>",
|
||||||
skip_native_fns_gen: bool = False,
|
skip_native_fns_gen: bool = False,
|
||||||
) -> ParsedYaml:
|
) -> ParsedYaml:
|
||||||
assert isinstance(es, list)
|
assert isinstance(es, list)
|
||||||
rs: List[NativeFunction] = []
|
rs: list[NativeFunction] = []
|
||||||
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
||||||
for e in es:
|
for e in es:
|
||||||
assert isinstance(e, dict), f"expected to be dict: {e}"
|
assert isinstance(e, dict), f"expected to be dict: {e}"
|
||||||
assert isinstance(e.get("__line__"), int), e
|
assert isinstance(e.get("__line__"), int), e
|
||||||
@ -174,7 +164,7 @@ def parse_native_yaml_struct(
|
|||||||
BackendIndex.grow_index(bs, m)
|
BackendIndex.grow_index(bs, m)
|
||||||
error_check_native_functions(rs)
|
error_check_native_functions(rs)
|
||||||
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
|
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
|
||||||
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
|
indices: dict[DispatchKey, BackendIndex] = defaultdict(
|
||||||
lambda: BackendIndex(
|
lambda: BackendIndex(
|
||||||
dispatch_key=DispatchKey.Undefined,
|
dispatch_key=DispatchKey.Undefined,
|
||||||
use_out_as_primary=True,
|
use_out_as_primary=True,
|
||||||
@ -200,9 +190,9 @@ def parse_native_yaml_struct(
|
|||||||
return ParsedYaml(rs, indices)
|
return ParsedYaml(rs, indices)
|
||||||
|
|
||||||
|
|
||||||
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
|
||||||
assert isinstance(es, list)
|
assert isinstance(es, list)
|
||||||
rs: Set[str] = set()
|
rs: set[str] = set()
|
||||||
for e in es:
|
for e in es:
|
||||||
assert isinstance(e.get("__line__"), int), e
|
assert isinstance(e.get("__line__"), int), e
|
||||||
loc = Location(path, e["__line__"])
|
loc = Location(path, e["__line__"])
|
||||||
@ -218,7 +208,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
@functools.lru_cache(maxsize=None)
|
||||||
def parse_tags_yaml(path: str) -> Set[str]:
|
def parse_tags_yaml(path: str) -> set[str]:
|
||||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
@ -231,10 +221,10 @@ def parse_tags_yaml(path: str) -> Set[str]:
|
|||||||
def parse_native_yaml(
|
def parse_native_yaml(
|
||||||
path: str,
|
path: str,
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
ignore_keys: set[DispatchKey] | None = None,
|
||||||
*,
|
*,
|
||||||
skip_native_fns_gen: bool = False,
|
skip_native_fns_gen: bool = False,
|
||||||
loaded_yaml: Optional[object] = None,
|
loaded_yaml: object | None = None,
|
||||||
) -> ParsedYaml:
|
) -> ParsedYaml:
|
||||||
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
|
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
|
||||||
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
|
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
|
||||||
@ -261,8 +251,8 @@ def parse_native_yaml(
|
|||||||
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
|
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
|
||||||
# Assertions here are meant to be performed across NativeFunctions.
|
# Assertions here are meant to be performed across NativeFunctions.
|
||||||
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
|
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
|
||||||
func_map: Dict[OperatorName, NativeFunction] = {}
|
func_map: dict[OperatorName, NativeFunction] = {}
|
||||||
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
|
base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
|
||||||
for f in funcs:
|
for f in funcs:
|
||||||
func_map[f.func.name] = f
|
func_map[f.func.name] = f
|
||||||
base_func_map[f.func.name.name].append(f)
|
base_func_map[f.func.name.name].append(f)
|
||||||
@ -329,7 +319,7 @@ def cpp_string(s: str) -> str:
|
|||||||
# and similar functional combinators.
|
# and similar functional combinators.
|
||||||
|
|
||||||
|
|
||||||
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
|
def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
|
||||||
if len(backends) == 0:
|
if len(backends) == 0:
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
@ -343,7 +333,7 @@ def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
|
|||||||
|
|
||||||
def get_static_dispatch_backend(
|
def get_static_dispatch_backend(
|
||||||
f: NativeFunction, backend_index: BackendIndex
|
f: NativeFunction, backend_index: BackendIndex
|
||||||
) -> Optional[DispatchKey]:
|
) -> DispatchKey | None:
|
||||||
if f.structured_delegate is not None or backend_index.has_kernel(f):
|
if f.structured_delegate is not None or backend_index.has_kernel(f):
|
||||||
# TODO: for ops with structured_delegate it should check the dispatch table of
|
# TODO: for ops with structured_delegate it should check the dispatch table of
|
||||||
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
|
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
|
||||||
@ -362,8 +352,8 @@ def get_static_dispatch_backend(
|
|||||||
|
|
||||||
|
|
||||||
def static_dispatch_ops_header(
|
def static_dispatch_ops_header(
|
||||||
f: NativeFunction, backend_index: List[BackendIndex]
|
f: NativeFunction, backend_index: list[BackendIndex]
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
if backend_index is None or f.manual_kernel_registration:
|
if backend_index is None or f.manual_kernel_registration:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -377,7 +367,7 @@ def static_dispatch_ops_header(
|
|||||||
return "\n".join(output)
|
return "\n".join(output)
|
||||||
|
|
||||||
|
|
||||||
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
|
||||||
return [
|
return [
|
||||||
f"#include <ATen/{dispatch_key}Functions.h>"
|
f"#include <ATen/{dispatch_key}Functions.h>"
|
||||||
for dispatch_key in static_dispatch_keys(backends)
|
for dispatch_key in static_dispatch_keys(backends)
|
||||||
@ -388,12 +378,12 @@ def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
|||||||
# Note that we have a special case for `memory_format` argument and this case is not covered by
|
# Note that we have a special case for `memory_format` argument and this case is not covered by
|
||||||
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
|
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
|
||||||
def translate_args(
|
def translate_args(
|
||||||
sig: Union[CppSignature, DispatcherSignature],
|
sig: CppSignature | DispatcherSignature,
|
||||||
cpp_sig: CppSignature,
|
cpp_sig: CppSignature,
|
||||||
) -> str:
|
) -> str:
|
||||||
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
|
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
|
||||||
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
|
def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
|
||||||
output_bindings: List[Binding] = []
|
output_bindings: list[Binding] = []
|
||||||
for binding in input_bindings:
|
for binding in input_bindings:
|
||||||
if binding.name == "memory_format":
|
if binding.name == "memory_format":
|
||||||
spl_mem_format_binding = Binding(
|
spl_mem_format_binding = Binding(
|
||||||
@ -423,7 +413,7 @@ def translate_args(
|
|||||||
|
|
||||||
|
|
||||||
def generate_static_dispatch_backend_call(
|
def generate_static_dispatch_backend_call(
|
||||||
sig: Union[CppSignature, DispatcherSignature],
|
sig: CppSignature | DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_index: BackendIndex,
|
backend_index: BackendIndex,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -441,9 +431,9 @@ def generate_static_dispatch_backend_call(
|
|||||||
|
|
||||||
|
|
||||||
def generate_static_dispatch_fallback_call(
|
def generate_static_dispatch_fallback_call(
|
||||||
sig: Union[CppSignature, DispatcherSignature],
|
sig: CppSignature | DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_indices: List[BackendIndex],
|
backend_indices: list[BackendIndex],
|
||||||
) -> str:
|
) -> str:
|
||||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||||
f, method=False, fallback_binding=False
|
f, method=False, fallback_binding=False
|
||||||
@ -470,9 +460,9 @@ def generate_static_dispatch_fallback_call(
|
|||||||
|
|
||||||
|
|
||||||
def static_dispatch(
|
def static_dispatch(
|
||||||
sig: Union[CppSignature, DispatcherSignature],
|
sig: CppSignature | DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_indices: List[BackendIndex],
|
backend_indices: list[BackendIndex],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
|
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
|
||||||
@ -512,7 +502,7 @@ def static_dispatch(
|
|||||||
tensor_opts = f.func.arguments.tensor_options
|
tensor_opts = f.func.arguments.tensor_options
|
||||||
|
|
||||||
stmts = []
|
stmts = []
|
||||||
subexprs: List[str] = []
|
subexprs: list[str] = []
|
||||||
if tensor_opts is not None:
|
if tensor_opts is not None:
|
||||||
subexprs.append(
|
subexprs.append(
|
||||||
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
|
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
|
||||||
@ -548,10 +538,10 @@ def static_dispatch(
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class RegisterSchema:
|
class RegisterSchema:
|
||||||
selector: SelectiveBuilder
|
selector: SelectiveBuilder
|
||||||
known_tags: Dict[str, int] = field(default_factory=dict)
|
known_tags: dict[str, int] = field(default_factory=dict)
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
if not self.selector.is_native_function_selected(f):
|
if not self.selector.is_native_function_selected(f):
|
||||||
return None
|
return None
|
||||||
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
|
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
|
||||||
@ -573,7 +563,7 @@ class RegisterSchema:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeOperators:
|
class ComputeOperators:
|
||||||
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
||||||
static_dispatch_backend_indices: List[BackendIndex]
|
static_dispatch_backend_indices: list[BackendIndex]
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> str:
|
def __call__(self, f: NativeFunction) -> str:
|
||||||
@ -670,7 +660,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeFunction:
|
class ComputeFunction:
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
sig_group = CppSignatureGroup.from_native_function(
|
sig_group = CppSignatureGroup.from_native_function(
|
||||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||||
)
|
)
|
||||||
@ -718,10 +708,10 @@ namespace symint {{
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeTensorMethod:
|
class ComputeTensorMethod:
|
||||||
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
||||||
static_dispatch_backend_indices: List[BackendIndex]
|
static_dispatch_backend_indices: list[BackendIndex]
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
if Variant.method not in f.variants:
|
if Variant.method not in f.variants:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -764,7 +754,7 @@ inline {sig.defn(prefix="Tensor::")} const {{
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeRedispatchFunction:
|
class ComputeRedispatchFunction:
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
# We unconditionally generate function variants of the redispatch API.
|
# We unconditionally generate function variants of the redispatch API.
|
||||||
# This is mainly because we can namespace functions separately, but not methods,
|
# This is mainly because we can namespace functions separately, but not methods,
|
||||||
sig_group = CppSignatureGroup.from_native_function(
|
sig_group = CppSignatureGroup.from_native_function(
|
||||||
@ -798,7 +788,7 @@ def compute_aten_op(f: NativeFunction) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# Generates MetaFunctions.h
|
# Generates MetaFunctions.h
|
||||||
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
|
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
|
||||||
if not g.structured:
|
if not g.structured:
|
||||||
return None
|
return None
|
||||||
with native_function_manager(g.out):
|
with native_function_manager(g.out):
|
||||||
@ -943,7 +933,7 @@ class ComputeBackendSelect:
|
|||||||
selector: SelectiveBuilder
|
selector: SelectiveBuilder
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
if not needs_backend_select(f, self.selector):
|
if not needs_backend_select(f, self.selector):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -959,7 +949,7 @@ class ComputeBackendSelect:
|
|||||||
|
|
||||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
||||||
|
|
||||||
sig: Union[NativeSignature, DispatcherSignature]
|
sig: NativeSignature | DispatcherSignature
|
||||||
sig = dispatcher_sig
|
sig = dispatcher_sig
|
||||||
dispatcher_exprs = dispatcher_sig.exprs()
|
dispatcher_exprs = dispatcher_sig.exprs()
|
||||||
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
|
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
|
||||||
@ -1059,7 +1049,7 @@ def dynamic_type(t: Type) -> str:
|
|||||||
).cpp_type()
|
).cpp_type()
|
||||||
|
|
||||||
|
|
||||||
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
|
def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
|
||||||
# This is written out explicitly to ensure that Tensor and
|
# This is written out explicitly to ensure that Tensor and
|
||||||
# namespace are put into the list in the right order
|
# namespace are put into the list in the right order
|
||||||
method_of = ["Type"]
|
method_of = ["Type"]
|
||||||
@ -1072,7 +1062,7 @@ def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
|
|||||||
|
|
||||||
def compute_returns_yaml(
|
def compute_returns_yaml(
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
|
) -> tuple[list[dict[str, str]], dict[str, str]]:
|
||||||
# Note [name and field_name]
|
# Note [name and field_name]
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# To understand name_to_field_name, we must first talk about this
|
# To understand name_to_field_name, we must first talk about this
|
||||||
@ -1112,7 +1102,7 @@ def compute_returns_yaml(
|
|||||||
# schema itself.
|
# schema itself.
|
||||||
#
|
#
|
||||||
# See also https://github.com/pytorch/pytorch/issues/43114
|
# See also https://github.com/pytorch/pytorch/issues/43114
|
||||||
name_to_field_name: Dict[str, str] = {}
|
name_to_field_name: dict[str, str] = {}
|
||||||
|
|
||||||
# Compute the returns field of the YAML entry
|
# Compute the returns field of the YAML entry
|
||||||
names = cpp.return_names(f)
|
names = cpp.return_names(f)
|
||||||
@ -1141,12 +1131,12 @@ def compute_cpp_argument_yaml(
|
|||||||
cpp_a: Binding,
|
cpp_a: Binding,
|
||||||
*,
|
*,
|
||||||
schema_order: bool,
|
schema_order: bool,
|
||||||
kwarg_only_set: Set[str],
|
kwarg_only_set: set[str],
|
||||||
out_arg_set: Set[str],
|
out_arg_set: set[str],
|
||||||
name_to_field_name: Dict[str, str],
|
name_to_field_name: dict[str, str],
|
||||||
) -> object:
|
) -> object:
|
||||||
if isinstance(cpp_a.argument, TensorOptionsArguments):
|
if isinstance(cpp_a.argument, TensorOptionsArguments):
|
||||||
arg: Dict[str, object] = {
|
arg: dict[str, object] = {
|
||||||
"annotation": None,
|
"annotation": None,
|
||||||
"dynamic_type": "at::TensorOptions",
|
"dynamic_type": "at::TensorOptions",
|
||||||
"is_nullable": False,
|
"is_nullable": False,
|
||||||
@ -1173,11 +1163,11 @@ def compute_argument_yaml(
|
|||||||
a: Argument,
|
a: Argument,
|
||||||
*,
|
*,
|
||||||
schema_order: bool,
|
schema_order: bool,
|
||||||
kwarg_only_set: Set[str],
|
kwarg_only_set: set[str],
|
||||||
out_arg_set: Set[str],
|
out_arg_set: set[str],
|
||||||
name_to_field_name: Dict[str, str],
|
name_to_field_name: dict[str, str],
|
||||||
) -> object:
|
) -> object:
|
||||||
arg: Dict[str, object] = {
|
arg: dict[str, object] = {
|
||||||
"annotation": str(a.annotation) if a.annotation else None,
|
"annotation": str(a.annotation) if a.annotation else None,
|
||||||
"dynamic_type": dynamic_type(a.type),
|
"dynamic_type": dynamic_type(a.type),
|
||||||
"is_nullable": a.type.is_nullable(),
|
"is_nullable": a.type.is_nullable(),
|
||||||
@ -1303,7 +1293,7 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
|
|||||||
|
|
||||||
@with_native_function_and_indices
|
@with_native_function_and_indices
|
||||||
def compute_registration_declarations(
|
def compute_registration_declarations(
|
||||||
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
|
f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
|
||||||
) -> str:
|
) -> str:
|
||||||
name = dispatcher.name(f.func)
|
name = dispatcher.name(f.func)
|
||||||
returns_type = dispatcher.returns_type(
|
returns_type = dispatcher.returns_type(
|
||||||
@ -1311,7 +1301,7 @@ def compute_registration_declarations(
|
|||||||
).cpp_type_registration_declarations()
|
).cpp_type_registration_declarations()
|
||||||
args = dispatcher.arguments(f.func)
|
args = dispatcher.arguments(f.func)
|
||||||
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
|
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
|
||||||
comment_data: Dict[str, str] = {
|
comment_data: dict[str, str] = {
|
||||||
"schema": f"aten::{f.func}",
|
"schema": f"aten::{f.func}",
|
||||||
# TODO: What exactly is the semantics of the 'dispatch' field?
|
# TODO: What exactly is the semantics of the 'dispatch' field?
|
||||||
"dispatch": str(
|
"dispatch": str(
|
||||||
@ -1337,8 +1327,8 @@ def compute_registration_declarations(
|
|||||||
|
|
||||||
|
|
||||||
def get_custom_build_selector(
|
def get_custom_build_selector(
|
||||||
provided_op_registration_allowlist: Optional[List[str]],
|
provided_op_registration_allowlist: list[str] | None,
|
||||||
op_selection_yaml_path: Optional[str],
|
op_selection_yaml_path: str | None,
|
||||||
) -> SelectiveBuilder:
|
) -> SelectiveBuilder:
|
||||||
assert not (
|
assert not (
|
||||||
provided_op_registration_allowlist is not None
|
provided_op_registration_allowlist is not None
|
||||||
@ -1349,7 +1339,7 @@ def get_custom_build_selector(
|
|||||||
+ "same time."
|
+ "same time."
|
||||||
)
|
)
|
||||||
|
|
||||||
op_registration_allowlist: Optional[Set[str]] = None
|
op_registration_allowlist: set[str] | None = None
|
||||||
if provided_op_registration_allowlist is not None:
|
if provided_op_registration_allowlist is not None:
|
||||||
op_registration_allowlist = set(provided_op_registration_allowlist)
|
op_registration_allowlist = set(provided_op_registration_allowlist)
|
||||||
|
|
||||||
@ -1369,11 +1359,11 @@ def get_custom_build_selector(
|
|||||||
|
|
||||||
def get_grouped_by_view_native_functions(
|
def get_grouped_by_view_native_functions(
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
|
||||||
def maybe_create_view_group(
|
def maybe_create_view_group(
|
||||||
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
|
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
|
||||||
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
) -> list[NativeFunction | NativeFunctionsViewGroup]:
|
||||||
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
|
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
|
||||||
if ViewSchemaKind.aliasing in d:
|
if ViewSchemaKind.aliasing in d:
|
||||||
view = d.pop(ViewSchemaKind.aliasing)
|
view = d.pop(ViewSchemaKind.aliasing)
|
||||||
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
|
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
|
||||||
@ -1391,8 +1381,8 @@ def get_grouped_by_view_native_functions(
|
|||||||
funcs.extend(d.values())
|
funcs.extend(d.values())
|
||||||
return funcs
|
return funcs
|
||||||
|
|
||||||
grouped_by_views: Dict[
|
grouped_by_views: dict[
|
||||||
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
|
FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
schema = f.func.view_signature()
|
schema = f.func.view_signature()
|
||||||
@ -1416,10 +1406,10 @@ def get_grouped_by_view_native_functions(
|
|||||||
|
|
||||||
def get_grouped_native_functions(
|
def get_grouped_native_functions(
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||||
def flatten_pre_group(
|
def flatten_pre_group(
|
||||||
d: Dict[SchemaKind, NativeFunction]
|
d: dict[SchemaKind, NativeFunction]
|
||||||
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||||
r = NativeFunctionsGroup.from_dict(d)
|
r = NativeFunctionsGroup.from_dict(d)
|
||||||
if r is None:
|
if r is None:
|
||||||
# Invariant: any NativeFunctions that are code-generated
|
# Invariant: any NativeFunctions that are code-generated
|
||||||
@ -1438,13 +1428,13 @@ def get_grouped_native_functions(
|
|||||||
|
|
||||||
def get_ns_grouped_kernels(
|
def get_ns_grouped_kernels(
|
||||||
*,
|
*,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
native_function_decl_gen: Callable[
|
native_function_decl_gen: Callable[
|
||||||
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
|
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
|
||||||
] = dest.compute_native_function_declaration,
|
] = dest.compute_native_function_declaration,
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||||
for f in grouped_native_functions:
|
for f in grouped_native_functions:
|
||||||
native_function_namespaces = set()
|
native_function_namespaces = set()
|
||||||
dispatch_keys = set()
|
dispatch_keys = set()
|
||||||
@ -1467,9 +1457,9 @@ def get_ns_grouped_kernels(
|
|||||||
|
|
||||||
def get_native_function_declarations_from_ns_grouped_kernels(
|
def get_native_function_declarations_from_ns_grouped_kernels(
|
||||||
*,
|
*,
|
||||||
ns_grouped_kernels: Dict[str, List[str]],
|
ns_grouped_kernels: dict[str, list[str]],
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
declarations: List[str] = []
|
declarations: list[str] = []
|
||||||
newline = "\n"
|
newline = "\n"
|
||||||
for namespace, kernels in ns_grouped_kernels.items():
|
for namespace, kernels in ns_grouped_kernels.items():
|
||||||
ns_helper = NamespaceHelper(
|
ns_helper = NamespaceHelper(
|
||||||
@ -1495,12 +1485,12 @@ def get_native_function_declarations_from_ns_grouped_kernels(
|
|||||||
# Return native function declarations grouped by their namespaces.
|
# Return native function declarations grouped by their namespaces.
|
||||||
def get_native_function_declarations(
|
def get_native_function_declarations(
|
||||||
*,
|
*,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
native_function_decl_gen: Callable[
|
native_function_decl_gen: Callable[
|
||||||
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
|
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
|
||||||
] = dest.compute_native_function_declaration,
|
] = dest.compute_native_function_declaration,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Generate kernel declarations, in `NativeFunction(s).h`.
|
Generate kernel declarations, in `NativeFunction(s).h`.
|
||||||
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
|
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
|
||||||
@ -1520,7 +1510,7 @@ def get_native_function_declarations(
|
|||||||
|
|
||||||
|
|
||||||
def get_kernel_namespace(
|
def get_kernel_namespace(
|
||||||
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
|
*, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
|
||||||
) -> str:
|
) -> str:
|
||||||
backend_metadata = backend_idx.get_kernel(f)
|
backend_metadata = backend_idx.get_kernel(f)
|
||||||
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
|
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
|
||||||
@ -1538,7 +1528,7 @@ def get_kernel_namespace(
|
|||||||
def get_native_function_definitions(
|
def get_native_function_definitions(
|
||||||
*,
|
*,
|
||||||
fm: FileManager,
|
fm: FileManager,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_idx: BackendIndex,
|
backend_idx: BackendIndex,
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
@ -1546,11 +1536,11 @@ def get_native_function_definitions(
|
|||||||
symint: bool,
|
symint: bool,
|
||||||
skip_dispatcher_op_registration: bool,
|
skip_dispatcher_op_registration: bool,
|
||||||
gen_dispatch_helpers: bool,
|
gen_dispatch_helpers: bool,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
definitions: List[str] = []
|
definitions: list[str] = []
|
||||||
ns_definitions: Dict[str, List[str]] = defaultdict(list)
|
ns_definitions: dict[str, list[str]] = defaultdict(list)
|
||||||
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
|
anonymous_definitions: dict[str, list[str]] = defaultdict(list)
|
||||||
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
|
registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
|
||||||
newline = "\n"
|
newline = "\n"
|
||||||
ns_gen = dest.RegisterDispatchKey(
|
ns_gen = dest.RegisterDispatchKey(
|
||||||
backend_idx,
|
backend_idx,
|
||||||
@ -1640,15 +1630,15 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
|||||||
# Used in CPUFunctions_inl.h and etc.
|
# Used in CPUFunctions_inl.h and etc.
|
||||||
def get_namespaced_declaration(
|
def get_namespaced_declaration(
|
||||||
*,
|
*,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_idx: BackendIndex,
|
backend_idx: BackendIndex,
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
symint: bool,
|
symint: bool,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
declarations: List[str] = []
|
declarations: list[str] = []
|
||||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||||
newline = "\n"
|
newline = "\n"
|
||||||
func = dest.RegisterDispatchKey(
|
func = dest.RegisterDispatchKey(
|
||||||
backend_idx,
|
backend_idx,
|
||||||
@ -1692,8 +1682,8 @@ def get_native_function_schema_registrations(
|
|||||||
*,
|
*,
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
schema_selector: SelectiveBuilder,
|
schema_selector: SelectiveBuilder,
|
||||||
) -> Tuple[List[str], str]:
|
) -> tuple[list[str], str]:
|
||||||
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||||
for native_function in native_functions:
|
for native_function in native_functions:
|
||||||
ns_native_functions[native_function.namespace].append(native_function)
|
ns_native_functions[native_function.namespace].append(native_function)
|
||||||
schema_registrations = ""
|
schema_registrations = ""
|
||||||
@ -1727,14 +1717,14 @@ def get_native_function_schema_registrations(
|
|||||||
def gen_aggregated_headers(
|
def gen_aggregated_headers(
|
||||||
*,
|
*,
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||||
static_dispatch_idx: List[BackendIndex],
|
static_dispatch_idx: list[BackendIndex],
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
cuda_fm: FileManager,
|
||||||
functions_keys: Set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1848,25 +1838,25 @@ def gen_aggregated_headers(
|
|||||||
def gen_per_operator_headers(
|
def gen_per_operator_headers(
|
||||||
*,
|
*,
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
static_dispatch_idx: List[BackendIndex],
|
static_dispatch_idx: list[BackendIndex],
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
cuda_fm: FileManager,
|
||||||
ops_fm: FileManager,
|
ops_fm: FileManager,
|
||||||
functions_keys: Set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
# For CMake builds, split operator declarations into separate headers in
|
# For CMake builds, split operator declarations into separate headers in
|
||||||
# the ATen/ops folder to split up header dependencies
|
# the ATen/ops folder to split up header dependencies
|
||||||
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list)
|
functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||||
for fn in native_functions:
|
for fn in native_functions:
|
||||||
functions_by_root_name[fn.root_name].append(fn)
|
functions_by_root_name[fn.root_name].append(fn)
|
||||||
|
|
||||||
grouped_functions_by_root_name: Dict[
|
grouped_functions_by_root_name: dict[
|
||||||
str, List[Union[NativeFunction, NativeFunctionsGroup]]
|
str, list[NativeFunction | NativeFunctionsGroup]
|
||||||
] = defaultdict(list)
|
] = defaultdict(list)
|
||||||
for group in grouped_native_functions:
|
for group in grouped_native_functions:
|
||||||
name = group.root_name
|
name = group.root_name
|
||||||
@ -2042,18 +2032,18 @@ def gen_per_operator_headers(
|
|||||||
def gen_headers(
|
def gen_headers(
|
||||||
*,
|
*,
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
valid_tags: Set[str],
|
valid_tags: set[str],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||||
static_dispatch_idx: List[BackendIndex],
|
static_dispatch_idx: list[BackendIndex],
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
core_fm: FileManager,
|
core_fm: FileManager,
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
cuda_fm: FileManager,
|
||||||
ops_fm: FileManager,
|
ops_fm: FileManager,
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
functions_keys: Set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
per_operator_headers: bool,
|
per_operator_headers: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -2133,8 +2123,8 @@ def gen_headers(
|
|||||||
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
|
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen_aten_interned_strings() -> Dict[str, str]:
|
def gen_aten_interned_strings() -> dict[str, str]:
|
||||||
attrs: Set[str] = set() # All function argument names
|
attrs: set[str] = set() # All function argument names
|
||||||
names = set() # All ATen function names
|
names = set() # All ATen function names
|
||||||
for func in native_functions:
|
for func in native_functions:
|
||||||
names.add(str(func.func.name.name))
|
names.add(str(func.func.name.name))
|
||||||
@ -2171,7 +2161,7 @@ def gen_headers(
|
|||||||
|
|
||||||
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
|
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
|
||||||
|
|
||||||
def gen_tags_enum() -> Dict[str, str]:
|
def gen_tags_enum() -> dict[str, str]:
|
||||||
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
|
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
|
||||||
|
|
||||||
core_fm.write("enum_tag.h", gen_tags_enum)
|
core_fm.write("enum_tag.h", gen_tags_enum)
|
||||||
@ -2180,19 +2170,19 @@ def gen_headers(
|
|||||||
def gen_source_files(
|
def gen_source_files(
|
||||||
*,
|
*,
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||||
view_groups: Sequence[NativeFunctionsViewGroup],
|
view_groups: Sequence[NativeFunctionsViewGroup],
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
static_dispatch_idx: List[BackendIndex],
|
static_dispatch_idx: list[BackendIndex],
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
aoti_fm: FileManager,
|
aoti_fm: FileManager,
|
||||||
core_fm: FileManager,
|
core_fm: FileManager,
|
||||||
cpu_fm: FileManager,
|
cpu_fm: FileManager,
|
||||||
cpu_vec_fm: FileManager,
|
cpu_vec_fm: FileManager,
|
||||||
cuda_fm: FileManager,
|
cuda_fm: FileManager,
|
||||||
dispatch_keys: Sequence[DispatchKey],
|
dispatch_keys: Sequence[DispatchKey],
|
||||||
functions_keys: Set[DispatchKey],
|
functions_keys: set[DispatchKey],
|
||||||
rocm: bool,
|
rocm: bool,
|
||||||
force_schema_registration: bool,
|
force_schema_registration: bool,
|
||||||
per_operator_headers: bool,
|
per_operator_headers: bool,
|
||||||
@ -2216,7 +2206,7 @@ def gen_source_files(
|
|||||||
|
|
||||||
if per_operator_headers:
|
if per_operator_headers:
|
||||||
|
|
||||||
def operator_headers() -> List[str]:
|
def operator_headers() -> list[str]:
|
||||||
headers = []
|
headers = []
|
||||||
for g in grouped_native_functions:
|
for g in grouped_native_functions:
|
||||||
is_registered = False
|
is_registered = False
|
||||||
@ -2258,7 +2248,7 @@ def gen_source_files(
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def operator_headers() -> List[str]:
|
def operator_headers() -> list[str]:
|
||||||
headers = ["#include <ATen/NativeFunctions.h>"]
|
headers = ["#include <ATen/NativeFunctions.h>"]
|
||||||
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
||||||
headers.append("#include <ATen/Functions.h>")
|
headers.append("#include <ATen/Functions.h>")
|
||||||
@ -2449,7 +2439,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
|||||||
del fm
|
del fm
|
||||||
|
|
||||||
# BackendSelect is generated specially
|
# BackendSelect is generated specially
|
||||||
def gen_backend_select() -> Dict[str, List[str]]:
|
def gen_backend_select() -> dict[str, list[str]]:
|
||||||
relevant_fns = [
|
relevant_fns = [
|
||||||
fn for fn in native_functions if needs_backend_select(fn, selector)
|
fn for fn in native_functions if needs_backend_select(fn, selector)
|
||||||
]
|
]
|
||||||
@ -2494,7 +2484,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
|||||||
)
|
)
|
||||||
|
|
||||||
def key_func(
|
def key_func(
|
||||||
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
) -> str:
|
) -> str:
|
||||||
return fn.root_name
|
return fn.root_name
|
||||||
|
|
||||||
@ -2536,11 +2526,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
|||||||
)
|
)
|
||||||
|
|
||||||
def functionalization_env_callable(
|
def functionalization_env_callable(
|
||||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
def gen_op_headers(
|
def gen_op_headers(
|
||||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
if isinstance(g, NativeFunctionsViewGroup):
|
if isinstance(g, NativeFunctionsViewGroup):
|
||||||
# view ops always get a functionalization kernel
|
# view ops always get a functionalization kernel
|
||||||
headers = [
|
headers = [
|
||||||
@ -2590,8 +2580,8 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
all_groups: List[
|
all_groups: list[
|
||||||
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
|
||||||
] = list(structured_native_functions) + list(
|
] = list(structured_native_functions) + list(
|
||||||
view_groups # type: ignore[assignment, arg-type, operator]
|
view_groups # type: ignore[assignment, arg-type, operator]
|
||||||
)
|
)
|
||||||
@ -2600,11 +2590,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
|||||||
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
|
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
|
||||||
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
|
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
|
||||||
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
|
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
|
||||||
structured_map: Dict[OperatorName, NativeFunction] = {
|
structured_map: dict[OperatorName, NativeFunction] = {
|
||||||
f.func.name: f
|
f.func.name: f
|
||||||
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
|
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
|
||||||
}
|
}
|
||||||
view_map: Dict[OperatorName, NativeFunction] = {
|
view_map: dict[OperatorName, NativeFunction] = {
|
||||||
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
|
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
|
||||||
}
|
}
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
@ -2715,12 +2705,12 @@ def gen_declarations_yaml(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_torchgen_root() -> pathlib.Path:
|
def get_torchgen_root() -> Path:
|
||||||
"""
|
"""
|
||||||
If you're depending on torchgen out-of-tree, you can use the root to figure
|
If you're depending on torchgen out-of-tree, you can use the root to figure
|
||||||
out the path to native_functions.yaml
|
out the path to native_functions.yaml
|
||||||
"""
|
"""
|
||||||
return pathlib.Path(__file__).parent.resolve()
|
return Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
@ -2882,11 +2872,11 @@ def main() -> None:
|
|||||||
#
|
#
|
||||||
# Invalid character escape '\c'.
|
# Invalid character escape '\c'.
|
||||||
core_install_dir = f"{options.install_dir}/core"
|
core_install_dir = f"{options.install_dir}/core"
|
||||||
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
|
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
|
||||||
ops_install_dir = f"{options.install_dir}/ops"
|
ops_install_dir = f"{options.install_dir}/ops"
|
||||||
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
|
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
|
||||||
aoti_install_dir = f"{options.aoti_install_dir}"
|
aoti_install_dir = f"{options.aoti_install_dir}"
|
||||||
pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
|
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
|
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
|
||||||
cpu_fm = make_file_manager(options=options)
|
cpu_fm = make_file_manager(options=options)
|
||||||
@ -2916,7 +2906,7 @@ def main() -> None:
|
|||||||
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
|
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
|
||||||
]
|
]
|
||||||
|
|
||||||
static_dispatch_idx: List[BackendIndex] = []
|
static_dispatch_idx: list[BackendIndex] = []
|
||||||
if options.static_dispatch_backend:
|
if options.static_dispatch_backend:
|
||||||
static_dispatch_idx = [
|
static_dispatch_idx = [
|
||||||
backend_indices[DispatchKey.parse(key)]
|
backend_indices[DispatchKey.parse(key)]
|
||||||
@ -2973,7 +2963,7 @@ def main() -> None:
|
|||||||
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
|
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
|
||||||
|
|
||||||
if options.output_dependencies:
|
if options.output_dependencies:
|
||||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
depfile_path = Path(options.output_dependencies).resolve()
|
||||||
depfile_name = depfile_path.name
|
depfile_name = depfile_path.name
|
||||||
depfile_stem = depfile_path.stem
|
depfile_stem = depfile_path.stem
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||||
@ -69,7 +71,7 @@ base_type_to_callsite_expr = {
|
|||||||
|
|
||||||
|
|
||||||
# convert args to C types, names in declarations, and expressions in function bodies
|
# convert args to C types, names in declarations, and expressions in function bodies
|
||||||
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return]
|
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
|
||||||
if isinstance(typ, BaseType):
|
if isinstance(typ, BaseType):
|
||||||
if typ.name in base_type_to_c_type:
|
if typ.name in base_type_to_c_type:
|
||||||
return (
|
return (
|
||||||
@ -167,12 +169,12 @@ def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]:
|
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
|
||||||
return [typ + " " + name for typ, name in zip(types, names)]
|
return [typ + " " + name for typ, name in zip(types, names)]
|
||||||
|
|
||||||
|
|
||||||
# Generate argument declarations and callsite expressions
|
# Generate argument declarations and callsite expressions
|
||||||
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]:
|
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
|
||||||
types = []
|
types = []
|
||||||
new_names = []
|
new_names = []
|
||||||
callsite_exprs = []
|
callsite_exprs = []
|
||||||
@ -189,7 +191,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[s
|
|||||||
# Return values are passed out as pointer arguments because all the C shim functions
|
# Return values are passed out as pointer arguments because all the C shim functions
|
||||||
# are expected to return AOTITorchError.
|
# are expected to return AOTITorchError.
|
||||||
# Generate returns as declarations and callsite expressions
|
# Generate returns as declarations and callsite expressions
|
||||||
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
|
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
|
||||||
types = []
|
types = []
|
||||||
names = []
|
names = []
|
||||||
for idx, ret in enumerate(schema.returns):
|
for idx, ret in enumerate(schema.returns):
|
||||||
@ -222,7 +224,7 @@ def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
|
|||||||
ret_pointer_can_be_null = True
|
ret_pointer_can_be_null = True
|
||||||
break
|
break
|
||||||
|
|
||||||
callsite_exprs: List[str] = []
|
callsite_exprs: list[str] = []
|
||||||
for idx, ret in enumerate(schema.returns):
|
for idx, ret in enumerate(schema.returns):
|
||||||
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
|
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
|
||||||
assert isinstance(ret.type, BaseType)
|
assert isinstance(ret.type, BaseType)
|
||||||
@ -236,12 +238,12 @@ def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
|
|||||||
|
|
||||||
|
|
||||||
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
|
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
|
||||||
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {}
|
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
|
||||||
|
|
||||||
|
|
||||||
def gen_declaration_and_definition(
|
def gen_declaration_and_definition(
|
||||||
schema: FunctionSchema, device: str, backend_call: str
|
schema: FunctionSchema, device: str, backend_call: str
|
||||||
) -> Tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
func_name = schema.name.unambiguous_name()
|
func_name = schema.name.unambiguous_name()
|
||||||
|
|
||||||
global declaration_definition_cache
|
global declaration_definition_cache
|
||||||
@ -254,7 +256,7 @@ def gen_declaration_and_definition(
|
|||||||
args, callsite_exprs = gen_arguments(
|
args, callsite_exprs = gen_arguments(
|
||||||
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||||
)
|
)
|
||||||
ret_assignments: List[str] = []
|
ret_assignments: list[str] = []
|
||||||
else:
|
else:
|
||||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||||
# ignore return values for inplace ops
|
# ignore return values for inplace ops
|
||||||
@ -284,7 +286,7 @@ def gen_declaration_and_definition(
|
|||||||
|
|
||||||
|
|
||||||
def gen_static_dispatch_backend_call_signature(
|
def gen_static_dispatch_backend_call_signature(
|
||||||
sig: Union[CppSignature, DispatcherSignature],
|
sig: CppSignature | DispatcherSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
) -> CppSignature:
|
) -> CppSignature:
|
||||||
sig = DispatcherSignature.from_schema(f.func)
|
sig = DispatcherSignature.from_schema(f.func)
|
||||||
@ -310,10 +312,10 @@ def gen_static_dispatch_backend_call(
|
|||||||
|
|
||||||
def get_backend_index_for_aoti(
|
def get_backend_index_for_aoti(
|
||||||
func: NativeFunction,
|
func: NativeFunction,
|
||||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
) -> Optional[BackendIndex]:
|
) -> BackendIndex | None:
|
||||||
backend_index = None
|
backend_index = None
|
||||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||||
func.structured_delegate is not None
|
func.structured_delegate is not None
|
||||||
@ -341,10 +343,10 @@ def get_backend_index_for_aoti(
|
|||||||
|
|
||||||
def get_header_for_aoti(
|
def get_header_for_aoti(
|
||||||
func: NativeFunction,
|
func: NativeFunction,
|
||||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
backend_index = get_backend_index_for_aoti(
|
backend_index = get_backend_index_for_aoti(
|
||||||
func, func_group_mapping, dispatch_key, backend_indices
|
func, func_group_mapping, dispatch_key, backend_indices
|
||||||
)
|
)
|
||||||
@ -365,11 +367,11 @@ def get_fallback_op_name(func: NativeFunction) -> str:
|
|||||||
|
|
||||||
def gen_c_shim(
|
def gen_c_shim(
|
||||||
func: NativeFunction,
|
func: NativeFunction,
|
||||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
header: bool,
|
header: bool,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
backend_index = get_backend_index_for_aoti(
|
backend_index = get_backend_index_for_aoti(
|
||||||
func, func_group_mapping, dispatch_key, backend_indices
|
func, func_group_mapping, dispatch_key, backend_indices
|
||||||
)
|
)
|
||||||
@ -399,16 +401,16 @@ def gen_c_shim(
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ShimGenerator:
|
class ShimGenerator:
|
||||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup]
|
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
|
||||||
dispatch_key: DispatchKey
|
dispatch_key: DispatchKey
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex]
|
backend_indices: dict[DispatchKey, BackendIndex]
|
||||||
header: bool # True to generate .h and False to generate .cpp
|
header: bool # True to generate .h and False to generate .cpp
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
func: NativeFunction,
|
func: NativeFunction,
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
result = gen_c_shim(
|
result = gen_c_shim(
|
||||||
func,
|
func,
|
||||||
self.func_group_mapping,
|
self.func_group_mapping,
|
||||||
@ -421,9 +423,9 @@ class ShimGenerator:
|
|||||||
|
|
||||||
def gen_aoti_c_shim(
|
def gen_aoti_c_shim(
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
header: bool,
|
header: bool,
|
||||||
includes: str = "",
|
includes: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import re
|
import re
|
||||||
from collections import Counter, defaultdict, namedtuple
|
from collections import Counter, defaultdict, namedtuple
|
||||||
from typing import Dict, List, Optional, Sequence, Set, Union
|
from typing import Sequence
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -36,10 +38,10 @@ ParsedExternalYaml = namedtuple(
|
|||||||
|
|
||||||
def parse_backend_yaml(
|
def parse_backend_yaml(
|
||||||
backend_yaml_path: str,
|
backend_yaml_path: str,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
) -> ParsedExternalYaml:
|
) -> ParsedExternalYaml:
|
||||||
native_functions_map: Dict[OperatorName, NativeFunction] = {
|
native_functions_map: dict[OperatorName, NativeFunction] = {
|
||||||
f.func.name: f
|
f.func.name: f
|
||||||
for f in concatMap(
|
for f in concatMap(
|
||||||
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
|
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
|
||||||
@ -119,14 +121,14 @@ def parse_backend_yaml(
|
|||||||
Only the following keys are supported: {", ".join(valid_keys)}'
|
Only the following keys are supported: {", ".join(valid_keys)}'
|
||||||
|
|
||||||
def create_backend_index(
|
def create_backend_index(
|
||||||
backend_ops: List[str],
|
backend_ops: list[str],
|
||||||
symint_ops: Set[str],
|
symint_ops: set[str],
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
*,
|
*,
|
||||||
use_out_as_primary: bool,
|
use_out_as_primary: bool,
|
||||||
use_device_guard: bool,
|
use_device_guard: bool,
|
||||||
) -> BackendIndex:
|
) -> BackendIndex:
|
||||||
metadata: Dict[OperatorName, BackendMetadata] = {}
|
metadata: dict[OperatorName, BackendMetadata] = {}
|
||||||
for op in backend_ops:
|
for op in backend_ops:
|
||||||
op_name = OperatorName.parse(op)
|
op_name = OperatorName.parse(op)
|
||||||
assert (
|
assert (
|
||||||
@ -149,7 +151,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||||||
index=metadata,
|
index=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_key: Optional[DispatchKey] = None
|
backend_key: DispatchKey | None = None
|
||||||
if len(supported) > 0:
|
if len(supported) > 0:
|
||||||
with context(
|
with context(
|
||||||
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
|
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
|
||||||
@ -166,7 +168,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
|||||||
assert backend_key not in backend_indices
|
assert backend_key not in backend_indices
|
||||||
backend_indices[backend_key] = backend_idx
|
backend_indices[backend_key] = backend_idx
|
||||||
|
|
||||||
autograd_key: Optional[DispatchKey] = None
|
autograd_key: DispatchKey | None = None
|
||||||
if len(supported_autograd) > 0:
|
if len(supported_autograd) > 0:
|
||||||
with context(
|
with context(
|
||||||
lambda: f'The "autograd" key was specified, which indicates that you would like to override \
|
lambda: f'The "autograd" key was specified, which indicates that you would like to override \
|
||||||
@ -245,12 +247,12 @@ autograd key. They cannot be mix and matched. If this is something you need, fee
|
|||||||
|
|
||||||
def error_on_missing_kernels(
|
def error_on_missing_kernels(
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
backend_key: DispatchKey,
|
backend_key: DispatchKey,
|
||||||
autograd_key: Optional[DispatchKey],
|
autograd_key: DispatchKey | None,
|
||||||
class_name: str,
|
class_name: str,
|
||||||
kernel_defn_file_path: str,
|
kernel_defn_file_path: str,
|
||||||
full_codegen: Optional[List[OperatorName]] = None,
|
full_codegen: list[OperatorName] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
with open(kernel_defn_file_path) as f:
|
with open(kernel_defn_file_path) as f:
|
||||||
@ -268,7 +270,7 @@ def error_on_missing_kernels(
|
|||||||
)
|
)
|
||||||
# Quick mapping from each OperatorName used by the external backend
|
# Quick mapping from each OperatorName used by the external backend
|
||||||
# to its backend kernel name
|
# to its backend kernel name
|
||||||
expected_backend_op_names: Dict[OperatorName, str] = dict(
|
expected_backend_op_names: dict[OperatorName, str] = dict(
|
||||||
list(
|
list(
|
||||||
concatMap(
|
concatMap(
|
||||||
lambda index: [
|
lambda index: [
|
||||||
@ -278,13 +280,13 @@ def error_on_missing_kernels(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
expected_backend_native_funcs: List[NativeFunction] = [
|
expected_backend_native_funcs: list[NativeFunction] = [
|
||||||
f
|
f
|
||||||
for f in native_functions
|
for f in native_functions
|
||||||
if f.func.name in expected_backend_op_names.keys()
|
if f.func.name in expected_backend_op_names.keys()
|
||||||
and f.func.name not in full_codegen
|
and f.func.name not in full_codegen
|
||||||
]
|
]
|
||||||
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
|
expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
|
||||||
list
|
list
|
||||||
)
|
)
|
||||||
for native_f in expected_backend_native_funcs:
|
for native_f in expected_backend_native_funcs:
|
||||||
@ -356,10 +358,10 @@ def gen_dispatchkey_nativefunc_headers(
|
|||||||
fm: FileManager,
|
fm: FileManager,
|
||||||
class_name: str,
|
class_name: str,
|
||||||
cpp_namespace: str,
|
cpp_namespace: str,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
backend_dispatch_key: DispatchKey,
|
backend_dispatch_key: DispatchKey,
|
||||||
autograd_dispatch_key: Optional[DispatchKey],
|
autograd_dispatch_key: DispatchKey | None,
|
||||||
backend_name: str = "",
|
backend_name: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
assert class_name is not None
|
assert class_name is not None
|
||||||
@ -413,11 +415,11 @@ def gen_dispatcher_registrations(
|
|||||||
fm: FileManager,
|
fm: FileManager,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
class_name: str,
|
class_name: str,
|
||||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
backend_indices: dict[DispatchKey, BackendIndex],
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
backend_dispatch_key: DispatchKey,
|
backend_dispatch_key: DispatchKey,
|
||||||
dispatch_key: DispatchKey,
|
dispatch_key: DispatchKey,
|
||||||
selector: "SelectiveBuilder",
|
selector: SelectiveBuilder,
|
||||||
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
|
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
|
||||||
build_in_tree: bool = False,
|
build_in_tree: bool = False,
|
||||||
per_operator_headers: bool = False,
|
per_operator_headers: bool = False,
|
||||||
@ -524,7 +526,7 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
|||||||
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None
|
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||||
pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
|
pytorch_root = pathlib.Path(__file__).parent.parent.absolute()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pathlib
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -45,7 +47,6 @@ from torchgen.model import (
|
|||||||
OperatorName,
|
OperatorName,
|
||||||
Variant,
|
Variant,
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
|
||||||
from torchgen.utils import (
|
from torchgen.utils import (
|
||||||
context,
|
context,
|
||||||
FileManager,
|
FileManager,
|
||||||
@ -55,7 +56,11 @@ from torchgen.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
|
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
|
||||||
"""
|
"""
|
||||||
A wrapper function to basically get `sig.decl(include_context=True)`.
|
A wrapper function to basically get `sig.decl(include_context=True)`.
|
||||||
For ATen kernel, the codegen has no idea about ET contextArg, so we
|
For ATen kernel, the codegen has no idea about ET contextArg, so we
|
||||||
@ -72,9 +77,9 @@ def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def static_dispatch(
|
def static_dispatch(
|
||||||
sig: Union[CppSignature, ExecutorchCppSignature],
|
sig: CppSignature | ExecutorchCppSignature,
|
||||||
f: NativeFunction,
|
f: NativeFunction,
|
||||||
backend_indices: List[BackendIndex],
|
backend_indices: list[BackendIndex],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
|
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
|
||||||
@ -113,7 +118,7 @@ TORCH_API inline {_sig_decl_wrapper(sig)} {{
|
|||||||
# and the scaffolding to call into the dispatcher from these functions.
|
# and the scaffolding to call into the dispatcher from these functions.
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeFunction:
|
class ComputeFunction:
|
||||||
static_dispatch_backend_indices: List[BackendIndex]
|
static_dispatch_backend_indices: list[BackendIndex]
|
||||||
|
|
||||||
selector: SelectiveBuilder
|
selector: SelectiveBuilder
|
||||||
|
|
||||||
@ -122,7 +127,7 @@ class ComputeFunction:
|
|||||||
is_custom_op: Callable[[NativeFunction], bool]
|
is_custom_op: Callable[[NativeFunction], bool]
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
is_method_variant = False
|
is_method_variant = False
|
||||||
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
|
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
|
||||||
return None
|
return None
|
||||||
@ -136,7 +141,7 @@ class ComputeFunction:
|
|||||||
f"Can't handle native function {f.func} with the following variant specification {f.variants}."
|
f"Can't handle native function {f.func} with the following variant specification {f.variants}."
|
||||||
)
|
)
|
||||||
|
|
||||||
sig: Union[CppSignature, ExecutorchCppSignature] = (
|
sig: CppSignature | ExecutorchCppSignature = (
|
||||||
CppSignatureGroup.from_native_function(
|
CppSignatureGroup.from_native_function(
|
||||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||||
).most_faithful_signature()
|
).most_faithful_signature()
|
||||||
@ -179,10 +184,10 @@ class ComputeCodegenUnboxedKernels:
|
|||||||
@method_with_nested_native_function
|
@method_with_nested_native_function
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
unbox_kernel_entry: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]],
|
unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
|
||||||
) -> str:
|
) -> str:
|
||||||
f: NativeFunction = unbox_kernel_entry[0]
|
f: NativeFunction = unbox_kernel_entry[0]
|
||||||
kernel_key: Union[ETKernelKey, List[ETKernelKey]] = unbox_kernel_entry[1][0]
|
kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
|
||||||
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
|
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
|
||||||
|
|
||||||
op_name = f"{f.namespace}::{f.func.name}"
|
op_name = f"{f.namespace}::{f.func.name}"
|
||||||
@ -196,7 +201,7 @@ class ComputeCodegenUnboxedKernels:
|
|||||||
)
|
)
|
||||||
if not used_kernel_keys:
|
if not used_kernel_keys:
|
||||||
return ""
|
return ""
|
||||||
sig: Union[CppSignature, ExecutorchCppSignature]
|
sig: CppSignature | ExecutorchCppSignature
|
||||||
argument_type_gen: Callable[..., NamedCType]
|
argument_type_gen: Callable[..., NamedCType]
|
||||||
return_type_gen: Callable[..., CType]
|
return_type_gen: Callable[..., CType]
|
||||||
if self.use_aten_lib:
|
if self.use_aten_lib:
|
||||||
@ -290,11 +295,11 @@ def gen_unboxing(
|
|||||||
) -> None:
|
) -> None:
|
||||||
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
|
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
|
||||||
def key_func(
|
def key_func(
|
||||||
item: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]
|
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
|
||||||
) -> str:
|
) -> str:
|
||||||
return item[0].root_name + ":" + item[1][0].to_native_string()
|
return item[0].root_name + ":" + item[1][0].to_native_string()
|
||||||
|
|
||||||
items: List[Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]] = [
|
items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
|
||||||
(native_function, (kernel_key, metadata))
|
(native_function, (kernel_key, metadata))
|
||||||
for native_function in native_functions
|
for native_function in native_functions
|
||||||
for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
|
for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
|
||||||
@ -325,8 +330,8 @@ def gen_unboxing(
|
|||||||
|
|
||||||
@with_native_function_and_index # type: ignore[arg-type]
|
@with_native_function_and_index # type: ignore[arg-type]
|
||||||
def compute_native_function_declaration(
|
def compute_native_function_declaration(
|
||||||
g: Union[NativeFunctionsGroup, NativeFunction], kernel_index: ETKernelIndex
|
g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
assert isinstance(g, NativeFunction)
|
assert isinstance(g, NativeFunction)
|
||||||
sig = ExecutorchCppSignature.from_native_function(f=g)
|
sig = ExecutorchCppSignature.from_native_function(f=g)
|
||||||
metadata_list = kernel_index.get_kernels(g).values()
|
metadata_list = kernel_index.get_kernels(g).values()
|
||||||
@ -352,7 +357,7 @@ def gen_functions_declarations(
|
|||||||
kernel_index: ETKernelIndex,
|
kernel_index: ETKernelIndex,
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
use_aten_lib: bool,
|
use_aten_lib: bool,
|
||||||
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
|
custom_ops_native_functions: Sequence[NativeFunction] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates namespace separated C++ function API inline declaration/definitions.
|
Generates namespace separated C++ function API inline declaration/definitions.
|
||||||
@ -406,13 +411,13 @@ def get_ns_grouped_kernels(
|
|||||||
kernel_index: ETKernelIndex,
|
kernel_index: ETKernelIndex,
|
||||||
native_function_decl_gen: Callable[
|
native_function_decl_gen: Callable[
|
||||||
[
|
[
|
||||||
Union[NativeFunctionsGroup, NativeFunction],
|
NativeFunctionsGroup | NativeFunction,
|
||||||
ETKernelIndex,
|
ETKernelIndex,
|
||||||
],
|
],
|
||||||
List[str],
|
list[str],
|
||||||
],
|
],
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
native_function_namespaces = set()
|
native_function_namespaces = set()
|
||||||
op_kernels = kernel_index.get_kernels(f)
|
op_kernels = kernel_index.get_kernels(f)
|
||||||
@ -595,7 +600,7 @@ def gen_custom_ops(
|
|||||||
def translate_native_yaml(
|
def translate_native_yaml(
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
aten_yaml_path: str,
|
aten_yaml_path: str,
|
||||||
native_yaml_path: Optional[str],
|
native_yaml_path: str | None,
|
||||||
use_aten_lib: bool,
|
use_aten_lib: bool,
|
||||||
out_file: TextIO,
|
out_file: TextIO,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -646,15 +651,15 @@ def translate_native_yaml(
|
|||||||
skip_native_fns_gen=False,
|
skip_native_fns_gen=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
func_to_scoped_name: Dict[FunctionSchema, str] = {
|
func_to_scoped_name: dict[FunctionSchema, str] = {
|
||||||
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
|
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
|
||||||
}
|
}
|
||||||
op_to_scoped_name: Dict[OperatorName, str] = {
|
op_to_scoped_name: dict[OperatorName, str] = {
|
||||||
func.name: name for func, name in func_to_scoped_name.items()
|
func.name: name for func, name in func_to_scoped_name.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
|
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
|
||||||
kernel_persist_dict: Dict[str, Dict[str, Any]] = {
|
kernel_persist_dict: dict[str, dict[str, Any]] = {
|
||||||
op_to_scoped_name[op]: v for op, v in persisted_fields.items()
|
op_to_scoped_name[op]: v for op, v in persisted_fields.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -692,13 +697,13 @@ def translate_native_yaml(
|
|||||||
|
|
||||||
|
|
||||||
def parse_yaml(
|
def parse_yaml(
|
||||||
path: Optional[str],
|
path: str | None,
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
function_filter: Callable[[NativeFunction], bool],
|
function_filter: Callable[[NativeFunction], bool],
|
||||||
skip_native_fns_gen: bool = False,
|
skip_native_fns_gen: bool = False,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
List[NativeFunction],
|
list[NativeFunction],
|
||||||
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex],
|
dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
|
||||||
]:
|
]:
|
||||||
if path and os.path.exists(path) and os.stat(path).st_size > 0:
|
if path and os.path.exists(path) and os.stat(path).st_size > 0:
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
@ -735,8 +740,8 @@ def parse_yaml(
|
|||||||
|
|
||||||
# (2) Return BackendIndices if kernel index is absent
|
# (2) Return BackendIndices if kernel index is absent
|
||||||
def map_index(
|
def map_index(
|
||||||
m: Dict[OperatorName, BackendMetadata]
|
m: dict[OperatorName, BackendMetadata]
|
||||||
) -> Dict[OperatorName, BackendMetadata]:
|
) -> dict[OperatorName, BackendMetadata]:
|
||||||
return {op: m[op] for op in m if op in op_names}
|
return {op: m[op] for op in m if op in op_names}
|
||||||
|
|
||||||
backend_indices = {
|
backend_indices = {
|
||||||
@ -751,11 +756,11 @@ def parse_yaml(
|
|||||||
def parse_yaml_files(
|
def parse_yaml_files(
|
||||||
tags_yaml_path: str,
|
tags_yaml_path: str,
|
||||||
aten_yaml_path: str,
|
aten_yaml_path: str,
|
||||||
native_yaml_path: Optional[str],
|
native_yaml_path: str | None,
|
||||||
custom_ops_yaml_path: Optional[str],
|
custom_ops_yaml_path: str | None,
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
use_aten_lib: bool,
|
use_aten_lib: bool,
|
||||||
) -> Tuple[ETParsedYaml, Optional[ETParsedYaml]]:
|
) -> tuple[ETParsedYaml, ETParsedYaml | None]:
|
||||||
"""Parses functions.yaml and custom_ops.yaml files.
|
"""Parses functions.yaml and custom_ops.yaml files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -978,7 +983,7 @@ def main() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if options.output_dependencies:
|
if options.output_dependencies:
|
||||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
depfile_path = Path(options.output_dependencies).resolve()
|
||||||
depfile_name = depfile_path.name
|
depfile_name = depfile_path.name
|
||||||
depfile_stem = depfile_path.stem
|
depfile_stem = depfile_path.stem
|
||||||
|
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
from torchgen.api import cpp, dispatcher
|
from torchgen.api import cpp, dispatcher
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
@ -46,10 +48,13 @@ from torchgen.native_function_generation import (
|
|||||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||||
)
|
)
|
||||||
from torchgen.selective_build.selector import SelectiveBuilder
|
|
||||||
from torchgen.utils import dataclass_repr
|
from torchgen.utils import dataclass_repr
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.selective_build.selector import SelectiveBuilder
|
||||||
|
|
||||||
|
|
||||||
# Note: [Mutable Ops Not Using Functionalization]
|
# Note: [Mutable Ops Not Using Functionalization]
|
||||||
# Ops in this list currently do not work with functionalization and should be fixed.
|
# Ops in this list currently do not work with functionalization and should be fixed.
|
||||||
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
|
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
|
||||||
@ -88,7 +93,7 @@ class GenCompositeViewCopyKernel:
|
|||||||
backend_index: BackendIndex
|
backend_index: BackendIndex
|
||||||
|
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]:
|
def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
|
||||||
if g.view_copy is None:
|
if g.view_copy is None:
|
||||||
return None
|
return None
|
||||||
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
|
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
|
||||||
@ -160,7 +165,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
|
||||||
assert len(rets) == len(names)
|
assert len(rets) == len(names)
|
||||||
if len(rets) == 0:
|
if len(rets) == 0:
|
||||||
return ""
|
return ""
|
||||||
@ -184,7 +189,7 @@ def wrapper_name(func: FunctionSchema) -> str:
|
|||||||
return cpp.name(func)
|
return cpp.name(func)
|
||||||
|
|
||||||
|
|
||||||
def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> bool:
|
def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
|
||||||
return isinstance(a, SelfArgument) or (
|
return isinstance(a, SelfArgument) or (
|
||||||
isinstance(a, Argument) and a.type.is_tensor_like()
|
isinstance(a, Argument) and a.type.is_tensor_like()
|
||||||
)
|
)
|
||||||
@ -194,7 +199,7 @@ def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) ->
|
|||||||
# Some op schemas include non-owning types though (like TensorList),
|
# Some op schemas include non-owning types though (like TensorList),
|
||||||
# and when we unwrap them we expect to get out an owning type!.
|
# and when we unwrap them we expect to get out an owning type!.
|
||||||
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
|
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
|
||||||
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
|
def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
|
||||||
if t == BaseCType(tensorListT):
|
if t == BaseCType(tensorListT):
|
||||||
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
|
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
|
||||||
if t == BaseCType(iTensorListRefT):
|
if t == BaseCType(iTensorListRefT):
|
||||||
@ -209,9 +214,9 @@ def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
|
|||||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||||
def unwrap_tensor_args(
|
def unwrap_tensor_args(
|
||||||
sig: DispatcherSignature, *, is_view_op: bool
|
sig: DispatcherSignature, *, is_view_op: bool
|
||||||
) -> Tuple[str, List[Binding]]:
|
) -> tuple[str, list[Binding]]:
|
||||||
context: List[Binding] = []
|
context: list[Binding] = []
|
||||||
unwrapped_tensor_args: List[str] = []
|
unwrapped_tensor_args: list[str] = []
|
||||||
for arg in sig.arguments():
|
for arg in sig.arguments():
|
||||||
if is_tensor_like(arg.argument):
|
if is_tensor_like(arg.argument):
|
||||||
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
||||||
@ -247,9 +252,9 @@ def unwrap_tensor_args(
|
|||||||
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
|
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
|
||||||
# (1) a string containing all of the logic that does the conversions.
|
# (1) a string containing all of the logic that does the conversions.
|
||||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
|
||||||
context: List[Binding] = []
|
context: list[Binding] = []
|
||||||
unwrapped_tensor_args: List[str] = []
|
unwrapped_tensor_args: list[str] = []
|
||||||
for arg in sig.arguments():
|
for arg in sig.arguments():
|
||||||
if is_tensor_like(arg.argument):
|
if is_tensor_like(arg.argument):
|
||||||
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
||||||
@ -317,7 +322,7 @@ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
|
|||||||
|
|
||||||
# Detects whether any of the SymInt arguments are, in fact, symbolic values.
|
# Detects whether any of the SymInt arguments are, in fact, symbolic values.
|
||||||
# This is used in the constructor of ViewMeta.
|
# This is used in the constructor of ViewMeta.
|
||||||
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]:
|
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
|
||||||
name = "has_symbolic_inputs"
|
name = "has_symbolic_inputs"
|
||||||
statements = [
|
statements = [
|
||||||
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
|
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
|
||||||
@ -522,7 +527,7 @@ def maybe_create_output(f: NativeFunction, var_name: str) -> str:
|
|||||||
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
|
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
|
||||||
def get_mutable_redispatch_return_names(
|
def get_mutable_redispatch_return_names(
|
||||||
f: NativeFunction, inner_return_var: str
|
f: NativeFunction, inner_return_var: str
|
||||||
) -> Tuple[List[str], List[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
aliased_returns = []
|
aliased_returns = []
|
||||||
non_aliased_returns = []
|
non_aliased_returns = []
|
||||||
for i, name in enumerate(f.func.aliased_return_names()):
|
for i, name in enumerate(f.func.aliased_return_names()):
|
||||||
@ -751,11 +756,11 @@ def emit_inplace_functionalization_body(
|
|||||||
# See Note [Functionalization Pass: View Inverses].
|
# See Note [Functionalization Pass: View Inverses].
|
||||||
def gen_functionalization_view_inverse_declaration(
|
def gen_functionalization_view_inverse_declaration(
|
||||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||||
) -> Optional[str]:
|
) -> str | None:
|
||||||
# For every (non-composite) view op, we need a corresponding "inverse view" function.
|
# For every (non-composite) view op, we need a corresponding "inverse view" function.
|
||||||
# This generates the declarations so we get a good compiler error when someone adds a new view.
|
# This generates the declarations so we get a good compiler error when someone adds a new view.
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]:
|
def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
|
||||||
if g.view.has_composite_implicit_autograd_kernel:
|
if g.view.has_composite_implicit_autograd_kernel:
|
||||||
return None
|
return None
|
||||||
view_inverse_sig = ViewInverseSignature(g)
|
view_inverse_sig = ViewInverseSignature(g)
|
||||||
@ -766,9 +771,9 @@ def gen_functionalization_view_inverse_declaration(
|
|||||||
|
|
||||||
def gen_functionalization_registration(
|
def gen_functionalization_registration(
|
||||||
selector: SelectiveBuilder,
|
selector: SelectiveBuilder,
|
||||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
composite_implicit_autograd_index: BackendIndex,
|
composite_implicit_autograd_index: BackendIndex,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def emit_registration_helper(f: NativeFunction) -> str:
|
def emit_registration_helper(f: NativeFunction) -> str:
|
||||||
assert not f.has_composite_implicit_autograd_kernel
|
assert not f.has_composite_implicit_autograd_kernel
|
||||||
@ -832,8 +837,8 @@ def gen_functionalization_definition(
|
|||||||
# (and instead only need to operate on grouped NativeFunctions).
|
# (and instead only need to operate on grouped NativeFunctions).
|
||||||
# The only reason currently is because we need to emit direct dispatch registrations
|
# The only reason currently is because we need to emit direct dispatch registrations
|
||||||
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
|
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
|
||||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||||
) -> List[str]:
|
) -> list[str]:
|
||||||
# Don't generate kernels in mobile build
|
# Don't generate kernels in mobile build
|
||||||
if not selector.include_all_operators:
|
if not selector.include_all_operators:
|
||||||
return []
|
return []
|
||||||
|
@ -1,19 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import (
|
from typing import Any, Callable, Iterable, Iterator, Sequence
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -102,8 +93,8 @@ ParsedExternalYaml = namedtuple(
|
|||||||
|
|
||||||
def parse_native_functions_keys(
|
def parse_native_functions_keys(
|
||||||
backend_yaml_path: str,
|
backend_yaml_path: str,
|
||||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||||
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:
|
) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
|
||||||
with open(backend_yaml_path) as f:
|
with open(backend_yaml_path) as f:
|
||||||
yaml_values = yaml.load(f, Loader=YamlLoader)
|
yaml_values = yaml.load(f, Loader=YamlLoader)
|
||||||
assert isinstance(yaml_values, dict)
|
assert isinstance(yaml_values, dict)
|
||||||
@ -120,7 +111,7 @@ def parse_native_functions_keys(
|
|||||||
|
|
||||||
|
|
||||||
def validate_shape_inference_header(
|
def validate_shape_inference_header(
|
||||||
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
|
shape_inference_hdr: str, expected_shape_infr_decls: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
with open(shape_inference_hdr) as f:
|
with open(shape_inference_hdr) as f:
|
||||||
@ -180,12 +171,12 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
|||||||
|
|
||||||
class default_args:
|
class default_args:
|
||||||
node_base: str = "Node"
|
node_base: str = "Node"
|
||||||
node_base_hdr: Optional[str] = None
|
node_base_hdr: str | None = None
|
||||||
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
|
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
|
||||||
tensor_class: str = "torch::lazy::LazyTensor"
|
tensor_class: str = "torch::lazy::LazyTensor"
|
||||||
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
|
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
|
||||||
lazy_ir_generator: Type[GenLazyIR] = GenLazyIR
|
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
|
||||||
native_func_definition_generator: Type[
|
native_func_definition_generator: type[
|
||||||
GenLazyNativeFuncDefinition
|
GenLazyNativeFuncDefinition
|
||||||
] = GenLazyNativeFuncDefinition
|
] = GenLazyNativeFuncDefinition
|
||||||
backend_name: str = "TorchScript"
|
backend_name: str = "TorchScript"
|
||||||
@ -263,10 +254,10 @@ def main() -> None:
|
|||||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||||
torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
|
torch_root = pathlib.Path(__file__).parent.parent.parent.absolute()
|
||||||
aten_path = str(torch_root / "aten" / "src" / "ATen")
|
aten_path = str(torch_root / "aten" / "src" / "ATen")
|
||||||
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
|
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
|
||||||
if options.gen_ts_lowerings:
|
if options.gen_ts_lowerings:
|
||||||
lazy_ir_generator = GenTSLazyIR
|
lazy_ir_generator = GenTSLazyIR
|
||||||
native_func_definition_generator: Type[
|
native_func_definition_generator: type[
|
||||||
GenLazyNativeFuncDefinition
|
GenLazyNativeFuncDefinition
|
||||||
] = default_args.native_func_definition_generator
|
] = default_args.native_func_definition_generator
|
||||||
|
|
||||||
@ -292,14 +283,14 @@ def run_gen_lazy_tensor(
|
|||||||
source_yaml: str,
|
source_yaml: str,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
dry_run: bool,
|
dry_run: bool,
|
||||||
impl_path: Optional[str],
|
impl_path: str | None,
|
||||||
node_base: str = default_args.node_base,
|
node_base: str = default_args.node_base,
|
||||||
node_base_hdr: Optional[str] = default_args.node_base_hdr,
|
node_base_hdr: str | None = default_args.node_base_hdr,
|
||||||
tensor_class: str = default_args.tensor_class,
|
tensor_class: str = default_args.tensor_class,
|
||||||
tensor_class_hdr: str = default_args.tensor_class_hdr,
|
tensor_class_hdr: str = default_args.tensor_class_hdr,
|
||||||
shape_inference_hdr: str = default_args.shape_inference_hdr,
|
shape_inference_hdr: str = default_args.shape_inference_hdr,
|
||||||
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator,
|
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
|
||||||
native_func_definition_generator: Type[
|
native_func_definition_generator: type[
|
||||||
GenLazyNativeFuncDefinition
|
GenLazyNativeFuncDefinition
|
||||||
] = default_args.native_func_definition_generator,
|
] = default_args.native_func_definition_generator,
|
||||||
# build_in_tree is true for TS backend and affects include paths
|
# build_in_tree is true for TS backend and affects include paths
|
||||||
@ -347,7 +338,7 @@ def run_gen_lazy_tensor(
|
|||||||
)
|
)
|
||||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||||
|
|
||||||
def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
|
||||||
"""
|
"""
|
||||||
We sort the native function because of the note in concat_map_codegen.
|
We sort the native function because of the note in concat_map_codegen.
|
||||||
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
|
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
|
||||||
@ -377,8 +368,8 @@ def run_gen_lazy_tensor(
|
|||||||
|
|
||||||
def concat_map_codegen(
|
def concat_map_codegen(
|
||||||
func: Callable[[NativeFunction], Sequence[str]],
|
func: Callable[[NativeFunction], Sequence[str]],
|
||||||
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
|
xs: Iterable[NativeFunctionsGroup | NativeFunction],
|
||||||
ops_list: List[OperatorName] = full_codegen,
|
ops_list: list[OperatorName] = full_codegen,
|
||||||
) -> Iterator[str]:
|
) -> Iterator[str]:
|
||||||
"""
|
"""
|
||||||
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
|
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Sequence, Tuple
|
from typing import Sequence
|
||||||
|
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
from torchgen.api.types import DispatcherSignature
|
from torchgen.api.types import DispatcherSignature
|
||||||
@ -32,7 +34,7 @@ def is_tensor_list(typ: Type) -> bool:
|
|||||||
return isinstance(typ, ListType) and is_tensor(typ.elem)
|
return isinstance(typ, ListType) and is_tensor(typ.elem)
|
||||||
|
|
||||||
|
|
||||||
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||||
result = f"""\
|
result = f"""\
|
||||||
Tensor {name}_value;
|
Tensor {name}_value;
|
||||||
optional<int64_t> {name}_bdim;
|
optional<int64_t> {name}_bdim;
|
||||||
@ -40,7 +42,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
|||||||
return textwrap.dedent(result).split("\n")
|
return textwrap.dedent(result).split("\n")
|
||||||
|
|
||||||
|
|
||||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||||
result = f"""\
|
result = f"""\
|
||||||
optional<Tensor> {name}_value;
|
optional<Tensor> {name}_value;
|
||||||
optional<int64_t> {name}_bdim;
|
optional<int64_t> {name}_bdim;
|
||||||
@ -52,7 +54,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
|||||||
|
|
||||||
def gen_unwraps(
|
def gen_unwraps(
|
||||||
flat_arguments: Sequence[Argument], cur_level_var: str
|
flat_arguments: Sequence[Argument], cur_level_var: str
|
||||||
) -> Tuple[str, List[str]]:
|
) -> tuple[str, list[str]]:
|
||||||
arg_names = [a.name for a in flat_arguments]
|
arg_names = [a.name for a in flat_arguments]
|
||||||
arg_types = [a.type for a in flat_arguments]
|
arg_types = [a.type for a in flat_arguments]
|
||||||
|
|
||||||
@ -99,7 +101,7 @@ if ({' && '.join(conditions)}) {{
|
|||||||
|
|
||||||
|
|
||||||
def gen_returns(
|
def gen_returns(
|
||||||
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
|
returns: tuple[Return, ...], cur_level_var: str, results_var: str
|
||||||
) -> str:
|
) -> str:
|
||||||
idx = 0
|
idx = 0
|
||||||
wrapped_returns = []
|
wrapped_returns = []
|
||||||
@ -132,7 +134,7 @@ def is_mutated_arg(argument: Argument) -> bool:
|
|||||||
return argument.annotation is not None and argument.annotation.is_write
|
return argument.annotation is not None and argument.annotation.is_write
|
||||||
|
|
||||||
|
|
||||||
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
||||||
# Assumptions:
|
# Assumptions:
|
||||||
# - only one argument is being modified in-place
|
# - only one argument is being modified in-place
|
||||||
# - the argument that is being modified in-place is the first argument
|
# - the argument that is being modified in-place is the first argument
|
||||||
@ -197,7 +199,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
|||||||
}}"""
|
}}"""
|
||||||
|
|
||||||
|
|
||||||
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
|
||||||
schema = native_function.func
|
schema = native_function.func
|
||||||
sig = DispatcherSignature.from_schema(schema)
|
sig = DispatcherSignature.from_schema(schema)
|
||||||
returns = schema.returns
|
returns = schema.returns
|
||||||
@ -244,7 +246,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ComputeBatchRulePlumbing:
|
class ComputeBatchRulePlumbing:
|
||||||
@method_with_native_function
|
@method_with_native_function
|
||||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
def __call__(self, f: NativeFunction) -> str | None:
|
||||||
result = gen_vmap_plumbing(f)
|
result = gen_vmap_plumbing(f)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Iterator, Optional
|
from typing import Iterator
|
||||||
|
|
||||||
|
|
||||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||||
@ -17,8 +19,8 @@ from typing import Iterator, Optional
|
|||||||
|
|
||||||
|
|
||||||
class Locals(threading.local):
|
class Locals(threading.local):
|
||||||
use_const_ref_for_mutable_tensors: Optional[bool] = None
|
use_const_ref_for_mutable_tensors: bool | None = None
|
||||||
use_ilistref_for_tensor_lists: Optional[bool] = None
|
use_ilistref_for_tensor_lists: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
_locals = Locals()
|
_locals = Locals()
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
from typing import Callable, Iterator, Sequence
|
||||||
|
|
||||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||||||
|
|
||||||
@ -229,7 +231,7 @@ class DispatchKey(Enum):
|
|||||||
return str(self).lower()
|
return str(self).lower()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(value: str) -> "DispatchKey":
|
def parse(value: str) -> DispatchKey:
|
||||||
for k, v in DispatchKey.__members__.items():
|
for k, v in DispatchKey.__members__.items():
|
||||||
if k == value:
|
if k == value:
|
||||||
return v
|
return v
|
||||||
@ -350,20 +352,20 @@ class ScalarType(Enum):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_parse(value: str) -> Optional["ScalarType"]:
|
def maybe_parse(value: str) -> ScalarType | None:
|
||||||
for k, v in ScalarType.__members__.items():
|
for k, v in ScalarType.__members__.items():
|
||||||
if k == value:
|
if k == value:
|
||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(value: str) -> "ScalarType":
|
def parse(value: str) -> ScalarType:
|
||||||
mb_r = ScalarType.maybe_parse(value)
|
mb_r = ScalarType.maybe_parse(value)
|
||||||
assert mb_r is not None, f"unknown dtype {value}"
|
assert mb_r is not None, f"unknown dtype {value}"
|
||||||
return mb_r
|
return mb_r
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_set(values: str) -> OrderedSet["ScalarType"]:
|
def parse_set(values: str) -> OrderedSet[ScalarType]:
|
||||||
dtypes: OrderedSet[ScalarType] = OrderedSet()
|
dtypes: OrderedSet[ScalarType] = OrderedSet()
|
||||||
for value in values.split(", "):
|
for value in values.split(", "):
|
||||||
if value in DTYPE_CLASSES:
|
if value in DTYPE_CLASSES:
|
||||||
@ -373,7 +375,7 @@ class ScalarType(Enum):
|
|||||||
return dtypes
|
return dtypes
|
||||||
|
|
||||||
|
|
||||||
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
|
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
|
||||||
# NB: Integral doesn't include boolean
|
# NB: Integral doesn't include boolean
|
||||||
DTYPE_CLASSES["Integral"] = OrderedSet(
|
DTYPE_CLASSES["Integral"] = OrderedSet(
|
||||||
[
|
[
|
||||||
@ -419,7 +421,7 @@ class UfuncKey(Enum):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(value: str) -> "UfuncKey":
|
def parse(value: str) -> UfuncKey:
|
||||||
for k, v in UfuncKey.__members__.items():
|
for k, v in UfuncKey.__members__.items():
|
||||||
if k == value:
|
if k == value:
|
||||||
return v
|
return v
|
||||||
@ -462,7 +464,7 @@ class NativeFunction:
|
|||||||
# (This type is quoted as we are forward referencing a type
|
# (This type is quoted as we are forward referencing a type
|
||||||
# defined later in the file. I opted for this ordering of the
|
# defined later in the file. I opted for this ordering of the
|
||||||
# classes for expository clarity.)
|
# classes for expository clarity.)
|
||||||
func: "FunctionSchema"
|
func: FunctionSchema
|
||||||
|
|
||||||
# Whether or not to generate mutable tensor arguments like regular
|
# Whether or not to generate mutable tensor arguments like regular
|
||||||
# ones
|
# ones
|
||||||
@ -475,14 +477,14 @@ class NativeFunction:
|
|||||||
device_check: DeviceCheckType
|
device_check: DeviceCheckType
|
||||||
|
|
||||||
# What python module to put the function in
|
# What python module to put the function in
|
||||||
python_module: Optional[str]
|
python_module: str | None
|
||||||
|
|
||||||
# TODO: figure out what this does
|
# TODO: figure out what this does
|
||||||
category_override: Optional[str]
|
category_override: str | None
|
||||||
|
|
||||||
# If no variants are specified in native_functions.yaml, this is
|
# If no variants are specified in native_functions.yaml, this is
|
||||||
# assumed to be {'function'}.
|
# assumed to be {'function'}.
|
||||||
variants: Set[Variant]
|
variants: set[Variant]
|
||||||
|
|
||||||
# Whether or not we should skip generating registrations for
|
# Whether or not we should skip generating registrations for
|
||||||
# this kernel. This is a bit of a double-edged sword, as manual
|
# this kernel. This is a bit of a double-edged sword, as manual
|
||||||
@ -497,7 +499,7 @@ class NativeFunction:
|
|||||||
|
|
||||||
# The location in the YAML file were this native function entry was
|
# The location in the YAML file were this native function entry was
|
||||||
# defined. This is for conveniently reporting error messages!
|
# defined. This is for conveniently reporting error messages!
|
||||||
loc: "Location"
|
loc: Location
|
||||||
|
|
||||||
# A list of operators that are expected to be auto-generated for this NativeFunction.
|
# A list of operators that are expected to be auto-generated for this NativeFunction.
|
||||||
# Note: This list isn't actually directly used by the codegen to generate anything.
|
# Note: This list isn't actually directly used by the codegen to generate anything.
|
||||||
@ -505,11 +507,11 @@ class NativeFunction:
|
|||||||
# function schema, and uses the autogen declarations to error check.
|
# function schema, and uses the autogen declarations to error check.
|
||||||
# We expect every NativeFunction that gets auto-generated be explicitly called out
|
# We expect every NativeFunction that gets auto-generated be explicitly called out
|
||||||
# in native_functions.yaml
|
# in native_functions.yaml
|
||||||
autogen: List["OperatorName"]
|
autogen: list[OperatorName]
|
||||||
|
|
||||||
# If non-empty, this kernel is subject to ufunc codegen.
|
# If non-empty, this kernel is subject to ufunc codegen.
|
||||||
# Sorted by ufunc_key
|
# Sorted by ufunc_key
|
||||||
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
|
ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
|
||||||
|
|
||||||
# Whether or not this out functions is a "structured kernel". Structured
|
# Whether or not this out functions is a "structured kernel". Structured
|
||||||
# kernels are defined a little differently from normal kernels; in
|
# kernels are defined a little differently from normal kernels; in
|
||||||
@ -522,13 +524,13 @@ class NativeFunction:
|
|||||||
|
|
||||||
# Whether or not this non-out function is a structured kernel, defined
|
# Whether or not this non-out function is a structured kernel, defined
|
||||||
# in terms of the out kernel referenced by the string here.
|
# in terms of the out kernel referenced by the string here.
|
||||||
structured_delegate: Optional["OperatorName"]
|
structured_delegate: OperatorName | None
|
||||||
|
|
||||||
# Only valid for structured kernels. Specifies alternative of what
|
# Only valid for structured kernels. Specifies alternative of what
|
||||||
# to inherit from when defining the meta class for the structured
|
# to inherit from when defining the meta class for the structured
|
||||||
# operator. This will usually be TensorIteratorBase. This also
|
# operator. This will usually be TensorIteratorBase. This also
|
||||||
# changes the semantics of set_output to call the parent class.
|
# changes the semantics of set_output to call the parent class.
|
||||||
structured_inherits: Optional[str]
|
structured_inherits: str | None
|
||||||
|
|
||||||
# Structured kernels can declare elements as "precomputed". These elements
|
# Structured kernels can declare elements as "precomputed". These elements
|
||||||
# are returned by the meta function in one struct and passed to the impl
|
# are returned by the meta function in one struct and passed to the impl
|
||||||
@ -536,11 +538,11 @@ class NativeFunction:
|
|||||||
# elements supersede. Information about the names and types of these
|
# elements supersede. Information about the names and types of these
|
||||||
# precomputed elements and how they correspond to kernel arguments is stored
|
# precomputed elements and how they correspond to kernel arguments is stored
|
||||||
# in this member, if applicable.
|
# in this member, if applicable.
|
||||||
precomputed: Optional["Precompute"]
|
precomputed: Precompute | None
|
||||||
|
|
||||||
# Argument names whose default should be excluded from the C++ interface.
|
# Argument names whose default should be excluded from the C++ interface.
|
||||||
# Intended for resolving overload ambiguities between signatures.
|
# Intended for resolving overload ambiguities between signatures.
|
||||||
cpp_no_default_args: Set[str]
|
cpp_no_default_args: set[str]
|
||||||
|
|
||||||
# Note [Abstract ATen methods]
|
# Note [Abstract ATen methods]
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
@ -560,7 +562,7 @@ class NativeFunction:
|
|||||||
|
|
||||||
# Tags are used to describe semantic information about (groups of) operators,
|
# Tags are used to describe semantic information about (groups of) operators,
|
||||||
# That aren't easily inferrable directly from the operator's schema.
|
# That aren't easily inferrable directly from the operator's schema.
|
||||||
tags: Set[str]
|
tags: set[str]
|
||||||
|
|
||||||
# NB: The benefit of defining a dataclass is that we automatically get
|
# NB: The benefit of defining a dataclass is that we automatically get
|
||||||
# a constructor defined for all the fields we specify. No need
|
# a constructor defined for all the fields we specify. No need
|
||||||
@ -569,13 +571,11 @@ class NativeFunction:
|
|||||||
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
|
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_yaml(
|
def from_yaml(
|
||||||
ei: Dict[str, object],
|
ei: dict[str, object],
|
||||||
loc: "Location",
|
loc: Location,
|
||||||
valid_tags: Set[str],
|
valid_tags: set[str],
|
||||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
ignore_keys: set[DispatchKey] | None = None,
|
||||||
) -> Tuple[
|
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
|
||||||
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
|
|
||||||
]:
|
|
||||||
"""
|
"""
|
||||||
Parse a NativeFunction from a dictionary as directly parsed
|
Parse a NativeFunction from a dictionary as directly parsed
|
||||||
from native_functions.yaml
|
from native_functions.yaml
|
||||||
@ -602,7 +602,7 @@ class NativeFunction:
|
|||||||
|
|
||||||
variants_s = e.pop("variants", "function")
|
variants_s = e.pop("variants", "function")
|
||||||
assert isinstance(variants_s, str)
|
assert isinstance(variants_s, str)
|
||||||
variants: Set[Variant] = set()
|
variants: set[Variant] = set()
|
||||||
for v in variants_s.split(", "):
|
for v in variants_s.split(", "):
|
||||||
if v == "function":
|
if v == "function":
|
||||||
variants.add(Variant.function)
|
variants.add(Variant.function)
|
||||||
@ -646,7 +646,7 @@ class NativeFunction:
|
|||||||
"namespace is not supported in structured delegate,"
|
"namespace is not supported in structured delegate,"
|
||||||
" using the same namespace as the native function"
|
" using the same namespace as the native function"
|
||||||
)
|
)
|
||||||
structured_delegate: Optional[OperatorName] = None
|
structured_delegate: OperatorName | None = None
|
||||||
if structured_delegate_s is not None:
|
if structured_delegate_s is not None:
|
||||||
structured_delegate = OperatorName.parse(structured_delegate_s)
|
structured_delegate = OperatorName.parse(structured_delegate_s)
|
||||||
|
|
||||||
@ -685,7 +685,7 @@ class NativeFunction:
|
|||||||
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
||||||
tags_inp.append("pt2_compliant_tag")
|
tags_inp.append("pt2_compliant_tag")
|
||||||
|
|
||||||
tags: Set[str] = set()
|
tags: set[str] = set()
|
||||||
for t in tags_inp:
|
for t in tags_inp:
|
||||||
assert len(valid_tags) > 0
|
assert len(valid_tags) > 0
|
||||||
# TODO: verify that the tag is valid and has an entry in tags.yaml
|
# TODO: verify that the tag is valid and has an entry in tags.yaml
|
||||||
@ -698,7 +698,7 @@ class NativeFunction:
|
|||||||
|
|
||||||
raw_dispatch = e.pop("dispatch", None)
|
raw_dispatch = e.pop("dispatch", None)
|
||||||
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||||||
dispatch: Dict[DispatchKey, BackendMetadata] = {}
|
dispatch: dict[DispatchKey, BackendMetadata] = {}
|
||||||
num_dispatch_keys: int = 0
|
num_dispatch_keys: int = 0
|
||||||
if raw_dispatch is not None:
|
if raw_dispatch is not None:
|
||||||
assert not manual_kernel_registration, (
|
assert not manual_kernel_registration, (
|
||||||
@ -1081,8 +1081,8 @@ class SchemaKind(Enum):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NativeFunctionsGroup:
|
class NativeFunctionsGroup:
|
||||||
functional: NativeFunction
|
functional: NativeFunction
|
||||||
inplace: Optional[NativeFunction]
|
inplace: NativeFunction | None
|
||||||
mutable: Optional[NativeFunction]
|
mutable: NativeFunction | None
|
||||||
out: NativeFunction
|
out: NativeFunction
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1136,7 +1136,7 @@ class NativeFunctionsGroup:
|
|||||||
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
|
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
|
||||||
)
|
)
|
||||||
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
||||||
expected_generated_fns: Set[str] = set()
|
expected_generated_fns: set[str] = set()
|
||||||
for f in self.functions():
|
for f in self.functions():
|
||||||
expected_generated_fns.update(str(op) for op in f.autogen)
|
expected_generated_fns.update(str(op) for op in f.autogen)
|
||||||
expected_generated_fns_str = ", ".join(
|
expected_generated_fns_str = ", ".join(
|
||||||
@ -1155,7 +1155,7 @@ class NativeFunctionsGroup:
|
|||||||
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
|
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
|
||||||
)
|
)
|
||||||
|
|
||||||
def signature(self) -> "FunctionSchema":
|
def signature(self) -> FunctionSchema:
|
||||||
return self.out.func.signature()
|
return self.out.func.signature()
|
||||||
|
|
||||||
def functions(self) -> Iterator[NativeFunction]:
|
def functions(self) -> Iterator[NativeFunction]:
|
||||||
@ -1171,9 +1171,7 @@ class NativeFunctionsGroup:
|
|||||||
return self.functional.root_name
|
return self.functional.root_name
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(
|
def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
|
||||||
d: Dict[SchemaKind, NativeFunction]
|
|
||||||
) -> Optional["NativeFunctionsGroup"]:
|
|
||||||
assert d
|
assert d
|
||||||
if len(d) == 1:
|
if len(d) == 1:
|
||||||
return None
|
return None
|
||||||
@ -1229,7 +1227,7 @@ class UfuncInnerLoop:
|
|||||||
ufunc_key: UfuncKey
|
ufunc_key: UfuncKey
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
|
def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
|
||||||
name, supported_dtypes_str = value.split(" ", 1)
|
name, supported_dtypes_str = value.split(" ", 1)
|
||||||
assert supported_dtypes_str[0] == "("
|
assert supported_dtypes_str[0] == "("
|
||||||
assert supported_dtypes_str[-1] == ")"
|
assert supported_dtypes_str[-1] == ")"
|
||||||
@ -1261,12 +1259,12 @@ class BackendIndex:
|
|||||||
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
|
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
|
||||||
external: bool
|
external: bool
|
||||||
# Other backend-specific information that is on a per-operator basis
|
# Other backend-specific information that is on a per-operator basis
|
||||||
index: Dict["OperatorName", BackendMetadata]
|
index: dict[OperatorName, BackendMetadata]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def grow_index(
|
def grow_index(
|
||||||
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||||
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for k, v in child_index.items():
|
for k, v in child_index.items():
|
||||||
for op_name, metadata in v.items():
|
for op_name, metadata in v.items():
|
||||||
@ -1281,13 +1279,13 @@ class BackendIndex:
|
|||||||
else:
|
else:
|
||||||
return g.functional
|
return g.functional
|
||||||
|
|
||||||
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
|
def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
||||||
m = self.get_kernel(g)
|
m = self.get_kernel(g)
|
||||||
return m is not None
|
return m is not None
|
||||||
|
|
||||||
def get_kernel(
|
def get_kernel(
|
||||||
self, g: Union[NativeFunction, NativeFunctionsGroup]
|
self, g: NativeFunction | NativeFunctionsGroup
|
||||||
) -> Optional[BackendMetadata]:
|
) -> BackendMetadata | None:
|
||||||
if isinstance(g, NativeFunction):
|
if isinstance(g, NativeFunction):
|
||||||
f = g
|
f = g
|
||||||
elif isinstance(g, NativeFunctionsGroup):
|
elif isinstance(g, NativeFunctionsGroup):
|
||||||
@ -1298,7 +1296,7 @@ class BackendIndex:
|
|||||||
return None
|
return None
|
||||||
return self.index[f.func.name]
|
return self.index[f.func.name]
|
||||||
|
|
||||||
def native_function_class_name(self) -> Optional[str]:
|
def native_function_class_name(self) -> str | None:
|
||||||
if self.external:
|
if self.external:
|
||||||
return f"{str(self.dispatch_key)}NativeFunctions"
|
return f"{str(self.dispatch_key)}NativeFunctions"
|
||||||
else:
|
else:
|
||||||
@ -1364,16 +1362,16 @@ class BackendIndex:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FunctionSchema:
|
class FunctionSchema:
|
||||||
# The name of the operator this function schema describes.
|
# The name of the operator this function schema describes.
|
||||||
name: "OperatorName"
|
name: OperatorName
|
||||||
|
|
||||||
arguments: "Arguments"
|
arguments: Arguments
|
||||||
|
|
||||||
# TODO: Need to handle collisions with argument names at some point
|
# TODO: Need to handle collisions with argument names at some point
|
||||||
returns: Tuple["Return", ...]
|
returns: tuple[Return, ...]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_mutable(self) -> bool:
|
def is_mutable(self) -> bool:
|
||||||
def is_write(arg: "Argument") -> bool:
|
def is_write(arg: Argument) -> bool:
|
||||||
if arg.annotation is None:
|
if arg.annotation is None:
|
||||||
return False
|
return False
|
||||||
return arg.annotation.is_write
|
return arg.annotation.is_write
|
||||||
@ -1382,7 +1380,7 @@ class FunctionSchema:
|
|||||||
# See aten/src/ATen/core/function_schema.h (keep these in sync)
|
# See aten/src/ATen/core/function_schema.h (keep these in sync)
|
||||||
return any(is_write(a) for a in self.arguments.flat_all)
|
return any(is_write(a) for a in self.arguments.flat_all)
|
||||||
|
|
||||||
def schema_order_arguments(self) -> Iterator["Argument"]:
|
def schema_order_arguments(self) -> Iterator[Argument]:
|
||||||
return itertools.chain(
|
return itertools.chain(
|
||||||
self.arguments.flat_positional,
|
self.arguments.flat_positional,
|
||||||
self.arguments.flat_kwarg_only,
|
self.arguments.flat_kwarg_only,
|
||||||
@ -1392,7 +1390,7 @@ class FunctionSchema:
|
|||||||
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(func: str) -> "FunctionSchema":
|
def parse(func: str) -> FunctionSchema:
|
||||||
# We should probably get a proper parser here
|
# We should probably get a proper parser here
|
||||||
decls = FunctionSchema.decl_re.findall(func)
|
decls = FunctionSchema.decl_re.findall(func)
|
||||||
assert len(decls) == 1, f"Invalid function schema: {func}"
|
assert len(decls) == 1, f"Invalid function schema: {func}"
|
||||||
@ -1587,8 +1585,8 @@ class FunctionSchema:
|
|||||||
# - If the return aliases an input, we return the input name
|
# - If the return aliases an input, we return the input name
|
||||||
# - Otherwise, we return None.
|
# - Otherwise, we return None.
|
||||||
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
|
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
|
||||||
def aliased_return_names(self) -> List[Optional[str]]:
|
def aliased_return_names(self) -> list[str | None]:
|
||||||
outs: List[Optional[str]] = []
|
outs: list[str | None] = []
|
||||||
for r in self.returns:
|
for r in self.returns:
|
||||||
aliased_args = [
|
aliased_args = [
|
||||||
a
|
a
|
||||||
@ -1612,7 +1610,7 @@ class FunctionSchema:
|
|||||||
strip_default: bool = False,
|
strip_default: bool = False,
|
||||||
strip_view_copy_name: bool = False,
|
strip_view_copy_name: bool = False,
|
||||||
keep_return_names: bool = False,
|
keep_return_names: bool = False,
|
||||||
) -> "FunctionSchema":
|
) -> FunctionSchema:
|
||||||
"""
|
"""
|
||||||
Certain schemas are 'related', in that they are simply
|
Certain schemas are 'related', in that they are simply
|
||||||
inplace/out/functional versions of the same function. This method
|
inplace/out/functional versions of the same function. This method
|
||||||
@ -1709,10 +1707,10 @@ class FunctionSchema:
|
|||||||
returns=returns,
|
returns=returns,
|
||||||
)
|
)
|
||||||
|
|
||||||
def view_signature(self) -> "FunctionSchema":
|
def view_signature(self) -> FunctionSchema:
|
||||||
return self.signature(strip_view_copy_name=True)
|
return self.signature(strip_view_copy_name=True)
|
||||||
|
|
||||||
def with_name(self, name: "OperatorName") -> "FunctionSchema":
|
def with_name(self, name: OperatorName) -> FunctionSchema:
|
||||||
return FunctionSchema(
|
return FunctionSchema(
|
||||||
name=name,
|
name=name,
|
||||||
arguments=self.arguments,
|
arguments=self.arguments,
|
||||||
@ -1747,12 +1745,12 @@ class FunctionSchema:
|
|||||||
class Annotation:
|
class Annotation:
|
||||||
# Typically only has one element. Not actually a set so
|
# Typically only has one element. Not actually a set so
|
||||||
# we can conveniently assume it is canonically ordered
|
# we can conveniently assume it is canonically ordered
|
||||||
alias_set: Tuple[str, ...]
|
alias_set: tuple[str, ...]
|
||||||
is_write: bool
|
is_write: bool
|
||||||
alias_set_after: Tuple[str, ...]
|
alias_set_after: tuple[str, ...]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(ann: str) -> "Annotation":
|
def parse(ann: str) -> Annotation:
|
||||||
# TODO: implement a proper parser if this gets more ugly
|
# TODO: implement a proper parser if this gets more ugly
|
||||||
# Regex Explanation:
|
# Regex Explanation:
|
||||||
# Example: "a! -> a|b"
|
# Example: "a! -> a|b"
|
||||||
@ -1805,13 +1803,13 @@ class Annotation:
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Type:
|
class Type:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(t: str) -> "Type":
|
def parse(t: str) -> Type:
|
||||||
r = Type._parse(t)
|
r = Type._parse(t)
|
||||||
assert str(r) == t, f"{r} != {t}"
|
assert str(r) == t, f"{r} != {t}"
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse(t: str) -> "Type":
|
def _parse(t: str) -> Type:
|
||||||
m = re.match(r"^(.+)\?$", t)
|
m = re.match(r"^(.+)\?$", t)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
return OptionalType(Type.parse(m.group(1)))
|
return OptionalType(Type.parse(m.group(1)))
|
||||||
@ -1837,7 +1835,7 @@ class Type:
|
|||||||
# so we can conveniently generate legacy Declarations.yaml but
|
# so we can conveniently generate legacy Declarations.yaml but
|
||||||
# really we should probably just remove these at some point
|
# really we should probably just remove these at some point
|
||||||
|
|
||||||
def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
|
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_tensor_like(self) -> bool:
|
def is_tensor_like(self) -> bool:
|
||||||
@ -1852,7 +1850,7 @@ class Type:
|
|||||||
def is_nullable(self) -> bool:
|
def is_nullable(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_list_like(self) -> Optional["ListType"]:
|
def is_list_like(self) -> ListType | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -1892,7 +1890,7 @@ class BaseType(Type):
|
|||||||
def is_nullable(self) -> bool:
|
def is_nullable(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_list_like(self) -> Optional["ListType"]:
|
def is_list_like(self) -> ListType | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def is_symint_like(self) -> bool:
|
def is_symint_like(self) -> bool:
|
||||||
@ -1916,7 +1914,7 @@ class OptionalType(Type):
|
|||||||
def is_nullable(self) -> bool:
|
def is_nullable(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def is_list_like(self) -> Optional["ListType"]:
|
def is_list_like(self) -> ListType | None:
|
||||||
return self.elem.is_list_like()
|
return self.elem.is_list_like()
|
||||||
|
|
||||||
|
|
||||||
@ -1943,7 +1941,7 @@ class CustomClassType(Type):
|
|||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_list_like(self) -> Optional["ListType"]:
|
def is_list_like(self) -> ListType | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -1957,7 +1955,7 @@ class CustomClassType(Type):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ListType(Type):
|
class ListType(Type):
|
||||||
elem: Type
|
elem: Type
|
||||||
size: Optional[int]
|
size: int | None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
size = f"{self.size}" if self.size else ""
|
size = f"{self.size}" if self.size else ""
|
||||||
@ -1972,7 +1970,7 @@ class ListType(Type):
|
|||||||
def is_nullable(self) -> bool:
|
def is_nullable(self) -> bool:
|
||||||
return self.elem.is_nullable()
|
return self.elem.is_nullable()
|
||||||
|
|
||||||
def is_list_like(self) -> Optional["ListType"]:
|
def is_list_like(self) -> ListType | None:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -1983,7 +1981,7 @@ class Argument:
|
|||||||
|
|
||||||
name: str
|
name: str
|
||||||
type: Type
|
type: Type
|
||||||
default: Optional[str]
|
default: str | None
|
||||||
|
|
||||||
# The semantics of the annotation field are a little strange.
|
# The semantics of the annotation field are a little strange.
|
||||||
#
|
#
|
||||||
@ -2004,16 +2002,16 @@ class Argument:
|
|||||||
# structure of annotated types is very simple. So we just hard
|
# structure of annotated types is very simple. So we just hard
|
||||||
# code it here. But if we ever do get anything more complex, this
|
# code it here. But if we ever do get anything more complex, this
|
||||||
# model will have to change!
|
# model will have to change!
|
||||||
annotation: Optional[Annotation]
|
annotation: Annotation | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias_info(self) -> Optional[Annotation]:
|
def alias_info(self) -> Annotation | None:
|
||||||
return self.annotation
|
return self.annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(arg: str) -> "Argument":
|
def parse(arg: str) -> Argument:
|
||||||
name: str
|
name: str
|
||||||
default: Optional[str]
|
default: str | None
|
||||||
assert " " in arg, f"illegal argument '{arg}'"
|
assert " " in arg, f"illegal argument '{arg}'"
|
||||||
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
||||||
if "=" in name_and_default:
|
if "=" in name_and_default:
|
||||||
@ -2026,7 +2024,7 @@ class Argument:
|
|||||||
default = None
|
default = None
|
||||||
# TODO: deduplicate annotation matching with Return
|
# TODO: deduplicate annotation matching with Return
|
||||||
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||||||
annotation: Optional[Annotation]
|
annotation: Annotation | None
|
||||||
if match:
|
if match:
|
||||||
# If you update this, make sure the __str__ still works too
|
# If you update this, make sure the __str__ still works too
|
||||||
assert match.group(2) in [
|
assert match.group(2) in [
|
||||||
@ -2069,24 +2067,24 @@ class Argument:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Return:
|
class Return:
|
||||||
name: Optional[str]
|
name: str | None
|
||||||
type: Type
|
type: Type
|
||||||
annotation: Optional[Annotation]
|
annotation: Annotation | None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias_info(self) -> Optional[Annotation]:
|
def alias_info(self) -> Annotation | None:
|
||||||
return self.annotation
|
return self.annotation
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(arg: str) -> "Return":
|
def parse(arg: str) -> Return:
|
||||||
name: Optional[str]
|
name: str | None
|
||||||
if " " in arg:
|
if " " in arg:
|
||||||
type_and_annot, name = arg.rsplit(" ", 1)
|
type_and_annot, name = arg.rsplit(" ", 1)
|
||||||
else:
|
else:
|
||||||
type_and_annot = arg
|
type_and_annot = arg
|
||||||
name = None
|
name = None
|
||||||
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||||||
annotation: Optional[Annotation]
|
annotation: Annotation | None
|
||||||
if match:
|
if match:
|
||||||
# If you update this, make sure the __str__ still works too
|
# If you update this, make sure the __str__ still works too
|
||||||
assert match.group(2) in [
|
assert match.group(2) in [
|
||||||
@ -2148,34 +2146,34 @@ class Arguments:
|
|||||||
# pre_self_positional is usually empty, but is notably non-empty
|
# pre_self_positional is usually empty, but is notably non-empty
|
||||||
# for where.self, where the condition argument comes before the
|
# for where.self, where the condition argument comes before the
|
||||||
# self argument
|
# self argument
|
||||||
pre_self_positional: Tuple[Argument, ...]
|
pre_self_positional: tuple[Argument, ...]
|
||||||
self_arg: Optional[SelfArgument]
|
self_arg: SelfArgument | None
|
||||||
post_self_positional: Tuple[Argument, ...]
|
post_self_positional: tuple[Argument, ...]
|
||||||
|
|
||||||
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
|
pre_tensor_options_kwarg_only: tuple[Argument, ...]
|
||||||
tensor_options: Optional[TensorOptionsArguments]
|
tensor_options: TensorOptionsArguments | None
|
||||||
# post_tensor_options is typically memory format, which should be
|
# post_tensor_options is typically memory format, which should be
|
||||||
# part of tensor options but isn't right now, and is usually
|
# part of tensor options but isn't right now, and is usually
|
||||||
# placed after the tensor options arguments
|
# placed after the tensor options arguments
|
||||||
post_tensor_options_kwarg_only: Tuple[Argument, ...]
|
post_tensor_options_kwarg_only: tuple[Argument, ...]
|
||||||
|
|
||||||
# Unlike in the previous codegen, we have factored out 'out' arguments
|
# Unlike in the previous codegen, we have factored out 'out' arguments
|
||||||
# in the canonical representation, removing them from kwarg
|
# in the canonical representation, removing them from kwarg
|
||||||
# arguments. This choice is justified by numerous downstream
|
# arguments. This choice is justified by numerous downstream
|
||||||
# transformations which treat out arguments specially; additionally,
|
# transformations which treat out arguments specially; additionally,
|
||||||
# you can see that canonicity is not violated!
|
# you can see that canonicity is not violated!
|
||||||
out: Tuple[Argument, ...] # these are also kwarg-only
|
out: tuple[Argument, ...] # these are also kwarg-only
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flat_non_out(self) -> Sequence[Argument]:
|
def flat_non_out(self) -> Sequence[Argument]:
|
||||||
ret: List[Argument] = []
|
ret: list[Argument] = []
|
||||||
ret.extend(self.flat_positional)
|
ret.extend(self.flat_positional)
|
||||||
ret.extend(self.flat_kwarg_only)
|
ret.extend(self.flat_kwarg_only)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def flat_positional(self) -> Sequence[Argument]:
|
def flat_positional(self) -> Sequence[Argument]:
|
||||||
ret: List[Argument] = []
|
ret: list[Argument] = []
|
||||||
ret.extend(self.pre_self_positional)
|
ret.extend(self.pre_self_positional)
|
||||||
if self.self_arg is not None:
|
if self.self_arg is not None:
|
||||||
ret.append(self.self_arg.argument)
|
ret.append(self.self_arg.argument)
|
||||||
@ -2189,7 +2187,7 @@ class Arguments:
|
|||||||
# NB: doesn't contain out arguments
|
# NB: doesn't contain out arguments
|
||||||
@property
|
@property
|
||||||
def flat_kwarg_only(self) -> Sequence[Argument]:
|
def flat_kwarg_only(self) -> Sequence[Argument]:
|
||||||
ret: List[Argument] = []
|
ret: list[Argument] = []
|
||||||
ret.extend(self.pre_tensor_options_kwarg_only)
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
||||||
if self.tensor_options is not None:
|
if self.tensor_options is not None:
|
||||||
ret.extend(self.tensor_options.all())
|
ret.extend(self.tensor_options.all())
|
||||||
@ -2198,7 +2196,7 @@ class Arguments:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def flat_all(self) -> Sequence[Argument]:
|
def flat_all(self) -> Sequence[Argument]:
|
||||||
ret: List[Argument] = []
|
ret: list[Argument] = []
|
||||||
ret.extend(self.flat_positional)
|
ret.extend(self.flat_positional)
|
||||||
ret.extend(self.flat_kwarg_only)
|
ret.extend(self.flat_kwarg_only)
|
||||||
ret.extend(self.out)
|
ret.extend(self.out)
|
||||||
@ -2207,15 +2205,15 @@ class Arguments:
|
|||||||
@property
|
@property
|
||||||
def non_out(
|
def non_out(
|
||||||
self,
|
self,
|
||||||
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
||||||
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
||||||
ret.extend(self.positional)
|
ret.extend(self.positional)
|
||||||
ret.extend(self.kwarg_only)
|
ret.extend(self.kwarg_only)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
|
def positional(self) -> Sequence[Argument | SelfArgument]:
|
||||||
ret: List[Union[Argument, SelfArgument]] = []
|
ret: list[Argument | SelfArgument] = []
|
||||||
ret.extend(self.pre_self_positional)
|
ret.extend(self.pre_self_positional)
|
||||||
if self.self_arg is not None:
|
if self.self_arg is not None:
|
||||||
ret.append(self.self_arg)
|
ret.append(self.self_arg)
|
||||||
@ -2223,8 +2221,8 @@ class Arguments:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
|
def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
|
||||||
ret: List[Union[Argument, TensorOptionsArguments]] = []
|
ret: list[Argument | TensorOptionsArguments] = []
|
||||||
ret.extend(self.pre_tensor_options_kwarg_only)
|
ret.extend(self.pre_tensor_options_kwarg_only)
|
||||||
if self.tensor_options is not None:
|
if self.tensor_options is not None:
|
||||||
ret.append(self.tensor_options)
|
ret.append(self.tensor_options)
|
||||||
@ -2232,14 +2230,14 @@ class Arguments:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
||||||
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
||||||
ret.extend(self.positional)
|
ret.extend(self.positional)
|
||||||
ret.extend(self.kwarg_only)
|
ret.extend(self.kwarg_only)
|
||||||
ret.extend(self.out)
|
ret.extend(self.out)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def mutable_arg_names(self) -> List[str]:
|
def mutable_arg_names(self) -> list[str]:
|
||||||
return [
|
return [
|
||||||
a.name
|
a.name
|
||||||
for a in self.flat_all
|
for a in self.flat_all
|
||||||
@ -2255,7 +2253,7 @@ class Arguments:
|
|||||||
def has_generator_arg(self) -> bool:
|
def has_generator_arg(self) -> bool:
|
||||||
return any(a.type.is_generator_like() for a in self.flat_non_out)
|
return any(a.type.is_generator_like() for a in self.flat_non_out)
|
||||||
|
|
||||||
def signature(self, *, strip_default: bool = False) -> "Arguments":
|
def signature(self, *, strip_default: bool = False) -> Arguments:
|
||||||
# dataclasses.replace could be used here, but it is less
|
# dataclasses.replace could be used here, but it is less
|
||||||
# type safe so for now I've opted to type everything out
|
# type safe so for now I've opted to type everything out
|
||||||
def strip_arg_annotation(a: Argument) -> Argument:
|
def strip_arg_annotation(a: Argument) -> Argument:
|
||||||
@ -2290,7 +2288,7 @@ class Arguments:
|
|||||||
out=(),
|
out=(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def remove_self_annotation(self) -> "Arguments":
|
def remove_self_annotation(self) -> Arguments:
|
||||||
assert self.self_arg is not None
|
assert self.self_arg is not None
|
||||||
return dataclasses.replace(
|
return dataclasses.replace(
|
||||||
self,
|
self,
|
||||||
@ -2299,7 +2297,7 @@ class Arguments:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_out_args(self, outs: List[Argument]) -> "Arguments":
|
def with_out_args(self, outs: list[Argument]) -> Arguments:
|
||||||
assert len(self.out) == 0
|
assert len(self.out) == 0
|
||||||
return dataclasses.replace(
|
return dataclasses.replace(
|
||||||
self,
|
self,
|
||||||
@ -2307,10 +2305,10 @@ class Arguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
|
def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
|
||||||
positional: List[Argument] = []
|
positional: list[Argument] = []
|
||||||
kwarg_only: List[Argument] = []
|
kwarg_only: list[Argument] = []
|
||||||
out: List[Argument] = []
|
out: list[Argument] = []
|
||||||
arguments_acc = positional
|
arguments_acc = positional
|
||||||
|
|
||||||
# TODO: Use a real parser here; this will get bamboozled
|
# TODO: Use a real parser here; this will get bamboozled
|
||||||
@ -2343,7 +2341,7 @@ class Arguments:
|
|||||||
return positional, kwarg_only, out
|
return positional, kwarg_only, out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(args: str) -> "Arguments":
|
def parse(args: str) -> Arguments:
|
||||||
"""
|
"""
|
||||||
Input: 'int x, int y, int z'
|
Input: 'int x, int y, int z'
|
||||||
"""
|
"""
|
||||||
@ -2361,9 +2359,9 @@ class Arguments:
|
|||||||
if a.name == "self":
|
if a.name == "self":
|
||||||
self_ix = i
|
self_ix = i
|
||||||
break
|
break
|
||||||
pre_self_positional: List[Argument]
|
pre_self_positional: list[Argument]
|
||||||
self_arg: Optional[SelfArgument]
|
self_arg: SelfArgument | None
|
||||||
post_self_positional: List[Argument]
|
post_self_positional: list[Argument]
|
||||||
if self_ix is not None:
|
if self_ix is not None:
|
||||||
pre_self_positional = positional[:self_ix]
|
pre_self_positional = positional[:self_ix]
|
||||||
self_arg = SelfArgument(positional[self_ix])
|
self_arg = SelfArgument(positional[self_ix])
|
||||||
@ -2374,9 +2372,9 @@ class Arguments:
|
|||||||
post_self_positional = positional
|
post_self_positional = positional
|
||||||
|
|
||||||
# Group tensor options arguments
|
# Group tensor options arguments
|
||||||
pre_tensor_options_kwarg_only: List[Argument] = []
|
pre_tensor_options_kwarg_only: list[Argument] = []
|
||||||
tensor_options: Optional[TensorOptionsArguments] = None
|
tensor_options: TensorOptionsArguments | None = None
|
||||||
post_tensor_options_kwarg_only: List[Argument] = []
|
post_tensor_options_kwarg_only: list[Argument] = []
|
||||||
kwarg_only_acc = pre_tensor_options_kwarg_only
|
kwarg_only_acc = pre_tensor_options_kwarg_only
|
||||||
|
|
||||||
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
||||||
@ -2423,7 +2421,7 @@ class Arguments:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
all_arguments: List[str] = []
|
all_arguments: list[str] = []
|
||||||
all_arguments.extend(map(str, self.flat_positional))
|
all_arguments.extend(map(str, self.flat_positional))
|
||||||
if self.flat_kwarg_only or self.out:
|
if self.flat_kwarg_only or self.out:
|
||||||
all_arguments.append("*")
|
all_arguments.append("*")
|
||||||
@ -2502,7 +2500,7 @@ class BaseOperatorName:
|
|||||||
functional_overload: bool = False
|
functional_overload: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(op: str) -> "BaseOperatorName":
|
def parse(op: str) -> BaseOperatorName:
|
||||||
assert op != ""
|
assert op != ""
|
||||||
assert not op.endswith("_out"), (
|
assert not op.endswith("_out"), (
|
||||||
"_out suffix is reserved and not permitted for operator names; "
|
"_out suffix is reserved and not permitted for operator names; "
|
||||||
@ -2574,7 +2572,7 @@ class OperatorName:
|
|||||||
overload_name: str
|
overload_name: str
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(op_name: str) -> "OperatorName":
|
def parse(op_name: str) -> OperatorName:
|
||||||
if "." in op_name:
|
if "." in op_name:
|
||||||
name, overload_name = op_name.split(".", 1)
|
name, overload_name = op_name.split(".", 1)
|
||||||
else:
|
else:
|
||||||
@ -2601,7 +2599,7 @@ class OperatorName:
|
|||||||
else:
|
else:
|
||||||
return f"{self.name}"
|
return f"{self.name}"
|
||||||
|
|
||||||
def remove_inplace(self) -> "OperatorName":
|
def remove_inplace(self) -> OperatorName:
|
||||||
return OperatorName(
|
return OperatorName(
|
||||||
name=BaseOperatorName(
|
name=BaseOperatorName(
|
||||||
base=self.name.base,
|
base=self.name.base,
|
||||||
@ -2611,7 +2609,7 @@ class OperatorName:
|
|||||||
overload_name=self.overload_name,
|
overload_name=self.overload_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
def with_overload(self, overload: str) -> "OperatorName":
|
def with_overload(self, overload: str) -> OperatorName:
|
||||||
return OperatorName(
|
return OperatorName(
|
||||||
name=BaseOperatorName(
|
name=BaseOperatorName(
|
||||||
base=self.name.base,
|
base=self.name.base,
|
||||||
@ -2649,9 +2647,9 @@ class NativeFunctionsViewGroup:
|
|||||||
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
|
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
|
||||||
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
|
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
|
||||||
# (we already get them "for free" through decomposition)
|
# (we already get them "for free" through decomposition)
|
||||||
view_copy: Optional[NativeFunction]
|
view_copy: NativeFunction | None
|
||||||
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
|
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
|
||||||
view_inplace: Optional[NativeFunction]
|
view_inplace: NativeFunction | None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self.view.is_view_op
|
assert self.view.is_view_op
|
||||||
@ -2731,7 +2729,7 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
|
|||||||
|
|
||||||
# Given a NativeFunction that corresponds to a view op,
|
# Given a NativeFunction that corresponds to a view op,
|
||||||
# returns the OperatorName of the corresponding "copy" variant of the op.
|
# returns the OperatorName of the corresponding "copy" variant of the op.
|
||||||
def get_view_copy_name(f: NativeFunction) -> "OperatorName":
|
def get_view_copy_name(f: NativeFunction) -> OperatorName:
|
||||||
# Right now, when asking for a view op's corresponding "view_copy" name
|
# Right now, when asking for a view op's corresponding "view_copy" name
|
||||||
# we assert for sanity that the op is allowed to have a generated view_copy variant.
|
# we assert for sanity that the op is allowed to have a generated view_copy variant.
|
||||||
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
|
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
|
||||||
@ -2755,7 +2753,7 @@ def get_view_copy_name(f: NativeFunction) -> "OperatorName":
|
|||||||
# Helper functions for parsing argument lists (both inputs and returns)
|
# Helper functions for parsing argument lists (both inputs and returns)
|
||||||
|
|
||||||
|
|
||||||
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
def parse_returns(return_decl: str) -> tuple[Return, ...]:
|
||||||
"""
|
"""
|
||||||
Input: '()'
|
Input: '()'
|
||||||
Output: []
|
Output: []
|
||||||
@ -2774,12 +2772,12 @@ def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
|||||||
class Precompute:
|
class Precompute:
|
||||||
# A map from kernel argument name -> a list of precomputed
|
# A map from kernel argument name -> a list of precomputed
|
||||||
# elements that replaces/supersedes it.
|
# elements that replaces/supersedes it.
|
||||||
replace: Dict[str, List[Argument]]
|
replace: dict[str, list[Argument]]
|
||||||
# List of precomputed args added without replacement
|
# List of precomputed args added without replacement
|
||||||
add: List[Argument]
|
add: list[Argument]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse(src: object) -> "Precompute":
|
def parse(src: object) -> Precompute:
|
||||||
assert isinstance(src, list)
|
assert isinstance(src, list)
|
||||||
|
|
||||||
# src is a list of strings of the format:
|
# src is a list of strings of the format:
|
||||||
@ -2824,7 +2822,7 @@ class Precompute:
|
|||||||
for a in args:
|
for a in args:
|
||||||
assert a.name.upper() != a.name
|
assert a.name.upper() != a.name
|
||||||
|
|
||||||
def to_list(self) -> List[str]:
|
def to_list(self) -> list[str]:
|
||||||
replace_list = []
|
replace_list = []
|
||||||
for kernel_param, replacement_params in self.replace.items():
|
for kernel_param, replacement_params in self.replace.items():
|
||||||
replacements = ", ".join(str(param) for param in replacement_params)
|
replacements = ", ".join(str(param) for param in replacement_params)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Sequence
|
||||||
|
|
||||||
import torchgen.api.dispatcher as dispatcher
|
import torchgen.api.dispatcher as dispatcher
|
||||||
from torchgen.api.translate import translate
|
from torchgen.api.translate import translate
|
||||||
@ -101,9 +103,9 @@ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
|||||||
# But have differing SchemaKinds.
|
# But have differing SchemaKinds.
|
||||||
def pre_group_native_functions(
|
def pre_group_native_functions(
|
||||||
native_functions: Sequence[NativeFunction],
|
native_functions: Sequence[NativeFunction],
|
||||||
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
|
) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
|
||||||
pre_grouped_native_functions: Dict[
|
pre_grouped_native_functions: dict[
|
||||||
FunctionSchema, Dict[SchemaKind, NativeFunction]
|
FunctionSchema, dict[SchemaKind, NativeFunction]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
for f in native_functions:
|
for f in native_functions:
|
||||||
d = pre_grouped_native_functions[f.func.signature()]
|
d = pre_grouped_native_functions[f.func.signature()]
|
||||||
@ -113,7 +115,7 @@ def pre_group_native_functions(
|
|||||||
|
|
||||||
|
|
||||||
# Returns the out variant overload name given a base function overload name
|
# Returns the out variant overload name given a base function overload name
|
||||||
def get_expected_out_variant_overload_name(overload_name: Optional[str]) -> str:
|
def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
|
||||||
return "out" if not overload_name else f"{overload_name}_out"
|
return "out" if not overload_name else f"{overload_name}_out"
|
||||||
|
|
||||||
|
|
||||||
@ -178,7 +180,7 @@ def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
|||||||
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
|
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
|
||||||
def generate_out_args_from_schema(
|
def generate_out_args_from_schema(
|
||||||
func: FunctionSchema,
|
func: FunctionSchema,
|
||||||
) -> Tuple[List[Return], List[Argument]]:
|
) -> tuple[list[Return], list[Argument]]:
|
||||||
# More of a sanity check - our existing restrictions on schemas should enforce that
|
# More of a sanity check - our existing restrictions on schemas should enforce that
|
||||||
# mutable schema kinds never return their mutable arguments.
|
# mutable schema kinds never return their mutable arguments.
|
||||||
assert not any(
|
assert not any(
|
||||||
@ -198,11 +200,11 @@ def generate_out_args_from_schema(
|
|||||||
|
|
||||||
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
|
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
|
||||||
|
|
||||||
new_out_args: List[Argument] = []
|
new_out_args: list[Argument] = []
|
||||||
# The end result of new_returns is that:
|
# The end result of new_returns is that:
|
||||||
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
|
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
|
||||||
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
|
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
|
||||||
new_returns: List[Return] = []
|
new_returns: list[Return] = []
|
||||||
for i, r in enumerate(func.returns):
|
for i, r in enumerate(func.returns):
|
||||||
if r.type.is_tensor_like():
|
if r.type.is_tensor_like():
|
||||||
new_out = Argument(
|
new_out = Argument(
|
||||||
@ -266,7 +268,7 @@ def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
|||||||
# Details are in the function, but we only generate composite kernels (in some cases) today.
|
# Details are in the function, but we only generate composite kernels (in some cases) today.
|
||||||
def generate_function(
|
def generate_function(
|
||||||
f: NativeFunction, k: SchemaKind
|
f: NativeFunction, k: SchemaKind
|
||||||
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
|
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
|
||||||
from torchgen.api import cpp
|
from torchgen.api import cpp
|
||||||
|
|
||||||
if k == SchemaKind.functional:
|
if k == SchemaKind.functional:
|
||||||
@ -375,8 +377,8 @@ def generate_function(
|
|||||||
# Note: this function *mutates* its two inputs,
|
# Note: this function *mutates* its two inputs,
|
||||||
# adding the new NativeFunctions / BackendMetadata to them
|
# adding the new NativeFunctions / BackendMetadata to them
|
||||||
def add_generated_native_functions(
|
def add_generated_native_functions(
|
||||||
rs: List[NativeFunction],
|
rs: list[NativeFunction],
|
||||||
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
|
indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||||
) -> None:
|
) -> None:
|
||||||
# The main code for generating new NativeFunctions
|
# The main code for generating new NativeFunctions
|
||||||
# First we group of NativeFunctions by schema kind,
|
# First we group of NativeFunctions by schema kind,
|
||||||
@ -497,7 +499,7 @@ out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THA
|
|||||||
rs.append(fn)
|
rs.append(fn)
|
||||||
|
|
||||||
|
|
||||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
|
||||||
assert len(rets) == len(names)
|
assert len(rets) == len(names)
|
||||||
if len(rets) == 0:
|
if len(rets) == 0:
|
||||||
return ""
|
return ""
|
||||||
@ -509,7 +511,7 @@ def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
|||||||
|
|
||||||
# Given a function, and the name of a variable corresponding to the output of that function,
|
# Given a function, and the name of a variable corresponding to the output of that function,
|
||||||
# gather up all of the individual returns that are not aliased
|
# gather up all of the individual returns that are not aliased
|
||||||
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]:
|
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
|
||||||
aliased_rets = func.aliased_return_names()
|
aliased_rets = func.aliased_return_names()
|
||||||
non_aliased_names = []
|
non_aliased_names = []
|
||||||
is_out_var_a_tuple = len(func.returns) > 1
|
is_out_var_a_tuple = len(func.returns) > 1
|
||||||
@ -524,7 +526,7 @@ def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str
|
|||||||
# Generates functional kernels in terms of their inplace.mutable counterparts.
|
# Generates functional kernels in terms of their inplace.mutable counterparts.
|
||||||
# We only do this for "generated" NativeFunctions
|
# We only do this for "generated" NativeFunctions
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||||
# We should only be generating these for code-generated NativeFunctions
|
# We should only be generating these for code-generated NativeFunctions
|
||||||
if "generated" not in g.functional.tags:
|
if "generated" not in g.functional.tags:
|
||||||
return None
|
return None
|
||||||
@ -541,7 +543,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
|||||||
sig = DispatcherSignature(g.functional.func)
|
sig = DispatcherSignature(g.functional.func)
|
||||||
target_sig = DispatcherSignature(target_f.func)
|
target_sig = DispatcherSignature(target_f.func)
|
||||||
|
|
||||||
context: List[Union[Binding, Expr]] = []
|
context: list[Binding | Expr] = []
|
||||||
clone_mutable_inputs = []
|
clone_mutable_inputs = []
|
||||||
cloned_return_names = []
|
cloned_return_names = []
|
||||||
# We can't just directly pass all of the arguments from the functional op into the mutating op.
|
# We can't just directly pass all of the arguments from the functional op into the mutating op.
|
||||||
@ -587,7 +589,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
|||||||
# Generates out= kernels in terms of their functional counterparts.
|
# Generates out= kernels in terms of their functional counterparts.
|
||||||
# We only do this for "generated" NativeFunctions
|
# We only do this for "generated" NativeFunctions
|
||||||
@with_native_function
|
@with_native_function
|
||||||
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||||
# We should only be generating these for code-generated NativeFunctions
|
# We should only be generating these for code-generated NativeFunctions
|
||||||
if "generated" not in g.out.tags:
|
if "generated" not in g.out.tags:
|
||||||
return None
|
return None
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.jit.generate_bytecode import generate_upgraders_bytecode
|
from torch.jit.generate_bytecode import generate_upgraders_bytecode
|
||||||
@ -185,7 +188,7 @@ PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
|
def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
|
||||||
instruction_list_part = []
|
instruction_list_part = []
|
||||||
for instruction in instruction_list_from_yaml:
|
for instruction in instruction_list_from_yaml:
|
||||||
instruction_list_part.append(
|
instruction_list_part.append(
|
||||||
@ -200,7 +203,7 @@ def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_constants(constants_list_from_yaml: List[Any]) -> str:
|
def construct_constants(constants_list_from_yaml: list[Any]) -> str:
|
||||||
constants_list_part = []
|
constants_list_part = []
|
||||||
for constant_from_yaml in constants_list_from_yaml:
|
for constant_from_yaml in constants_list_from_yaml:
|
||||||
convert_constant = None
|
convert_constant = None
|
||||||
@ -226,7 +229,7 @@ def construct_constants(constants_list_from_yaml: List[Any]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_operators(operator_list_from_yaml: List[Any]) -> str:
|
def construct_operators(operator_list_from_yaml: list[Any]) -> str:
|
||||||
operator_list_part = []
|
operator_list_part = []
|
||||||
for operator in operator_list_from_yaml:
|
for operator in operator_list_from_yaml:
|
||||||
operator_list_part.append(
|
operator_list_part.append(
|
||||||
@ -241,7 +244,7 @@ def construct_operators(operator_list_from_yaml: List[Any]) -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_types(types_tr_list_from_yaml: List[Any]) -> str:
|
def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
|
||||||
types_tr_list_part = []
|
types_tr_list_part = []
|
||||||
for types_tr in types_tr_list_from_yaml:
|
for types_tr in types_tr_list_from_yaml:
|
||||||
types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
|
types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
|
||||||
@ -260,7 +263,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def construct_version_maps(
|
def construct_version_maps(
|
||||||
upgrader_bytecode_function_to_index_map: Dict[str, Any]
|
upgrader_bytecode_function_to_index_map: dict[str, Any]
|
||||||
) -> str:
|
) -> str:
|
||||||
version_map = torch._C._get_operator_version_map()
|
version_map = torch._C._get_operator_version_map()
|
||||||
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
|
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
|
||||||
@ -302,8 +305,8 @@ def construct_version_maps(
|
|||||||
|
|
||||||
|
|
||||||
def get_upgrader_bytecode_function_to_index_map(
|
def get_upgrader_bytecode_function_to_index_map(
|
||||||
upgrader_dict: List[Dict[str, Any]]
|
upgrader_dict: list[dict[str, Any]]
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
upgrader_bytecode_function_to_index_map = {}
|
upgrader_bytecode_function_to_index_map = {}
|
||||||
index = 0
|
index = 0
|
||||||
for upgrader_bytecode in upgrader_dict:
|
for upgrader_bytecode in upgrader_dict:
|
||||||
@ -315,7 +318,7 @@ def get_upgrader_bytecode_function_to_index_map(
|
|||||||
return upgrader_bytecode_function_to_index_map
|
return upgrader_bytecode_function_to_index_map
|
||||||
|
|
||||||
|
|
||||||
def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
|
def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
|
||||||
body_parts = []
|
body_parts = []
|
||||||
upgrader_bytecode_function_to_index_map = (
|
upgrader_bytecode_function_to_index_map = (
|
||||||
get_upgrader_bytecode_function_to_index_map(upgrader_dict)
|
get_upgrader_bytecode_function_to_index_map(upgrader_dict)
|
||||||
@ -370,7 +373,7 @@ def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
|
|||||||
out_file.write(upgrader_file_content.encode("utf-8"))
|
out_file.write(upgrader_file_content.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||||
sorted_upgrader_list = sorted(
|
sorted_upgrader_list = sorted(
|
||||||
upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
|
upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
# This class holds information about a single operator used to determine
|
# This class holds information about a single operator used to determine
|
||||||
@ -46,12 +47,12 @@ class SelectiveBuildOperator:
|
|||||||
include_all_overloads: bool
|
include_all_overloads: bool
|
||||||
|
|
||||||
# Debug Information at the operator level
|
# Debug Information at the operator level
|
||||||
_debug_info: Optional[Tuple[str, ...]]
|
_debug_info: tuple[str, ...] | None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_yaml_dict(
|
def from_yaml_dict(
|
||||||
op_name: str, op_info: Dict[str, object]
|
op_name: str, op_info: dict[str, object]
|
||||||
) -> "SelectiveBuildOperator":
|
) -> SelectiveBuildOperator:
|
||||||
allowed_keys = {
|
allowed_keys = {
|
||||||
"name",
|
"name",
|
||||||
"is_root_operator",
|
"is_root_operator",
|
||||||
@ -79,7 +80,7 @@ class SelectiveBuildOperator:
|
|||||||
include_all_overloads = op_info.get("include_all_overloads", True)
|
include_all_overloads = op_info.get("include_all_overloads", True)
|
||||||
assert isinstance(include_all_overloads, bool)
|
assert isinstance(include_all_overloads, bool)
|
||||||
|
|
||||||
debug_info: Optional[Tuple[str, ...]] = None
|
debug_info: tuple[str, ...] | None = None
|
||||||
if "debug_info" in op_info:
|
if "debug_info" in op_info:
|
||||||
di_list = op_info["debug_info"]
|
di_list = op_info["debug_info"]
|
||||||
assert isinstance(di_list, list)
|
assert isinstance(di_list, list)
|
||||||
@ -96,7 +97,7 @@ class SelectiveBuildOperator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_legacy_operator_name_without_overload(
|
def from_legacy_operator_name_without_overload(
|
||||||
name: str,
|
name: str,
|
||||||
) -> "SelectiveBuildOperator":
|
) -> SelectiveBuildOperator:
|
||||||
return SelectiveBuildOperator(
|
return SelectiveBuildOperator(
|
||||||
name=name,
|
name=name,
|
||||||
is_root_operator=True,
|
is_root_operator=True,
|
||||||
@ -105,8 +106,8 @@ class SelectiveBuildOperator:
|
|||||||
_debug_info=None,
|
_debug_info=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, object]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
ret: Dict[str, object] = {
|
ret: dict[str, object] = {
|
||||||
"is_root_operator": self.is_root_operator,
|
"is_root_operator": self.is_root_operator,
|
||||||
"is_used_for_training": self.is_used_for_training,
|
"is_used_for_training": self.is_used_for_training,
|
||||||
"include_all_overloads": self.include_all_overloads,
|
"include_all_overloads": self.include_all_overloads,
|
||||||
@ -118,9 +119,9 @@ class SelectiveBuildOperator:
|
|||||||
|
|
||||||
|
|
||||||
def merge_debug_info(
|
def merge_debug_info(
|
||||||
lhs: Optional[Tuple[str, ...]],
|
lhs: tuple[str, ...] | None,
|
||||||
rhs: Optional[Tuple[str, ...]],
|
rhs: tuple[str, ...] | None,
|
||||||
) -> Optional[Tuple[str, ...]]:
|
) -> tuple[str, ...] | None:
|
||||||
# Ensure that when merging, each entry shows up just once.
|
# Ensure that when merging, each entry shows up just once.
|
||||||
if lhs is None and rhs is None:
|
if lhs is None and rhs is None:
|
||||||
return None
|
return None
|
||||||
@ -129,8 +130,8 @@ def merge_debug_info(
|
|||||||
|
|
||||||
|
|
||||||
def combine_operators(
|
def combine_operators(
|
||||||
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator"
|
lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
|
||||||
) -> "SelectiveBuildOperator":
|
) -> SelectiveBuildOperator:
|
||||||
if str(lhs.name) != str(rhs.name):
|
if str(lhs.name) != str(rhs.name):
|
||||||
raise Exception( # noqa: TRY002
|
raise Exception( # noqa: TRY002
|
||||||
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
|
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
|
||||||
@ -152,10 +153,10 @@ def combine_operators(
|
|||||||
|
|
||||||
|
|
||||||
def merge_operator_dicts(
|
def merge_operator_dicts(
|
||||||
lhs: Dict[str, SelectiveBuildOperator],
|
lhs: dict[str, SelectiveBuildOperator],
|
||||||
rhs: Dict[str, SelectiveBuildOperator],
|
rhs: dict[str, SelectiveBuildOperator],
|
||||||
) -> Dict[str, SelectiveBuildOperator]:
|
) -> dict[str, SelectiveBuildOperator]:
|
||||||
operators: Dict[str, SelectiveBuildOperator] = {}
|
operators: dict[str, SelectiveBuildOperator] = {}
|
||||||
for op_name, op in list(lhs.items()) + list(rhs.items()):
|
for op_name, op in list(lhs.items()) + list(rhs.items()):
|
||||||
new_op = op
|
new_op = op
|
||||||
if op_name in operators:
|
if op_name in operators:
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from torchgen.model import NativeFunction
|
|
||||||
from torchgen.selective_build.operator import (
|
from torchgen.selective_build.operator import (
|
||||||
merge_debug_info,
|
merge_debug_info,
|
||||||
merge_operator_dicts,
|
merge_operator_dicts,
|
||||||
@ -14,6 +15,10 @@ from torchgen.selective_build.operator import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torchgen.model import NativeFunction
|
||||||
|
|
||||||
|
|
||||||
# A SelectiveBuilder holds information extracted from the selective build
|
# A SelectiveBuilder holds information extracted from the selective build
|
||||||
# YAML specification.
|
# YAML specification.
|
||||||
#
|
#
|
||||||
@ -28,10 +33,10 @@ class SelectiveBuilder:
|
|||||||
include_all_operators: bool
|
include_all_operators: bool
|
||||||
|
|
||||||
# Debug Information at the selective/custom build level.
|
# Debug Information at the selective/custom build level.
|
||||||
_debug_info: Optional[Tuple[str, ...]]
|
_debug_info: tuple[str, ...] | None
|
||||||
|
|
||||||
# A dictionary of operator -> operator metadata.
|
# A dictionary of operator -> operator metadata.
|
||||||
operators: Dict[str, SelectiveBuildOperator]
|
operators: dict[str, SelectiveBuildOperator]
|
||||||
|
|
||||||
# A dictionary of selected kernel tags and dtypes. Typically a
|
# A dictionary of selected kernel tags and dtypes. Typically a
|
||||||
# PyTorch Operator Kernel (function) may have many code paths
|
# PyTorch Operator Kernel (function) may have many code paths
|
||||||
@ -39,22 +44,22 @@ class SelectiveBuilder:
|
|||||||
# one per kernel function, but there could be many per kernel
|
# one per kernel function, but there could be many per kernel
|
||||||
# function. The tag isn't a kernel function name, but some fragment
|
# function. The tag isn't a kernel function name, but some fragment
|
||||||
# of the kernel function implementation itself.
|
# of the kernel function implementation itself.
|
||||||
kernel_metadata: Dict[str, List[str]]
|
kernel_metadata: dict[str, list[str]]
|
||||||
|
|
||||||
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input
|
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input
|
||||||
# dtypes for tensor-like input args).
|
# dtypes for tensor-like input args).
|
||||||
# This is from selective.yaml
|
# This is from selective.yaml
|
||||||
et_kernel_metadata: Dict[str, List[str]]
|
et_kernel_metadata: dict[str, list[str]]
|
||||||
|
|
||||||
# A set of all the custom torch bind classes used by the selected models
|
# A set of all the custom torch bind classes used by the selected models
|
||||||
# Stored as a set internally to remove duplicates proactively, but written
|
# Stored as a set internally to remove duplicates proactively, but written
|
||||||
# as a list to yamls
|
# as a list to yamls
|
||||||
custom_classes: Set[str]
|
custom_classes: set[str]
|
||||||
|
|
||||||
# A set of all the build features used by the selected models
|
# A set of all the build features used by the selected models
|
||||||
# Stored as a set internally to remove duplicates proactively, but written
|
# Stored as a set internally to remove duplicates proactively, but written
|
||||||
# as a list to yamls
|
# as a list to yamls
|
||||||
build_features: Set[str]
|
build_features: set[str]
|
||||||
|
|
||||||
# If true, then fragments for all dtypes for all kernel functions
|
# If true, then fragments for all dtypes for all kernel functions
|
||||||
# are included as well as all custom classes. This is typically set when any one of the
|
# are included as well as all custom classes. This is typically set when any one of the
|
||||||
@ -63,11 +68,11 @@ class SelectiveBuilder:
|
|||||||
include_all_non_op_selectives: bool
|
include_all_non_op_selectives: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_nop_selector() -> "SelectiveBuilder":
|
def get_nop_selector() -> SelectiveBuilder:
|
||||||
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
|
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder":
|
def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
|
||||||
valid_top_level_keys = {
|
valid_top_level_keys = {
|
||||||
"include_all_non_op_selectives",
|
"include_all_non_op_selectives",
|
||||||
"include_all_operators",
|
"include_all_operators",
|
||||||
@ -135,20 +140,20 @@ class SelectiveBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_yaml_str(config_contents: str) -> "SelectiveBuilder":
|
def from_yaml_str(config_contents: str) -> SelectiveBuilder:
|
||||||
contents = yaml.safe_load(config_contents)
|
contents = yaml.safe_load(config_contents)
|
||||||
return SelectiveBuilder.from_yaml_dict(contents)
|
return SelectiveBuilder.from_yaml_dict(contents)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_yaml_path(config_path: str) -> "SelectiveBuilder":
|
def from_yaml_path(config_path: str) -> SelectiveBuilder:
|
||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
contents = yaml.safe_load(f)
|
contents = yaml.safe_load(f)
|
||||||
return SelectiveBuilder.from_yaml_dict(contents)
|
return SelectiveBuilder.from_yaml_dict(contents)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_legacy_op_registration_allow_list(
|
def from_legacy_op_registration_allow_list(
|
||||||
allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool
|
allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
|
||||||
) -> "SelectiveBuilder":
|
) -> SelectiveBuilder:
|
||||||
operators = {}
|
operators = {}
|
||||||
for op in allow_list:
|
for op in allow_list:
|
||||||
operators[op] = {
|
operators[op] = {
|
||||||
@ -231,7 +236,7 @@ class SelectiveBuilder:
|
|||||||
and dtype in self.kernel_metadata[kernel_tag]
|
and dtype in self.kernel_metadata[kernel_tag]
|
||||||
)
|
)
|
||||||
|
|
||||||
def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]:
|
def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Return a list of kernel keys that cover the used ops
|
Return a list of kernel keys that cover the used ops
|
||||||
"""
|
"""
|
||||||
@ -261,8 +266,8 @@ class SelectiveBuilder:
|
|||||||
|
|
||||||
return list(result_set)
|
return list(result_set)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, object]:
|
def to_dict(self) -> dict[str, object]:
|
||||||
ret: Dict[str, object] = {
|
ret: dict[str, object] = {
|
||||||
"include_all_non_op_selectives": self.include_all_non_op_selectives,
|
"include_all_non_op_selectives": self.include_all_non_op_selectives,
|
||||||
"include_all_operators": self.include_all_operators,
|
"include_all_operators": self.include_all_operators,
|
||||||
}
|
}
|
||||||
@ -288,10 +293,10 @@ class SelectiveBuilder:
|
|||||||
|
|
||||||
|
|
||||||
def merge_kernel_metadata(
|
def merge_kernel_metadata(
|
||||||
lhs: Dict[str, List[str]],
|
lhs: dict[str, list[str]],
|
||||||
rhs: Dict[str, List[str]],
|
rhs: dict[str, list[str]],
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
kernel_metadata: Dict[str, List[str]] = {}
|
kernel_metadata: dict[str, list[str]] = {}
|
||||||
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
|
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
|
||||||
dtypes_copy = set(dtypes)
|
dtypes_copy = set(dtypes)
|
||||||
if tag_name in kernel_metadata:
|
if tag_name in kernel_metadata:
|
||||||
@ -303,10 +308,10 @@ def merge_kernel_metadata(
|
|||||||
|
|
||||||
|
|
||||||
def merge_et_kernel_metadata(
|
def merge_et_kernel_metadata(
|
||||||
lhs: Dict[str, List[str]],
|
lhs: dict[str, list[str]],
|
||||||
rhs: Dict[str, List[str]],
|
rhs: dict[str, list[str]],
|
||||||
) -> Dict[str, List[str]]:
|
) -> dict[str, list[str]]:
|
||||||
merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set)
|
merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
|
||||||
for op in list(lhs.keys()) + list(rhs.keys()):
|
for op in list(lhs.keys()) + list(rhs.keys()):
|
||||||
merge_et_kernel_metadata[op].update(lhs.get(op, []))
|
merge_et_kernel_metadata[op].update(lhs.get(op, []))
|
||||||
merge_et_kernel_metadata[op].update(rhs.get(op, []))
|
merge_et_kernel_metadata[op].update(rhs.get(op, []))
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import importlib.util
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -18,9 +18,9 @@ you are in the root directory of the Pytorch git repo"""
|
|||||||
if not file_path.exists():
|
if not file_path.exists():
|
||||||
raise Exception(err_msg) # noqa: TRY002
|
raise Exception(err_msg) # noqa: TRY002
|
||||||
|
|
||||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
spec = spec_from_file_location(module_name, file_path)
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = module_from_spec(spec)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
assert spec.loader is not None
|
assert spec.loader is not None
|
||||||
assert module is not None
|
assert module is not None
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
from typing import Dict, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||||
|
|
||||||
|
|
||||||
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
|
def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
|
||||||
if isinstance(g, NativeFunctionsGroup):
|
if isinstance(g, NativeFunctionsGroup):
|
||||||
return str(g.functional.func.name.name.base)
|
return str(g.functional.func.name.name.base)
|
||||||
else:
|
else:
|
||||||
@ -55,12 +55,12 @@ is_hand_written_ops_ = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||||
name_base = func_name_base_str(g)
|
name_base = func_name_base_str(g)
|
||||||
return name_base in is_hand_written_ops_
|
return name_base in is_hand_written_ops_
|
||||||
|
|
||||||
|
|
||||||
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
|
def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
|
||||||
assert index == 0 or index == 1
|
assert index == 0 or index == 1
|
||||||
if op_name == "addr":
|
if op_name == "addr":
|
||||||
if index == 0:
|
if index == 0:
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
@ -28,7 +30,7 @@ def group_functions_by_op_name(
|
|||||||
return []
|
return []
|
||||||
groups = []
|
groups = []
|
||||||
|
|
||||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||||
with native_function_manager(g):
|
with native_function_manager(g):
|
||||||
return generator.is_supported(g)
|
return generator.is_supported(g)
|
||||||
|
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Sequence
|
||||||
|
|
||||||
import torchgen.api.cpp as cpp
|
import torchgen.api.cpp as cpp
|
||||||
from torchgen.context import native_function_manager
|
from torchgen.context import native_function_manager
|
||||||
@ -25,7 +27,7 @@ logger: logging.Logger = logging.getLogger()
|
|||||||
|
|
||||||
|
|
||||||
def has_alias(
|
def has_alias(
|
||||||
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
|
arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
for arg in arguments:
|
for arg in arguments:
|
||||||
annotation = getattr(arg, "annotation", None)
|
annotation = getattr(arg, "annotation", None)
|
||||||
@ -237,7 +239,7 @@ BLOCKED_OPS = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||||
base_op_name = ""
|
base_op_name = ""
|
||||||
func = None
|
func = None
|
||||||
if isinstance(g, NativeFunctionsViewGroup):
|
if isinstance(g, NativeFunctionsViewGroup):
|
||||||
@ -298,8 +300,8 @@ def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bo
|
|||||||
|
|
||||||
|
|
||||||
def ivalue_type_conversion_method(
|
def ivalue_type_conversion_method(
|
||||||
arg_type: Union[BaseType, OptionalType, Type]
|
arg_type: BaseType | OptionalType | Type,
|
||||||
) -> Optional[Tuple[bool, str]]:
|
) -> tuple[bool, str] | None:
|
||||||
"""
|
"""
|
||||||
Return the method call expression of `c10::ivalue' to convert its contained value to
|
Return the method call expression of `c10::ivalue' to convert its contained value to
|
||||||
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
|
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
|
||||||
@ -394,7 +396,7 @@ def test_tensor_dim(op_name: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
|
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
|
||||||
test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
|
test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_shape(op_name: str) -> str:
|
def test_tensor_shape(op_name: str) -> str:
|
||||||
@ -405,7 +407,7 @@ def test_tensor_shape(op_name: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def test_value_expression(
|
def test_value_expression(
|
||||||
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
|
arg_type: BaseType | OptionalType | Type, index: int, op_name: str
|
||||||
) -> str:
|
) -> str:
|
||||||
tensor_size_ex = test_tensor_shape(op_name)
|
tensor_size_ex = test_tensor_shape(op_name)
|
||||||
if tensor_size_ex == "":
|
if tensor_size_ex == "":
|
||||||
@ -475,8 +477,8 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
|
|||||||
|
|
||||||
def generate_test_ir_arguments(
|
def generate_test_ir_arguments(
|
||||||
schema: FunctionSchema,
|
schema: FunctionSchema,
|
||||||
) -> List[Tuple[str, Optional[str]]]:
|
) -> list[tuple[str, str | None]]:
|
||||||
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
|
def ir_argument(arg: Argument) -> tuple[str, str | None]:
|
||||||
t = arg.type
|
t = arg.type
|
||||||
add_optional = False
|
add_optional = False
|
||||||
if isinstance(t, OptionalType):
|
if isinstance(t, OptionalType):
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -5,31 +7,29 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from argparse import Namespace
|
|
||||||
from dataclasses import fields, is_dataclass
|
from dataclasses import fields, is_dataclass
|
||||||
from enum import auto, Enum
|
from enum import auto, Enum
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
TYPE_CHECKING,
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from torchgen.code_template import CodeTemplate
|
from torchgen.code_template import CodeTemplate
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
|
||||||
# Many of these functions share logic for defining both the definition
|
# Many of these functions share logic for defining both the definition
|
||||||
# and declaration (for example, the function signature is the same), so
|
# and declaration (for example, the function signature is the same), so
|
||||||
# we organize them into one function that takes a Target to say which
|
# we organize them into one function that takes a Target to say which
|
||||||
@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)"
|
|||||||
|
|
||||||
|
|
||||||
# TODO: Use a real parser here; this will get bamboozled
|
# TODO: Use a real parser here; this will get bamboozled
|
||||||
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
def split_name_params(schema: str) -> tuple[str, list[str]]:
|
||||||
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
|
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
|
||||||
if m is None:
|
if m is None:
|
||||||
raise RuntimeError(f"Unsupported function schema: {schema}")
|
raise RuntimeError(f"Unsupported function schema: {schema}")
|
||||||
@ -73,7 +73,7 @@ S = TypeVar("S")
|
|||||||
|
|
||||||
|
|
||||||
# Map over function that may return None; omit Nones from output sequence
|
# Map over function that may return None; omit Nones from output sequence
|
||||||
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
|
||||||
for x in xs:
|
for x in xs:
|
||||||
r = func(x)
|
r = func(x)
|
||||||
if r is not None:
|
if r is not None:
|
||||||
@ -127,7 +127,7 @@ class FileManager:
|
|||||||
install_dir: str
|
install_dir: str
|
||||||
template_dir: str
|
template_dir: str
|
||||||
dry_run: bool
|
dry_run: bool
|
||||||
filenames: Set[str]
|
filenames: set[str]
|
||||||
|
|
||||||
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
||||||
self.install_dir = install_dir
|
self.install_dir = install_dir
|
||||||
@ -136,7 +136,7 @@ class FileManager:
|
|||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
|
|
||||||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
def _write_if_changed(self, filename: str, contents: str) -> None:
|
||||||
old_contents: Optional[str]
|
old_contents: str | None
|
||||||
try:
|
try:
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
old_contents = f.read()
|
old_contents = f.read()
|
||||||
@ -150,7 +150,7 @@ class FileManager:
|
|||||||
|
|
||||||
# Read from template file and replace pattern with callable (type could be dict or str).
|
# Read from template file and replace pattern with callable (type could be dict or str).
|
||||||
def substitute_with_template(
|
def substitute_with_template(
|
||||||
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
|
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
|
||||||
) -> str:
|
) -> str:
|
||||||
template_path = os.path.join(self.template_dir, template_fn)
|
template_path = os.path.join(self.template_dir, template_fn)
|
||||||
env = env_callable()
|
env = env_callable()
|
||||||
@ -171,7 +171,7 @@ class FileManager:
|
|||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str,
|
||||||
template_fn: str,
|
template_fn: str,
|
||||||
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
env_callable: Callable[[], str | dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
filename = f"{self.install_dir}/{filename}"
|
filename = f"{self.install_dir}/{filename}"
|
||||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||||
@ -186,7 +186,7 @@ class FileManager:
|
|||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
filename: str,
|
filename: str,
|
||||||
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
env_callable: Callable[[], str | dict[str, Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.write_with_template(filename, filename, env_callable)
|
self.write_with_template(filename, filename, env_callable)
|
||||||
|
|
||||||
@ -196,13 +196,13 @@ class FileManager:
|
|||||||
items: Iterable[T],
|
items: Iterable[T],
|
||||||
*,
|
*,
|
||||||
key_fn: Callable[[T], str],
|
key_fn: Callable[[T], str],
|
||||||
env_callable: Callable[[T], Dict[str, List[str]]],
|
env_callable: Callable[[T], dict[str, list[str]]],
|
||||||
num_shards: int,
|
num_shards: int,
|
||||||
base_env: Optional[Dict[str, Any]] = None,
|
base_env: dict[str, Any] | None = None,
|
||||||
sharded_keys: Set[str],
|
sharded_keys: set[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
everything: Dict[str, Any] = {"shard_id": "Everything"}
|
everything: dict[str, Any] = {"shard_id": "Everything"}
|
||||||
shards: List[Dict[str, Any]] = [
|
shards: list[dict[str, Any]] = [
|
||||||
{"shard_id": f"_{i}"} for i in range(num_shards)
|
{"shard_id": f"_{i}"} for i in range(num_shards)
|
||||||
]
|
]
|
||||||
all_shards = [everything] + shards
|
all_shards = [everything] + shards
|
||||||
@ -221,7 +221,7 @@ class FileManager:
|
|||||||
else:
|
else:
|
||||||
shard[key] = []
|
shard[key] = []
|
||||||
|
|
||||||
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
|
||||||
for k, v in from_.items():
|
for k, v in from_.items():
|
||||||
assert k in sharded_keys, f"undeclared sharded key {k}"
|
assert k in sharded_keys, f"undeclared sharded key {k}"
|
||||||
into[k] += v
|
into[k] += v
|
||||||
@ -275,7 +275,7 @@ class FileManager:
|
|||||||
|
|
||||||
# Helper function to generate file manager
|
# Helper function to generate file manager
|
||||||
def make_file_manager(
|
def make_file_manager(
|
||||||
options: Namespace, install_dir: Optional[str] = None
|
options: Namespace, install_dir: str | None = None
|
||||||
) -> FileManager:
|
) -> FileManager:
|
||||||
template_dir = os.path.join(options.source_path, "templates")
|
template_dir = os.path.join(options.source_path, "templates")
|
||||||
install_dir = install_dir if install_dir else options.install_dir
|
install_dir = install_dir if install_dir else options.install_dir
|
||||||
@ -335,7 +335,7 @@ def _pformat(
|
|||||||
|
|
||||||
|
|
||||||
def _format_dict(
|
def _format_dict(
|
||||||
attr: Dict[Any, Any],
|
attr: dict[Any, Any],
|
||||||
indent: int,
|
indent: int,
|
||||||
width: int,
|
width: int,
|
||||||
curr_indent: int,
|
curr_indent: int,
|
||||||
@ -355,7 +355,7 @@ def _format_dict(
|
|||||||
|
|
||||||
|
|
||||||
def _format_list(
|
def _format_list(
|
||||||
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
|
attr: list[Any] | set[Any] | tuple[Any, ...],
|
||||||
indent: int,
|
indent: int,
|
||||||
width: int,
|
width: int,
|
||||||
curr_indent: int,
|
curr_indent: int,
|
||||||
@ -370,7 +370,7 @@ def _format_list(
|
|||||||
|
|
||||||
|
|
||||||
def _format(
|
def _format(
|
||||||
fields_str: List[str],
|
fields_str: list[str],
|
||||||
indent: int,
|
indent: int,
|
||||||
width: int,
|
width: int,
|
||||||
curr_indent: int,
|
curr_indent: int,
|
||||||
@ -402,7 +402,9 @@ class NamespaceHelper:
|
|||||||
} // namespace torch
|
} // namespace torch
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
|
def __init__(
|
||||||
|
self, namespace_str: str, entity_name: str = "", max_level: int = 2
|
||||||
|
) -> None:
|
||||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||||
cpp_namespaces = namespace_str.split("::")
|
cpp_namespaces = namespace_str.split("::")
|
||||||
assert (
|
assert (
|
||||||
@ -419,7 +421,7 @@ class NamespaceHelper:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_namespaced_entity(
|
def from_namespaced_entity(
|
||||||
namespaced_entity: str, max_level: int = 2
|
namespaced_entity: str, max_level: int = 2
|
||||||
) -> "NamespaceHelper":
|
) -> NamespaceHelper:
|
||||||
"""
|
"""
|
||||||
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
||||||
"""
|
"""
|
||||||
@ -452,9 +454,9 @@ class NamespaceHelper:
|
|||||||
|
|
||||||
|
|
||||||
class OrderedSet(Generic[T]):
|
class OrderedSet(Generic[T]):
|
||||||
storage: Dict[T, Literal[None]]
|
storage: dict[T, Literal[None]]
|
||||||
|
|
||||||
def __init__(self, iterable: Optional[Iterable[T]] = None):
|
def __init__(self, iterable: Iterable[T] | None = None) -> None:
|
||||||
if iterable is None:
|
if iterable is None:
|
||||||
self.storage = {}
|
self.storage = {}
|
||||||
else:
|
else:
|
||||||
@ -466,28 +468,28 @@ class OrderedSet(Generic[T]):
|
|||||||
def __iter__(self) -> Iterator[T]:
|
def __iter__(self) -> Iterator[T]:
|
||||||
return iter(self.storage.keys())
|
return iter(self.storage.keys())
|
||||||
|
|
||||||
def update(self, items: "OrderedSet[T]") -> None:
|
def update(self, items: OrderedSet[T]) -> None:
|
||||||
self.storage.update(items.storage)
|
self.storage.update(items.storage)
|
||||||
|
|
||||||
def add(self, item: T) -> None:
|
def add(self, item: T) -> None:
|
||||||
self.storage[item] = None
|
self.storage[item] = None
|
||||||
|
|
||||||
def copy(self) -> "OrderedSet[T]":
|
def copy(self) -> OrderedSet[T]:
|
||||||
ret: OrderedSet[T] = OrderedSet()
|
ret: OrderedSet[T] = OrderedSet()
|
||||||
ret.storage = self.storage.copy()
|
ret.storage = self.storage.copy()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
|
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
|
||||||
ret = args[0].copy()
|
ret = args[0].copy()
|
||||||
for s in args[1:]:
|
for s in args[1:]:
|
||||||
ret.update(s)
|
ret.update(s)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
|
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
|
||||||
return OrderedSet.union(self, other)
|
return OrderedSet.union(self, other)
|
||||||
|
|
||||||
def __ior__(self, other: "OrderedSet[T]") -> Self:
|
def __ior__(self, other: OrderedSet[T]) -> Self:
|
||||||
self.update(other)
|
self.update(other)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user