mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE][Easy] enable postponed annotations in torchgen
(#129376)"
This reverts commit 494057d6d4e9b40daf81a6a4d7a8c839b7424b14. Reverted https://github.com/pytorch/pytorch/pull/129376 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert to cleanly revert https://github.com/pytorch/pytorch/pull/129374, please do a rebase and reland this ([comment](https://github.com/pytorch/pytorch/pull/129375#issuecomment-2197800541))
This commit is contained in:
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Sequence
|
||||
from typing import cast, Dict, List, Match, Optional, Sequence, Set, Tuple
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
@ -50,16 +48,16 @@ class Derivative:
|
||||
original_formula: str
|
||||
|
||||
# Names of the arguments for which this formula calculates derivatives.
|
||||
var_names: tuple[str, ...]
|
||||
var_names: Tuple[str, ...]
|
||||
|
||||
# Saved inputs that are referenced by the formula.
|
||||
saved_inputs: tuple[SavedAttribute, ...]
|
||||
saved_inputs: Tuple[SavedAttribute, ...]
|
||||
|
||||
# Saved outputs that are referenced by the formula.
|
||||
saved_outputs: tuple[SavedAttribute, ...]
|
||||
saved_outputs: Tuple[SavedAttribute, ...]
|
||||
|
||||
# Gradients that are referenced by name in the formula.
|
||||
named_gradients: set[str]
|
||||
named_gradients: Set[str]
|
||||
|
||||
|
||||
# Represents a forward formula that calculates forward derivatives
|
||||
@ -73,17 +71,17 @@ class ForwardDerivative:
|
||||
|
||||
# Name of the output arguments for which this formula calculates forward
|
||||
# derivatives
|
||||
var_names: tuple[str, ...]
|
||||
var_names: Tuple[str, ...]
|
||||
|
||||
# Type of the output arguments for which this formula calculates forward
|
||||
# derivatives
|
||||
var_types: tuple[Type, ...]
|
||||
var_types: Tuple[Type, ...]
|
||||
|
||||
# Inputs for which the forward derivatives are required for this formula
|
||||
required_inputs_fw_grad: tuple[str, ...] | None
|
||||
required_inputs_fw_grad: Optional[Tuple[str, ...]]
|
||||
|
||||
# Inputs for which the primal is required for this formula
|
||||
required_inputs_primal: tuple[str, ...] | None
|
||||
required_inputs_primal: Optional[Tuple[str, ...]]
|
||||
|
||||
# Flag to specify if this formula requires the original value of self
|
||||
# This is only used by inplace operations
|
||||
@ -118,7 +116,7 @@ class DifferentiabilityInfo:
|
||||
# The name of the generated autograd function.
|
||||
# It's set only if we will calculate a derivative, i.e.
|
||||
# 'args_with_derivatives' is not empty.
|
||||
op: str | None
|
||||
op: Optional[str]
|
||||
|
||||
# The derivatives formulae for this function.
|
||||
# Note that the length of this sequence is the number of differentiable inputs
|
||||
@ -140,7 +138,7 @@ class DifferentiabilityInfo:
|
||||
|
||||
# The named gradients that are used in any of the derivatives.
|
||||
# Invariant: all(name in available_named_gradients for name in used_named_gradients)
|
||||
used_named_gradients: set[str]
|
||||
used_named_gradients: Set[str]
|
||||
|
||||
# The function's input arguments for which it calculates derivatives.
|
||||
# It's the union of 'var_names' of all 'derivatives', sorted by the
|
||||
@ -151,7 +149,7 @@ class DifferentiabilityInfo:
|
||||
non_differentiable_arg_names: Sequence[str]
|
||||
|
||||
# Raw data read from derivatives.yaml.
|
||||
output_differentiability: list[bool] | None
|
||||
output_differentiability: Optional[List[bool]]
|
||||
|
||||
# output_differentiability in derivatives.yaml can be a list of
|
||||
# conditions that express if the output is differentiable. In this case,
|
||||
@ -159,7 +157,7 @@ class DifferentiabilityInfo:
|
||||
# (NB: we only support one condition right now).
|
||||
# output_differentiability gets populated with True for each condition,
|
||||
# while output_differentiability_conditions gets populated with the conditions
|
||||
output_differentiability_conditions: list[str] | None
|
||||
output_differentiability_conditions: Optional[List[str]]
|
||||
|
||||
@property
|
||||
def has_derivatives(self) -> bool:
|
||||
@ -172,7 +170,7 @@ class DifferentiabilityInfo:
|
||||
# See Note [Codegen'd {view}_copy Operators]
|
||||
def create_view_copy_from_view_derivative(
|
||||
self, g: NativeFunctionsViewGroup
|
||||
) -> DifferentiabilityInfo | None:
|
||||
) -> Optional["DifferentiabilityInfo"]:
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
f = g.view_copy
|
||||
@ -203,7 +201,7 @@ class DifferentiabilityInfo:
|
||||
)
|
||||
|
||||
|
||||
def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
|
||||
def uses_ident(info: Optional[DifferentiabilityInfo], ident: str) -> bool:
|
||||
if info is None:
|
||||
return False
|
||||
for derivative in info.derivatives:
|
||||
@ -213,11 +211,11 @@ def uses_ident(info: DifferentiabilityInfo | None, ident: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def uses_retain_variables(info: DifferentiabilityInfo | None) -> bool:
|
||||
def uses_retain_variables(info: Optional[DifferentiabilityInfo]) -> bool:
|
||||
return uses_ident(info, "retain_variables")
|
||||
|
||||
|
||||
def uses_single_grad(info: DifferentiabilityInfo | None) -> bool:
|
||||
def uses_single_grad(info: Optional[DifferentiabilityInfo]) -> bool:
|
||||
return uses_ident(info, "grad")
|
||||
|
||||
|
||||
@ -255,8 +253,8 @@ class DifferentiableOutput:
|
||||
@dataclass(frozen=True)
|
||||
class NativeFunctionWithDifferentiabilityInfo:
|
||||
func: NativeFunction
|
||||
info: dict[str, DifferentiabilityInfo] | None
|
||||
fw_derivatives: dict[str, Sequence[ForwardDerivative]] | None
|
||||
info: Optional[Dict[str, DifferentiabilityInfo]]
|
||||
fw_derivatives: Optional[Dict[str, Sequence[ForwardDerivative]]]
|
||||
|
||||
|
||||
# TODO: Update comment below since it is out of date.
|
||||
@ -365,19 +363,19 @@ def is_reference_for_foreach(
|
||||
# TODO(crcrpar): Avoid hard coding "Default" ideally.
|
||||
def gen_foreach_derivativeinfo(
|
||||
foreach_function: NativeFunction,
|
||||
functional_info_by_signature: dict[
|
||||
FunctionSchema, dict[str, DifferentiabilityInfo]
|
||||
functional_info_by_signature: Dict[
|
||||
FunctionSchema, Dict[str, DifferentiabilityInfo]
|
||||
],
|
||||
non_functional_info_by_signature: dict[
|
||||
FunctionSchema, dict[str, DifferentiabilityInfo]
|
||||
non_functional_info_by_signature: Dict[
|
||||
FunctionSchema, Dict[str, DifferentiabilityInfo]
|
||||
],
|
||||
dispatch_key: str = "Default",
|
||||
) -> tuple[DifferentiabilityInfo | None, bool]:
|
||||
) -> Tuple[Optional[DifferentiabilityInfo], bool]:
|
||||
"""Generate DifferentiabilityInfo for out-place foreach function, return the existing one for in-place.
|
||||
|
||||
The second return value indicates whether the info is generated in this function.
|
||||
"""
|
||||
ref_diff_info: DifferentiabilityInfo | None = None
|
||||
ref_diff_info: Optional[DifferentiabilityInfo] = None
|
||||
|
||||
for function_schema, diff_info in functional_info_by_signature.items():
|
||||
if not is_reference_for_foreach(foreach_function, function_schema):
|
||||
@ -487,13 +485,13 @@ def gen_foreach_derivativeinfo(
|
||||
if arg.name in all_var_names
|
||||
]
|
||||
|
||||
forward_derivatives: list[ForwardDerivative] = []
|
||||
forward_derivatives: List[ForwardDerivative] = []
|
||||
fw_derivative: ForwardDerivative
|
||||
for fw_derivative in ref_diff_info.forward_derivatives:
|
||||
var_names: list[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
||||
var_types: list[Type] = list(fw_derivative.var_types)
|
||||
required_inputs_fw_grad: list[str] = []
|
||||
required_inputs_primal: list[str] = []
|
||||
var_names: List[str] = list(fw_derivative.var_names) # type: ignore[no-redef]
|
||||
var_types: List[Type] = list(fw_derivative.var_types)
|
||||
required_inputs_fw_grad: List[str] = []
|
||||
required_inputs_primal: List[str] = []
|
||||
if fw_derivative.required_inputs_fw_grad is not None:
|
||||
required_inputs_fw_grad = list(fw_derivative.required_inputs_fw_grad)
|
||||
if fw_derivative.required_inputs_primal:
|
||||
@ -580,9 +578,9 @@ def gen_foreach_derivativeinfo(
|
||||
|
||||
|
||||
def match_differentiability_info(
|
||||
native_functions: list[NativeFunction],
|
||||
differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]],
|
||||
) -> list[NativeFunctionWithDifferentiabilityInfo]:
|
||||
native_functions: List[NativeFunction],
|
||||
differentiability_infos: Dict[FunctionSchema, Dict[str, DifferentiabilityInfo]],
|
||||
) -> List[NativeFunctionWithDifferentiabilityInfo]:
|
||||
"""Sets the "derivative" key on declarations to matching autograd function
|
||||
In-place functions will use the out-of-place derivative definition if there
|
||||
is no in-place specific derivative.
|
||||
@ -601,7 +599,7 @@ def match_differentiability_info(
|
||||
|
||||
def find_info(
|
||||
f: NativeFunction,
|
||||
) -> tuple[dict[str, DifferentiabilityInfo] | None, bool]:
|
||||
) -> Tuple[Optional[Dict[str, DifferentiabilityInfo]], bool]:
|
||||
# Don't bother matching info to generated out= variants
|
||||
if "generated" in f.tags and f.func.kind() == SchemaKind.out:
|
||||
return None, False
|
||||
@ -655,7 +653,7 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
|
||||
return None, False
|
||||
|
||||
result: list[NativeFunctionWithDifferentiabilityInfo] = []
|
||||
result: List[NativeFunctionWithDifferentiabilityInfo] = []
|
||||
for f in native_functions:
|
||||
info_dict, is_exact_match = find_info(f)
|
||||
|
||||
@ -679,7 +677,7 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
)
|
||||
continue
|
||||
|
||||
fw_derivative_dict: dict[str, Sequence[ForwardDerivative]] = {}
|
||||
fw_derivative_dict: Dict[str, Sequence[ForwardDerivative]] = {}
|
||||
for key, info in info_dict.items():
|
||||
if not info.forward_derivatives:
|
||||
fw_derivative_dict[key] = []
|
||||
@ -715,7 +713,7 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
formula = fw_info.formula
|
||||
|
||||
def replace_self_with_original_self(formula: str, postfix: str) -> str:
|
||||
def repl(m: re.Match[str]) -> str:
|
||||
def repl(m: Match[str]) -> str:
|
||||
return f"{m.group(1)}original_self{postfix}{m.group(2)}"
|
||||
|
||||
return re.sub(IDENT_REGEX.format(f"self{postfix}"), repl, formula)
|
||||
@ -736,7 +734,7 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
formula = replace_self_with_original_self(formula, "_t")
|
||||
|
||||
# replace "result" from the formula by "self_p"
|
||||
def repl(m: re.Match[str]) -> str:
|
||||
def repl(m: Match[str]) -> str:
|
||||
return f"{m.group(1)}self_p{m.group(2)}"
|
||||
|
||||
formula = re.sub(IDENT_REGEX.format("result"), repl, formula)
|
||||
@ -760,8 +758,8 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
# If there is a need, we can relax (2) to allow any op that has an in-place variant
|
||||
is_single_method_on_self_t = False
|
||||
directly_do_inplace = False
|
||||
op_name: str | None = None
|
||||
between_parens: str | None = None
|
||||
op_name: Optional[str] = None
|
||||
between_parens: Optional[str] = None
|
||||
match = re.fullmatch(r"self_t.([\w]*)\((.*)\)", formula)
|
||||
if match:
|
||||
op_name, between_parens = match.group(1), match.group(2)
|
||||
@ -825,7 +823,7 @@ Attempted to convert a derivative formula for a mutable operator
|
||||
|
||||
|
||||
def is_differentiable(
|
||||
name: str, type: Type, info: DifferentiabilityInfo | None
|
||||
name: str, type: Type, info: Optional[DifferentiabilityInfo]
|
||||
) -> bool:
|
||||
return type.is_tensor_like() and (
|
||||
info is None or name not in info.non_differentiable_arg_names
|
||||
@ -834,10 +832,10 @@ def is_differentiable(
|
||||
|
||||
def gen_differentiable_outputs(
|
||||
fn: NativeFunctionWithDifferentiabilityInfo, key: str = "Default"
|
||||
) -> list[DifferentiableOutput]:
|
||||
) -> List[DifferentiableOutput]:
|
||||
f = fn.func
|
||||
info = fn.info[key] if fn.info else None
|
||||
outputs: list[DifferentiableOutput] = [
|
||||
outputs: List[DifferentiableOutput] = [
|
||||
DifferentiableOutput(
|
||||
name=name,
|
||||
type=ret.type,
|
||||
@ -852,7 +850,7 @@ def gen_differentiable_outputs(
|
||||
f"The length of output_differentiability ({len(output_differentiability)}), "
|
||||
f"does not match the number of outputs ({len(outputs)})."
|
||||
)
|
||||
differentiable_outputs: list[DifferentiableOutput] = []
|
||||
differentiable_outputs: List[DifferentiableOutput] = []
|
||||
if False in output_differentiability and f.func.kind() == SchemaKind.inplace:
|
||||
raise RuntimeError(
|
||||
"output_differentiability=False for inplace operation (version_counter won't get updated)"
|
||||
|
@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence, Set, Union
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
@ -96,7 +94,7 @@ def valuetype_type(
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
symint: bool = False,
|
||||
) -> NamedCType | None:
|
||||
) -> Optional[NamedCType]:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||
return None
|
||||
@ -281,7 +279,7 @@ def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
|
||||
|
||||
|
||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||
returns: list[str] = []
|
||||
returns: List[str] = []
|
||||
for i, r in enumerate(f.func.returns):
|
||||
# If we have an inplace function, the return argument is
|
||||
# implicitly named self.
|
||||
@ -370,17 +368,17 @@ def default_expr(d: str, t: Type, *, symint: bool) -> str:
|
||||
|
||||
|
||||
def argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
||||
*,
|
||||
cpp_no_default_args: set[str],
|
||||
cpp_no_default_args: Set[str],
|
||||
method: bool,
|
||||
faithful: bool,
|
||||
symint: bool = False,
|
||||
has_tensor_options: bool,
|
||||
) -> list[Binding]:
|
||||
) -> List[Binding]:
|
||||
def sub_argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
) -> list[Binding]:
|
||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
) -> List[Binding]:
|
||||
return argument(
|
||||
a,
|
||||
cpp_no_default_args=cpp_no_default_args,
|
||||
@ -396,7 +394,7 @@ def argument(
|
||||
binds = SpecialArgName.possibly_redundant_memory_format
|
||||
else:
|
||||
binds = a.name
|
||||
default: str | None = None
|
||||
default: Optional[str] = None
|
||||
if a.name not in cpp_no_default_args and a.default is not None:
|
||||
default = default_expr(a.default, a.type, symint=symint)
|
||||
return [
|
||||
@ -447,9 +445,9 @@ def arguments(
|
||||
faithful: bool,
|
||||
symint: bool = False,
|
||||
method: bool,
|
||||
cpp_no_default_args: set[str],
|
||||
) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
cpp_no_default_args: Set[str],
|
||||
) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
if faithful:
|
||||
args.extend(arguments.non_out)
|
||||
args.extend(arguments.out)
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import ArgName, Binding, CType, NamedCType
|
||||
@ -78,10 +76,10 @@ def returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
|
||||
return cpp.returns_type(rs, symint=symint)
|
||||
|
||||
|
||||
def jit_arguments(func: FunctionSchema) -> list[Argument]:
|
||||
def jit_arguments(func: FunctionSchema) -> List[Argument]:
|
||||
def to_argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
) -> list[Argument]:
|
||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
) -> List[Argument]:
|
||||
if isinstance(a, Argument):
|
||||
return [a]
|
||||
elif isinstance(a, SelfArgument):
|
||||
@ -116,5 +114,5 @@ def argument(
|
||||
)
|
||||
|
||||
|
||||
def arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
|
||||
def arguments(func: FunctionSchema, *, symint: bool = True) -> List[Binding]:
|
||||
return [argument(a, symint=symint) for a in jit_arguments(func)]
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional
|
||||
|
||||
from torchgen.api import dispatcher
|
||||
from torchgen.api.types import (
|
||||
@ -93,7 +93,7 @@ def name(
|
||||
*,
|
||||
is_reverse: bool,
|
||||
include_namespace: bool,
|
||||
reapply_views: bool | None = None,
|
||||
reapply_views: Optional[bool] = None,
|
||||
) -> str:
|
||||
if reapply_views is None:
|
||||
# reapply_views is only important for the fwd lambda,
|
||||
@ -124,7 +124,7 @@ def reverse_name(f: NativeFunction, include_namespace: bool) -> str:
|
||||
return f"{api_name}_inverse"
|
||||
|
||||
|
||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> list[Binding]:
|
||||
def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]:
|
||||
# capture arguments include all arguments except `self`.
|
||||
# Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture),
|
||||
# So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>)
|
||||
@ -152,14 +152,14 @@ def returns_type(func: FunctionSchema) -> CType:
|
||||
return BaseCType(tensorT)
|
||||
|
||||
|
||||
def outer_arguments(*, is_reverse: bool) -> list[Binding]:
|
||||
def outer_arguments(*, is_reverse: bool) -> List[Binding]:
|
||||
if is_reverse:
|
||||
return [base_binding, mutated_view_binding, mutated_view_idx_binding]
|
||||
else:
|
||||
return [base_binding, mutated_view_idx_binding]
|
||||
|
||||
|
||||
def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||
def inner_call_index(func: FunctionSchema) -> Optional[Binding]:
|
||||
# For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output.
|
||||
# When we replay a view op that returns multiple tensors, we need to index into the output appropriately
|
||||
if len(func.returns) > 1 or (
|
||||
@ -169,7 +169,7 @@ def inner_call_index(func: FunctionSchema) -> Binding | None:
|
||||
return None
|
||||
|
||||
|
||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> list[Binding]:
|
||||
def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]:
|
||||
args = func.arguments.flat_all
|
||||
assert args[0].type == BaseType(BaseTy.Tensor)
|
||||
non_self_args = args[1:]
|
||||
|
@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from torchgen.api.types import (
|
||||
BaseCppType,
|
||||
@ -36,7 +34,7 @@ from torchgen.model import (
|
||||
)
|
||||
|
||||
|
||||
_valueT: BaseCppType | None = None
|
||||
_valueT: Optional[BaseCppType] = None
|
||||
|
||||
|
||||
# A ValueT is an IR type which represents the computation of a Tensor. In other
|
||||
@ -68,8 +66,8 @@ tensorListValueT = BaseCppType("torch::lazy", "Value")
|
||||
|
||||
|
||||
def process_ir_type(
|
||||
typ: Type, properties: LazyIrProperties, *, symint: bool
|
||||
) -> BaseCType | VectorCType | OptionalCType | ListCType:
|
||||
typ: Type, properties: "LazyIrProperties", *, symint: bool
|
||||
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
|
||||
"""
|
||||
This function takes a type from NativeFunctions and converts it for use with
|
||||
lazy tensor codegen.
|
||||
@ -149,7 +147,7 @@ def process_ir_type(
|
||||
#
|
||||
# Invariant: passed typ should be an *owning* CType (e.g., we will report
|
||||
# that ArrayRef<Value> is NOT a value type)
|
||||
def isValueType(typ: CType, properties: LazyIrProperties | None = None) -> bool:
|
||||
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
|
||||
"""
|
||||
Given a type, determine if it is a Value-like type. This is equivalent to
|
||||
being Tensor-like, but assumes the type has already been transformed.
|
||||
@ -204,7 +202,7 @@ def isGeneratorType(typ: Type) -> bool:
|
||||
class LazyArgument:
|
||||
name: str
|
||||
orig_type: Type
|
||||
lazy_type_: CType | None
|
||||
lazy_type_: Optional[CType]
|
||||
is_wrapped_scalar: bool
|
||||
is_generator: bool
|
||||
# TODO: this is lies, it is false for symint list
|
||||
@ -216,9 +214,7 @@ class LazyArgument:
|
||||
# true if this argument is or contains a lazy IR value
|
||||
is_lazy_value: bool
|
||||
|
||||
def __init__(
|
||||
self, arg: Argument, properties: LazyIrProperties, *, symint: bool
|
||||
) -> None:
|
||||
def __init__(self, arg: Argument, properties: "LazyIrProperties", *, symint: bool):
|
||||
self.name = arg.name
|
||||
self.orig_type = arg.type
|
||||
self.symint = symint
|
||||
@ -252,7 +248,7 @@ class LazyIrProperties:
|
||||
attributes. The mutual exclusivity is automatically handled.
|
||||
"""
|
||||
|
||||
Properties: tuple[tuple[str, ...], ...] = (
|
||||
Properties: Tuple[Tuple[str, ...], ...] = (
|
||||
(
|
||||
"ShapePrecompute", # Assume shape has been precomputed
|
||||
"ShapeCompute", # Need to compute the shape on construction
|
||||
@ -275,8 +271,8 @@ class LazyIrProperties:
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, *default_properties: str) -> None:
|
||||
properties: dict[tuple[str, ...], str | None] = dict.fromkeys(
|
||||
def __init__(self, *default_properties: str):
|
||||
properties: Dict[Tuple[str, ...], Optional[str]] = dict.fromkeys(
|
||||
LazyIrProperties.Properties
|
||||
)
|
||||
self.__dict__["properties"] = properties
|
||||
@ -309,17 +305,17 @@ class LazyIrProperties:
|
||||
# TODO: This is not idiomatic with how other torchgen APIs transform on schema.
|
||||
class LazyIrSchema:
|
||||
# The name of the operator this function schema describes.
|
||||
name: OperatorName
|
||||
name: "OperatorName"
|
||||
|
||||
positional_args: tuple[LazyArgument, ...]
|
||||
keyword_args: tuple[LazyArgument, ...]
|
||||
positional_args: Tuple[LazyArgument, ...]
|
||||
keyword_args: Tuple[LazyArgument, ...]
|
||||
|
||||
# TODO: Need to handle collisions with argument names at some point
|
||||
returns: tuple[Return, ...]
|
||||
returns: Tuple["Return", ...]
|
||||
|
||||
# if this schema has a Generator arg, list its orig ctype/name but don't
|
||||
# build a LazyArgument since lazy IR doesn't support it
|
||||
generator_arg: NamedCType | None = None
|
||||
generator_arg: Optional[NamedCType] = None
|
||||
|
||||
# original function schema
|
||||
func: FunctionSchema
|
||||
@ -333,21 +329,21 @@ class LazyIrSchema:
|
||||
"Lower",
|
||||
"CanBeReused",
|
||||
)
|
||||
opkind: str | None = None
|
||||
opkind: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: FunctionSchema,
|
||||
properties: LazyIrProperties | None = None,
|
||||
properties: Optional[LazyIrProperties] = None,
|
||||
*,
|
||||
symint: bool,
|
||||
) -> None:
|
||||
):
|
||||
if properties:
|
||||
self.properties = properties
|
||||
|
||||
self.func = func
|
||||
self.symint = symint
|
||||
positional_args: list[LazyArgument] = []
|
||||
positional_args: List[LazyArgument] = []
|
||||
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
||||
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
||||
arg = func.arguments.self_arg.argument
|
||||
@ -361,7 +357,7 @@ class LazyIrSchema:
|
||||
)
|
||||
self.positional_args = tuple(positional_args)
|
||||
|
||||
keyword_args: list[LazyArgument] = []
|
||||
keyword_args: List[LazyArgument] = []
|
||||
for arg_field in [
|
||||
"pre_tensor_options_kwarg_only",
|
||||
"tensor_options",
|
||||
@ -415,13 +411,13 @@ class LazyIrSchema:
|
||||
values: bool = True,
|
||||
scalars: bool = True,
|
||||
generator: bool = True,
|
||||
) -> list[LazyArgument]:
|
||||
) -> List[LazyArgument]:
|
||||
# This function maintains the sorted order of arguments but provides different filtered views.
|
||||
# Some parts of the code care about kwargs vs args (TS lowerings),
|
||||
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
|
||||
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
|
||||
# in TS lowerings and therefore also omitted from lazy IR.
|
||||
args: list[LazyArgument] = []
|
||||
args: List[LazyArgument] = []
|
||||
if positional:
|
||||
args.extend(self.positional_args)
|
||||
if keyword:
|
||||
@ -443,25 +439,25 @@ class LazyIrSchema:
|
||||
return []
|
||||
|
||||
@property
|
||||
def positional_values(self) -> list[LazyArgument]:
|
||||
def positional_values(self) -> List[LazyArgument]:
|
||||
return self.filtered_args(
|
||||
positional=True, keyword=False, values=True, scalars=False
|
||||
)
|
||||
|
||||
@property
|
||||
def positional_scalars(self) -> list[LazyArgument]:
|
||||
def positional_scalars(self) -> List[LazyArgument]:
|
||||
return self.filtered_args(
|
||||
positional=True, keyword=False, values=False, scalars=True
|
||||
)
|
||||
|
||||
@property
|
||||
def keyword_values(self) -> list[LazyArgument]:
|
||||
def keyword_values(self) -> List[LazyArgument]:
|
||||
return self.filtered_args(
|
||||
positional=False, keyword=True, values=True, scalars=False
|
||||
)
|
||||
|
||||
@property
|
||||
def keyword_scalars(self) -> list[LazyArgument]:
|
||||
def keyword_scalars(self) -> List[LazyArgument]:
|
||||
return self.filtered_args(
|
||||
positional=False, keyword=True, values=False, scalars=True
|
||||
)
|
||||
|
@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence, Union
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api import cpp
|
||||
@ -83,11 +81,11 @@ def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType:
|
||||
|
||||
|
||||
def argument(
|
||||
a: Argument | SelfArgument | TensorOptionsArguments,
|
||||
a: Union[Argument, SelfArgument, TensorOptionsArguments],
|
||||
*,
|
||||
is_out: bool,
|
||||
symint: bool,
|
||||
) -> list[Binding]:
|
||||
) -> List[Binding]:
|
||||
# Ideally, we NEVER default native functions. However, there are a number
|
||||
# of functions that call native:: directly and rely on the defaulting
|
||||
# existing. So for BC, we generate defaults for non-out variants (but not
|
||||
@ -95,7 +93,7 @@ def argument(
|
||||
# default)
|
||||
should_default = not is_out
|
||||
if isinstance(a, Argument):
|
||||
default: str | None = None
|
||||
default: Optional[str] = None
|
||||
if should_default and a.default is not None:
|
||||
default = cpp.default_expr(a.default, a.type, symint=symint)
|
||||
return [
|
||||
@ -146,8 +144,8 @@ def argument(
|
||||
assert_never(a)
|
||||
|
||||
|
||||
def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
args.extend(func.arguments.non_out)
|
||||
args.extend(func.arguments.out)
|
||||
return [
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
|
||||
@ -199,14 +197,14 @@ from torchgen.model import (
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonReturns:
|
||||
returns: tuple[Return, ...]
|
||||
returns: Tuple[Return, ...]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PythonArgument:
|
||||
name: str
|
||||
type: Type
|
||||
default: str | None
|
||||
default: Optional[str]
|
||||
|
||||
# Used to generate the default init expr for some PythonArgParser outputs, e.g.:
|
||||
#
|
||||
@ -214,7 +212,7 @@ class PythonArgument:
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# ^
|
||||
# +--- default_init str
|
||||
default_init: str | None
|
||||
default_init: Optional[str]
|
||||
|
||||
# Compute argument formal for python argument parsing.
|
||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||
@ -302,10 +300,12 @@ class PythonOutArgument(PythonArgument):
|
||||
# 'auto out = _r.tensorlist_n<2>(2);',
|
||||
# then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
|
||||
# TODO: maybe don't need keep scattered out fields for python signature?
|
||||
outputs: tuple[PythonArgument, ...]
|
||||
outputs: Tuple[PythonArgument, ...]
|
||||
|
||||
@staticmethod
|
||||
def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
|
||||
def from_outputs(
|
||||
outputs: Tuple[PythonArgument, ...]
|
||||
) -> Optional["PythonOutArgument"]:
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
@ -339,13 +339,13 @@ class PythonSignature:
|
||||
|
||||
# Positional arguments.
|
||||
# TODO: create a dedicated SelfArgument type for 'self'?
|
||||
input_args: tuple[PythonArgument, ...]
|
||||
input_args: Tuple[PythonArgument, ...]
|
||||
|
||||
# Keyword arguments excluding the 'out' argument and scattered kwargs belonging
|
||||
# to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
|
||||
input_kwargs: tuple[PythonArgument, ...]
|
||||
input_kwargs: Tuple[PythonArgument, ...]
|
||||
|
||||
output_args: PythonOutArgument | None
|
||||
output_args: Optional[PythonOutArgument]
|
||||
|
||||
# Return types, which are only used by pyi
|
||||
returns: PythonReturns
|
||||
@ -356,7 +356,7 @@ class PythonSignature:
|
||||
# for out variant), in which case they will be used as scattered fields without
|
||||
# being packed into 'options'.
|
||||
# TODO: maybe create a PythonTensorOptionsArgument?
|
||||
tensor_options_args: tuple[PythonArgument, ...]
|
||||
tensor_options_args: Tuple[PythonArgument, ...]
|
||||
|
||||
# method or function signature?
|
||||
method: bool
|
||||
@ -367,8 +367,8 @@ class PythonSignature:
|
||||
|
||||
def arguments(
|
||||
self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
|
||||
) -> tuple[PythonArgument | PythonOutArgument, ...]:
|
||||
result: list[PythonArgument | PythonOutArgument] = []
|
||||
) -> Tuple[Union[PythonArgument, PythonOutArgument], ...]:
|
||||
result: List[Union[PythonArgument, PythonOutArgument]] = []
|
||||
result.extend(self.input_args)
|
||||
result.extend(self.input_kwargs)
|
||||
if self.output_args is not None and not skip_outputs:
|
||||
@ -394,7 +394,7 @@ class PythonSignature:
|
||||
# signature_str_pyi().
|
||||
def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: list[str] = [
|
||||
schema_formals: List[str] = [
|
||||
a.argument_str(method=self.method, symint=symint) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -405,7 +405,7 @@ class PythonSignature:
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: list[str] = [
|
||||
schema_formals: List[str] = [
|
||||
a.argument_str_pyi(method=self.method) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -419,10 +419,10 @@ class PythonSignature:
|
||||
schema_formals.insert(0, "self")
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
||||
# only pyi uses vararg signatures
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: list[str] = [
|
||||
schema_formals: List[str] = [
|
||||
a.argument_str_pyi(method=self.method) for a in args
|
||||
]
|
||||
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
|
||||
@ -470,7 +470,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
# [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
|
||||
# [func call]: self.addmm(mat1, mat2, beta, 1)
|
||||
# We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
|
||||
deprecated_args_exprs: tuple[str, ...]
|
||||
deprecated_args_exprs: Tuple[str, ...]
|
||||
|
||||
@property
|
||||
def deprecated(self) -> bool:
|
||||
@ -486,7 +486,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
|
||||
def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
|
||||
args = self.arguments(skip_outputs=skip_outputs)
|
||||
schema_formals: list[str] = [
|
||||
schema_formals: List[str] = [
|
||||
a.argument_str_pyi(method=self.method, deprecated=True) for a in args
|
||||
]
|
||||
positional_argc = len(self.input_args)
|
||||
@ -496,7 +496,7 @@ class PythonSignatureDeprecated(PythonSignature):
|
||||
returns_str = returns_str_pyi(self)
|
||||
return f'def {self.name}({", ".join(schema_formals)}) -> {returns_str}: ...'
|
||||
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
|
||||
def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> Optional[str]:
|
||||
# the codegen doesn't include vararg variants for deprecated signatures
|
||||
return None
|
||||
|
||||
@ -530,14 +530,14 @@ class PythonSignatureGroup:
|
||||
base: NativeFunction
|
||||
|
||||
# The out variant (e.g. conv2d_out)
|
||||
outplace: NativeFunction | None
|
||||
outplace: Optional[NativeFunction]
|
||||
|
||||
@classmethod
|
||||
def from_pairs(
|
||||
cls,
|
||||
functional: PythonSignatureNativeFunctionPair,
|
||||
out: PythonSignatureNativeFunctionPair | None,
|
||||
) -> PythonSignatureGroup:
|
||||
out: Optional[PythonSignatureNativeFunctionPair],
|
||||
) -> "PythonSignatureGroup":
|
||||
if out is None:
|
||||
return PythonSignatureGroup(
|
||||
signature=functional.signature,
|
||||
@ -716,7 +716,7 @@ def argument_type_str(
|
||||
raise RuntimeError(f"unrecognized type {repr(t)}")
|
||||
|
||||
|
||||
def argument_type_size(t: Type) -> int | None:
|
||||
def argument_type_size(t: Type) -> Optional[int]:
|
||||
l = t.is_list_like()
|
||||
if l is not None and str(l.elem) != "bool":
|
||||
return l.size
|
||||
@ -750,11 +750,11 @@ def signature(
|
||||
def signature_from_schema(
|
||||
func: FunctionSchema,
|
||||
*,
|
||||
category_override: str | None,
|
||||
category_override: Optional[str],
|
||||
method: bool = False,
|
||||
pyi: bool = False,
|
||||
) -> PythonSignature:
|
||||
args: list[Argument] = []
|
||||
args: List[Argument] = []
|
||||
args.extend(func.arguments.pre_self_positional)
|
||||
# Skip SelfArgument if this is method.
|
||||
if not method and func.arguments.self_arg is not None:
|
||||
@ -807,10 +807,10 @@ def signature_from_schema(
|
||||
)
|
||||
is_dummy_function = category_override == "dummy"
|
||||
|
||||
tensor_options_args: list[PythonArgument] = []
|
||||
tensor_options_args: List[PythonArgument] = []
|
||||
if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
|
||||
|
||||
def topt_default_init(name: str) -> str | None:
|
||||
def topt_default_init(name: str) -> Optional[str]:
|
||||
topt_args = func.arguments.tensor_options
|
||||
if topt_args is None:
|
||||
return None
|
||||
@ -891,7 +891,7 @@ def signature_from_schema(
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
|
||||
|
||||
def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
|
||||
def structseq_fieldnames(returns: Tuple[Return, ...]) -> List[str]:
|
||||
if len(returns) <= 1 or all(r.name is None for r in returns):
|
||||
return []
|
||||
else:
|
||||
@ -1002,7 +1002,7 @@ def return_type_str_pyi(t: Type) -> str:
|
||||
return argument_type_str_pyi(t)
|
||||
|
||||
|
||||
def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
|
||||
def returns_structseq_pyi(signature: PythonSignature) -> Optional[Tuple[str, str]]:
|
||||
python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
|
||||
structseq_name = signature.name
|
||||
field_names = structseq_fieldnames(signature.returns.returns)
|
||||
@ -1104,7 +1104,7 @@ def returns_str_pyi(signature: PythonSignature) -> str:
|
||||
|
||||
def dispatch_lambda_args(
|
||||
ps: PythonSignature, f: NativeFunction, symint: bool = True
|
||||
) -> tuple[DispatchLambdaArgument, ...]:
|
||||
) -> Tuple[DispatchLambdaArgument, ...]:
|
||||
if isinstance(ps, PythonSignatureDeprecated):
|
||||
schema = ps.deprecated_schema
|
||||
else:
|
||||
@ -1118,7 +1118,7 @@ def dispatch_lambda_args(
|
||||
method=False,
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
out_args: set[str] = {a.name for a in schema.arguments.out}
|
||||
out_args: Set[str] = {a.name for a in schema.arguments.out}
|
||||
|
||||
# Convert from cpp argument to lambda argument
|
||||
def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
|
||||
@ -1224,11 +1224,11 @@ def cpp_dispatch_target(f: NativeFunction) -> str:
|
||||
def cpp_dispatch_exprs(
|
||||
f: NativeFunction,
|
||||
*,
|
||||
python_signature: PythonSignature | None = None,
|
||||
) -> tuple[str, ...]:
|
||||
python_signature: Optional[PythonSignature] = None,
|
||||
) -> Tuple[str, ...]:
|
||||
cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
|
||||
|
||||
exprs: tuple[str, ...] = tuple()
|
||||
exprs: Tuple[str, ...] = tuple()
|
||||
if not isinstance(python_signature, PythonSignatureDeprecated):
|
||||
# By default the exprs are consistent with the C++ signature.
|
||||
exprs = tuple(a.name for a in cpp_args)
|
||||
@ -1262,7 +1262,7 @@ def cpp_dispatch_exprs(
|
||||
# For certain cases it is intentionally more restrictive than necessary,
|
||||
# e.g.: it doesn't accepts doublelist with definite size.
|
||||
def arg_parser_unpack_method(
|
||||
t: Type, default: str | None, default_init: str | None, *, symint: bool = True
|
||||
t: Type, default: Optional[str], default_init: Optional[str], *, symint: bool = True
|
||||
) -> str:
|
||||
has_default_init = default_init is not None
|
||||
if has_default_init and str(t) not in (
|
||||
@ -1377,7 +1377,7 @@ def arg_parser_output_expr(
|
||||
# Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
|
||||
def arg_parser_output_exprs(
|
||||
ps: PythonSignature, f: NativeFunction, *, symint: bool = True
|
||||
) -> dict[str, PythonArgParserOutputExpr]:
|
||||
) -> Dict[str, PythonArgParserOutputExpr]:
|
||||
return {
|
||||
e.name: e
|
||||
for i, a in enumerate(ps.arguments())
|
||||
@ -1404,8 +1404,8 @@ def dispatch_lambda_exprs(
|
||||
# outputs.
|
||||
arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
|
||||
lambda_args = dispatch_lambda_args(ps, f, symint=symint)
|
||||
inits: list[str] = []
|
||||
lambda_args_exprs: dict[str, str] = {}
|
||||
inits: List[str] = []
|
||||
lambda_args_exprs: Dict[str, str] = {}
|
||||
|
||||
has_toptions = has_tensor_options(f)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Union
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import (
|
||||
@ -97,7 +97,7 @@ def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
|
||||
|
||||
|
||||
# Structured kernels are never defaulted
|
||||
def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]:
|
||||
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments]) -> List[Binding]:
|
||||
if isinstance(a, Argument):
|
||||
return [
|
||||
Binding(
|
||||
@ -115,15 +115,15 @@ def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Bindin
|
||||
assert_never(a)
|
||||
|
||||
|
||||
def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
def impl_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
|
||||
if g.out.precomputed:
|
||||
# A list of parameters for the impl function with
|
||||
# certain parameters replaced with precomputed counterparts
|
||||
# as specified in native_functions.yaml.
|
||||
non_out_args_replaced: list[
|
||||
Argument | TensorOptionsArguments | SelfArgument
|
||||
non_out_args_replaced: List[
|
||||
Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
] = []
|
||||
for a in g.out.func.arguments.non_out:
|
||||
if isinstance(a, Argument) and a.name in g.out.precomputed.replace:
|
||||
@ -145,13 +145,13 @@ def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||
return [r for arg in args for r in argument(arg)]
|
||||
|
||||
|
||||
def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
def meta_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
args.extend(g.functional.func.arguments.non_out)
|
||||
return [r for arg in args for r in argument(arg)]
|
||||
|
||||
|
||||
def out_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
def out_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
args.extend(g.out.func.arguments.out)
|
||||
return [r for arg in args for r in argument(arg)]
|
||||
|
@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn, Sequence
|
||||
from typing import Dict, List, NoReturn, Sequence, Union
|
||||
|
||||
from torchgen.api.types import (
|
||||
ArrayRefCType,
|
||||
@ -97,13 +95,13 @@ class UnsatError(RuntimeError):
|
||||
# something more complicated, e.g., tracking the set of bindings in a context,
|
||||
# you may find using these smaller types more convenient.
|
||||
def translate(
|
||||
bindings: Sequence[Expr | Binding],
|
||||
goals: Sequence[NamedCType | Binding],
|
||||
bindings: Sequence[Union[Expr, Binding]],
|
||||
goals: Sequence[Union[NamedCType, Binding]],
|
||||
*,
|
||||
method: bool = False,
|
||||
allow_expensive_conversions: bool = False,
|
||||
) -> list[Expr]:
|
||||
binding_exprs: list[Expr] = []
|
||||
) -> List[Expr]:
|
||||
binding_exprs: List[Expr] = []
|
||||
for b in bindings:
|
||||
if isinstance(b, Binding):
|
||||
binding_exprs.append(
|
||||
@ -115,7 +113,7 @@ def translate(
|
||||
else:
|
||||
binding_exprs.append(b)
|
||||
|
||||
goal_ctypes: list[NamedCType] = []
|
||||
goal_ctypes: List[NamedCType] = []
|
||||
for g in goals:
|
||||
if isinstance(g, Binding):
|
||||
goal_ctypes.append(g.nctype)
|
||||
@ -123,7 +121,7 @@ def translate(
|
||||
goal_ctypes.append(g)
|
||||
|
||||
# Add all the bindings to the context
|
||||
ctx: dict[NamedCType, str] = {}
|
||||
ctx: Dict[NamedCType, str] = {}
|
||||
for b in binding_exprs:
|
||||
ctx[b.type] = b.expr
|
||||
|
||||
|
@ -1,19 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Sequence, TYPE_CHECKING
|
||||
from typing import Iterator, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.api.types.types_base import Binding, CType, Expr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
)
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
FunctionSchema,
|
||||
NativeFunction,
|
||||
NativeFunctionsGroup,
|
||||
NativeFunctionsViewGroup,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -43,7 +38,7 @@ class CppSignature:
|
||||
symint: bool
|
||||
|
||||
# The set of C++ arguments which should not have defaults applied to them
|
||||
cpp_no_default_args: set[str]
|
||||
cpp_no_default_args: Set[str]
|
||||
|
||||
# Is this a fallback C++ binding? Fallback bindings are enabled by
|
||||
# manual_cpp_binding: True and are alternate, non-public API that
|
||||
@ -77,7 +72,7 @@ class CppSignature:
|
||||
def decl(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
name: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
is_redispatching_fn: bool = False,
|
||||
suppress_symint_suffix: bool = False,
|
||||
@ -98,7 +93,7 @@ class CppSignature:
|
||||
def defn(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
name: Optional[str] = None,
|
||||
prefix: str = "",
|
||||
is_redispatching_fn: bool = False,
|
||||
) -> str:
|
||||
@ -131,9 +126,9 @@ class CppSignature:
|
||||
class CppSignatureGroup:
|
||||
func: FunctionSchema
|
||||
signature: CppSignature
|
||||
faithful_signature: CppSignature | None
|
||||
symint_signature: CppSignature | None
|
||||
symint_faithful_signature: CppSignature | None
|
||||
faithful_signature: Optional[CppSignature]
|
||||
symint_signature: Optional[CppSignature]
|
||||
symint_faithful_signature: Optional[CppSignature]
|
||||
|
||||
def most_faithful_signature(self) -> CppSignature:
|
||||
if self.faithful_signature:
|
||||
@ -154,7 +149,7 @@ class CppSignatureGroup:
|
||||
@staticmethod
|
||||
def from_native_function(
|
||||
f: NativeFunction, *, method: bool, fallback_binding: bool = False
|
||||
) -> CppSignatureGroup:
|
||||
) -> "CppSignatureGroup":
|
||||
func = f.func
|
||||
|
||||
def make_sig(*, faithful: bool, symint: bool) -> CppSignature:
|
||||
@ -167,16 +162,16 @@ class CppSignatureGroup:
|
||||
cpp_no_default_args=f.cpp_no_default_args,
|
||||
)
|
||||
|
||||
def make_sigs(*, symint: bool) -> tuple[CppSignature, CppSignature | None]:
|
||||
faithful_signature: CppSignature | None = None
|
||||
def make_sigs(*, symint: bool) -> Tuple[CppSignature, Optional[CppSignature]]:
|
||||
faithful_signature: Optional[CppSignature] = None
|
||||
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
|
||||
faithful_signature = make_sig(faithful=True, symint=symint)
|
||||
signature = make_sig(faithful=False, symint=symint)
|
||||
return signature, faithful_signature
|
||||
|
||||
signature, faithful_signature = make_sigs(symint=False)
|
||||
symint_signature: CppSignature | None = None
|
||||
symint_faithful_signature: CppSignature | None = None
|
||||
symint_signature: Optional[CppSignature] = None
|
||||
symint_faithful_signature: Optional[CppSignature] = None
|
||||
if func.has_symint():
|
||||
symint_signature, symint_faithful_signature = make_sigs(symint=True)
|
||||
|
||||
@ -201,20 +196,20 @@ class DispatcherSignature:
|
||||
|
||||
symint: bool = True
|
||||
|
||||
def arguments(self) -> list[Binding]:
|
||||
def arguments(self) -> List[Binding]:
|
||||
return dispatcher.arguments(self.func, symint=self.symint)
|
||||
|
||||
def name(self) -> str:
|
||||
return self.prefix + dispatcher.name(self.func)
|
||||
|
||||
def decl(self, name: str | None = None) -> str:
|
||||
def decl(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
||||
def defn(
|
||||
self, name: str | None = None, *, is_redispatching_fn: bool = False
|
||||
self, name: Optional[str] = None, *, is_redispatching_fn: bool = False
|
||||
) -> str:
|
||||
args = [a.defn() for a in self.arguments()]
|
||||
if is_redispatching_fn:
|
||||
@ -224,7 +219,7 @@ class DispatcherSignature:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
||||
def exprs(self) -> list[Expr]:
|
||||
def exprs(self) -> List[Expr]:
|
||||
return [Expr(a.name, a.nctype) for a in self.arguments()]
|
||||
|
||||
def returns_type(self) -> CType:
|
||||
@ -242,7 +237,7 @@ class DispatcherSignature:
|
||||
@staticmethod
|
||||
def from_schema(
|
||||
func: FunctionSchema, *, prefix: str = "", symint: bool = True
|
||||
) -> DispatcherSignature:
|
||||
) -> "DispatcherSignature":
|
||||
return DispatcherSignature(func, prefix, symint)
|
||||
|
||||
|
||||
@ -258,13 +253,13 @@ class NativeSignature:
|
||||
def name(self) -> str:
|
||||
return self.prefix + native.name(self.func)
|
||||
|
||||
def decl(self, name: str | None = None) -> str:
|
||||
def decl(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.decl() for a in self.arguments())
|
||||
if name is None:
|
||||
name = self.name()
|
||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} {name}({args_str})"
|
||||
|
||||
def defn(self, name: str | None = None) -> str:
|
||||
def defn(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||
if name is None:
|
||||
name = self.name()
|
||||
@ -275,13 +270,13 @@ class NativeSignature:
|
||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||
return f"{native.returns_type(self.func.returns, symint=self.symint).cpp_type()} (*)({args_str})"
|
||||
|
||||
def arguments(self) -> list[Binding]:
|
||||
def arguments(self) -> List[Binding]:
|
||||
return native.arguments(self.func, symint=self.symint)
|
||||
|
||||
def returns_type(self) -> CType:
|
||||
return native.returns_type(self.func.returns, symint=self.symint)
|
||||
|
||||
def dispatcher_exprs(self) -> list[Expr]:
|
||||
def dispatcher_exprs(self) -> List[Expr]:
|
||||
return translate.translate(
|
||||
self.arguments(), dispatcher.arguments(self.func), method=False
|
||||
)
|
||||
@ -312,7 +307,7 @@ class FunctionalizationLambda:
|
||||
# are we generating the forward lambda or the reverse lambda?
|
||||
is_reverse: bool
|
||||
|
||||
def captures(self) -> list[Expr]:
|
||||
def captures(self) -> List[Expr]:
|
||||
# The lambda lives inside of a kernel following the dispatcher API, so its outer context is the dispatcher arguments
|
||||
# We also need to read the "reapply views" TLS at the time that the functionalization kernel was executed,
|
||||
# and plumb it into the lambda.
|
||||
@ -341,7 +336,7 @@ class FunctionalizationLambda:
|
||||
]
|
||||
return f"[{capture_str}]({', '.join(decls)}) -> {return_type.cpp_type()}"
|
||||
|
||||
def inner_call(self, *, reapply_views: bool | None = None) -> str:
|
||||
def inner_call(self, *, reapply_views: Optional[bool] = None) -> str:
|
||||
inner_call_name = functionalization.name(
|
||||
self.g,
|
||||
is_reverse=self.is_reverse,
|
||||
@ -371,7 +366,7 @@ class FunctionalizationLambda:
|
||||
@staticmethod
|
||||
def from_func(
|
||||
g: NativeFunctionsViewGroup, *, is_reverse: bool
|
||||
) -> FunctionalizationLambda:
|
||||
) -> "FunctionalizationLambda":
|
||||
return FunctionalizationLambda(g, is_reverse)
|
||||
|
||||
|
||||
@ -380,11 +375,11 @@ class StructuredImplSignature:
|
||||
g: NativeFunctionsGroup
|
||||
name: str
|
||||
|
||||
def defn(self, name: str | None = None) -> str:
|
||||
def defn(self, name: Optional[str] = None) -> str:
|
||||
args_str = ", ".join(a.defn() for a in self.arguments())
|
||||
return f"TORCH_IMPL_FUNC({self.name})({args_str})"
|
||||
|
||||
def arguments(self) -> list[Binding]:
|
||||
def arguments(self) -> List[Binding]:
|
||||
return structured.impl_arguments(self.g)
|
||||
|
||||
|
||||
@ -393,7 +388,7 @@ class StructuredImplSignature:
|
||||
|
||||
def kernel_signature(
|
||||
f: NativeFunction, backend_index: BackendIndex, *, prefix: str = ""
|
||||
) -> NativeSignature | DispatcherSignature:
|
||||
) -> Union["NativeSignature", "DispatcherSignature"]:
|
||||
# Note [External Backends Follow Dispatcher API]
|
||||
# Kernel signatures for in-tree backends follow the "native" API,
|
||||
# while kernels for out-of-tree backends follow the dispatcher API.
|
||||
|
@ -12,10 +12,8 @@ if we want to generate code for another C++ library.
|
||||
Add new types to `types.py` if these types are ATen/c10 related.
|
||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from torchgen.api.types.types_base import (
|
||||
BaseCppType,
|
||||
@ -85,7 +83,7 @@ symIntArrayRefT = BaseCppType("c10", "SymIntArrayRef")
|
||||
scalar_t = BaseCppType("", "scalar_t")
|
||||
opmath_t = BaseCppType("", "opmath_t")
|
||||
|
||||
ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
|
||||
ScalarTypeToCppMapping: Dict[ScalarType, BaseCppType] = {
|
||||
ScalarType.Byte: byteT,
|
||||
ScalarType.Char: charT,
|
||||
ScalarType.Short: shortT,
|
||||
@ -104,7 +102,7 @@ ScalarTypeToCppMapping: dict[ScalarType, BaseCppType] = {
|
||||
ScalarType.Float8_e4m3fnuz: float8_e4m3fnuzT,
|
||||
}
|
||||
|
||||
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
@ -130,7 +128,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptionalCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -139,13 +137,13 @@ class OptionalCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return OptionalCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -154,13 +152,13 @@ class ListCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"c10::List<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ListCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayRefCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -169,7 +167,7 @@ class ArrayRefCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayRefCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@ -187,5 +185,5 @@ class VectorizedCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
@ -12,17 +12,12 @@ if we want to generate code for another C++ library.
|
||||
Add new types to `types.py` if these types are ATen/c10 related.
|
||||
Add new types to `types_base.py` if they are basic and not attached to ATen/c10.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
|
||||
from torchgen.model import Argument, SelfArgument, TensorOptionsArguments
|
||||
|
||||
|
||||
# An ArgName is just the str name of the argument in schema;
|
||||
@ -41,7 +36,7 @@ ArgName = Union[str, SpecialArgName]
|
||||
# This class shouldn't be created directly; instead, use/create one of the singletons below.
|
||||
@dataclass(frozen=True)
|
||||
class BaseCppType:
|
||||
ns: str | None
|
||||
ns: Optional[str]
|
||||
name: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -76,7 +71,7 @@ class CType(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
@ -92,13 +87,13 @@ class BaseCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return str(self.type).replace("at::", "")
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConstRefCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
@ -108,13 +103,13 @@ class ConstRefCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"const {self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VectorCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -123,13 +118,13 @@ class VectorCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::vector<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return VectorCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
size: int
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
@ -139,13 +134,13 @@ class ArrayCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"::std::array<{self.elem.cpp_type_registration_declarations()},{self.size}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayCType(self.elem.remove_const_ref(), self.size)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TupleCType(CType):
|
||||
elems: list[CType]
|
||||
elems: List["CType"]
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -154,13 +149,13 @@ class TupleCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f'::std::tuple<{",".join([e.cpp_type_registration_declarations() for e in self.elems])}>'
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return TupleCType([e.remove_const_ref() for e in self.elems])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MutRefCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
if strip_ref:
|
||||
@ -170,7 +165,7 @@ class MutRefCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"{self.elem.cpp_type_registration_declarations()} &"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return self.elem.remove_const_ref()
|
||||
|
||||
|
||||
@ -195,10 +190,10 @@ class NamedCType:
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return self.type.cpp_type_registration_declarations()
|
||||
|
||||
def remove_const_ref(self) -> NamedCType:
|
||||
def remove_const_ref(self) -> "NamedCType":
|
||||
return NamedCType(self.name, self.type.remove_const_ref())
|
||||
|
||||
def with_name(self, name: str) -> NamedCType:
|
||||
def with_name(self, name: str) -> "NamedCType":
|
||||
return NamedCType(name, self.type)
|
||||
|
||||
|
||||
@ -213,11 +208,11 @@ class NamedCType:
|
||||
class Binding:
|
||||
name: str
|
||||
nctype: NamedCType
|
||||
argument: Argument | TensorOptionsArguments | SelfArgument
|
||||
argument: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
# TODO: maybe don't represent default here
|
||||
default: str | None = None
|
||||
default: Optional[str] = None
|
||||
|
||||
def rename(self, name: str) -> Binding:
|
||||
def rename(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name,
|
||||
nctype=self.nctype,
|
||||
@ -229,7 +224,7 @@ class Binding:
|
||||
def type(self) -> str:
|
||||
return self.nctype.cpp_type()
|
||||
|
||||
def no_default(self) -> Binding:
|
||||
def no_default(self) -> "Binding":
|
||||
return Binding(
|
||||
name=self.name,
|
||||
nctype=self.nctype,
|
||||
@ -260,7 +255,7 @@ class Binding:
|
||||
def defn(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
def with_name(self, name: str) -> Binding:
|
||||
def with_name(self, name: str) -> "Binding":
|
||||
return Binding(
|
||||
name=name, nctype=self.nctype, argument=self.argument, default=self.default
|
||||
)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torchgen.api.types as api_types
|
||||
from torchgen.api import cpp, structured
|
||||
@ -39,7 +38,7 @@ def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
|
||||
# argument registers)
|
||||
#
|
||||
# NB: used for CPU only
|
||||
def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
|
||||
def dispatchstub_type(t: Type, *, binds: ArgName) -> Optional[NamedCType]:
|
||||
# Dispatch stubs are always plain ints
|
||||
r = cpp.valuetype_type(t, binds=binds, symint=False)
|
||||
if r is not None:
|
||||
@ -135,8 +134,8 @@ def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UfunctorBindings:
|
||||
ctor: list[Binding]
|
||||
apply: list[Binding]
|
||||
ctor: List[Binding]
|
||||
apply: List[Binding]
|
||||
|
||||
|
||||
# ufunctors are a CUDA-only concept representing functors that take some of
|
||||
@ -157,7 +156,7 @@ class UfunctorBindings:
|
||||
# The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
|
||||
# to the operator() definition
|
||||
def ufunctor_arguments(
|
||||
g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
|
||||
g: NativeFunctionsGroup, *, scalar_tensor_idx: Optional[int], scalar_t: BaseCppType
|
||||
) -> UfunctorBindings:
|
||||
ctor = []
|
||||
apply = []
|
||||
@ -186,7 +185,7 @@ def ufunctor_arguments(
|
||||
# }
|
||||
#
|
||||
# In this file, we refer to T as compute_t which is bound by caller
|
||||
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
|
||||
def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> List[Binding]:
|
||||
return [
|
||||
ufunc_argument(a, compute_t=compute_t)
|
||||
for a in g.functional.func.arguments.flat_non_out
|
||||
@ -198,7 +197,7 @@ def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Bindin
|
||||
#
|
||||
# using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
|
||||
# DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
|
||||
def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
|
||||
def stub_arguments(g: NativeFunctionsGroup) -> List[Binding]:
|
||||
# stubs drop all tensor arguments (they are implicit in the TensorIterator
|
||||
# argument and keep everything else)
|
||||
return [
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Tuple
|
||||
|
||||
from torchgen.api import cpp
|
||||
from torchgen.api.types import Binding, CppSignatureGroup, CType
|
||||
@ -103,7 +103,7 @@ def name(f: NativeFunction) -> str:
|
||||
|
||||
|
||||
# Convert all the arguments in a NativeFunction to C++ code
|
||||
def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
|
||||
def convert_arguments(f: NativeFunction) -> Tuple[List[Binding], List[str]]:
|
||||
# we need the 'self' argument so method needs to be False
|
||||
args = (
|
||||
CppSignatureGroup.from_native_function(f, method=False)
|
||||
@ -138,7 +138,7 @@ def convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
|
||||
# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
|
||||
def argumenttype_ivalue_convert(
|
||||
t: Type, arg_name: str, *, mutable: bool = False
|
||||
) -> tuple[str, CType, list[str], list[str]]:
|
||||
) -> Tuple[str, CType, List[str], List[str]]:
|
||||
# Unboxing is for mobile, which doesn't care about SymInts
|
||||
ctype = cpp.argumenttype_type(
|
||||
t=t, mutable=mutable, binds=arg_name, symint=False
|
||||
@ -172,7 +172,7 @@ def argumenttype_ivalue_convert(
|
||||
|
||||
def _gen_code_base_type(
|
||||
arg_name: str, out_name: str, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
return [
|
||||
f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||
], []
|
||||
@ -180,7 +180,7 @@ def _gen_code_base_type(
|
||||
|
||||
def _gen_code_optional_type(
|
||||
arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
in_name = f"{arg_name}_opt_in"
|
||||
res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
|
||||
return (
|
||||
@ -203,7 +203,7 @@ if ({arg_name}_opt.has_value()) {{
|
||||
|
||||
def _gen_code_list_type(
|
||||
arg_name: str, out_name: str, t: ListType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
in_name = f"{arg_name}_list_in"
|
||||
elem_name = f"{arg_name}_elem"
|
||||
code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Mapping, Sequence
|
||||
from typing import Mapping, Match, Optional, Sequence
|
||||
|
||||
|
||||
# match $identifier or ${identifier} and replace with value in env
|
||||
@ -22,7 +20,7 @@ class CodeTemplate:
|
||||
filename: str
|
||||
|
||||
@staticmethod
|
||||
def from_file(filename: str) -> CodeTemplate:
|
||||
def from_file(filename: str) -> "CodeTemplate":
|
||||
with open(filename) as f:
|
||||
return CodeTemplate(f.read(), filename)
|
||||
|
||||
@ -31,7 +29,7 @@ class CodeTemplate:
|
||||
self.filename = filename
|
||||
|
||||
def substitute(
|
||||
self, env: Mapping[str, object] | None = None, **kwargs: object
|
||||
self, env: Optional[Mapping[str, object]] = None, **kwargs: object
|
||||
) -> str:
|
||||
if env is None:
|
||||
env = {}
|
||||
@ -45,7 +43,7 @@ class CodeTemplate:
|
||||
[indent + l + "\n" for e in v for l in str(e).splitlines()]
|
||||
).rstrip()
|
||||
|
||||
def replace(match: re.Match[str]) -> str:
|
||||
def replace(match: Match[str]) -> str:
|
||||
indent = match.group(1)
|
||||
key = match.group(2)
|
||||
comma_before = ""
|
||||
|
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import torchgen.local as local
|
||||
from torchgen.model import (
|
||||
@ -40,7 +38,7 @@ F3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
|
||||
|
||||
@contextlib.contextmanager
|
||||
def native_function_manager(
|
||||
g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
|
||||
g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction]
|
||||
) -> Iterator[None]:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
# By default, we associate all errors with structured native functions
|
||||
@ -120,10 +118,10 @@ def with_native_function_and_index(
|
||||
|
||||
# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
|
||||
def with_native_function_and_indices(
|
||||
func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
|
||||
) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
|
||||
func: Callable[[F, Dict[DispatchKey, BackendIndex]], T]
|
||||
) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
|
||||
def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T:
|
||||
with native_function_manager(f):
|
||||
return func(f, backend_indices)
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.lazy import (
|
||||
@ -111,7 +109,7 @@ def node_ctor_inputs(schema: LazyIrSchema) -> str:
|
||||
|
||||
def gen_fallback_code(
|
||||
schema: LazyIrSchema,
|
||||
sig: DispatcherSignature | NativeSignature,
|
||||
sig: Union[DispatcherSignature, NativeSignature],
|
||||
overload_name: str,
|
||||
) -> str:
|
||||
"""
|
||||
@ -149,9 +147,9 @@ def aten_symbol(schema: LazyIrSchema) -> str:
|
||||
# converts all tensor-like arguments to meta tensors. Returns:
|
||||
# (1) a string containing all of the logic that does the conversions.
|
||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
|
||||
context: list[Binding] = []
|
||||
unwrapped_tensor_args: list[str] = []
|
||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
||||
context: List[Binding] = []
|
||||
unwrapped_tensor_args: List[str] = []
|
||||
for arg in sig.arguments():
|
||||
if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
|
||||
unwrapped_name = f"{arg.name}_meta"
|
||||
@ -173,7 +171,7 @@ class GenLazyIR(ABC):
|
||||
use_lazy_shape: bool
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
|
||||
metadata = self.backend_index.get_kernel(
|
||||
f.functional if isinstance(f, NativeFunctionsGroup) else f
|
||||
@ -238,7 +236,7 @@ class GenLazyIR(ABC):
|
||||
/* num_outputs */ {len(schema.returns)},
|
||||
torch::lazy::MHash({scalar_hashes}))"""
|
||||
|
||||
def gen(self, schema: LazyIrSchema) -> list[str]:
|
||||
def gen(self, schema: LazyIrSchema) -> List[str]:
|
||||
opkind = schema.opkind or aten_symbol(schema)
|
||||
|
||||
# for now, we just want one IR class decl and soon after also the method defs
|
||||
@ -415,7 +413,7 @@ class GenLazyNativeFuncDefinition:
|
||||
def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
|
||||
value_args = schema.filtered_args(values=True, scalars=False)
|
||||
# Generates lazy_{name} variables for LazyTensors wrapping input tensors
|
||||
lazy_tensor_decls: list[str] = []
|
||||
lazy_tensor_decls: List[str] = []
|
||||
for arg in value_args:
|
||||
if arg.is_wrapped_scalar:
|
||||
if isinstance(arg.lazy_type, OptionalCType):
|
||||
@ -462,7 +460,7 @@ class GenLazyNativeFuncDefinition:
|
||||
func: NativeFunction,
|
||||
schema: LazyIrSchema,
|
||||
metadata: BackendMetadata,
|
||||
sig: DispatcherSignature | NativeSignature,
|
||||
sig: Union[DispatcherSignature, NativeSignature],
|
||||
) -> str:
|
||||
if self.gen_forced_fallback_code:
|
||||
return gen_fallback_code(
|
||||
@ -576,7 +574,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
||||
}}
|
||||
"""
|
||||
|
||||
def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
|
||||
def create_lazy_tensor(self, first_tensor_name: Optional[str] = None) -> str:
|
||||
# xla uses an instance method for tensor creation, for the time being
|
||||
if self.create_from_first_tensor:
|
||||
# TODO(whc) remove this if XLA switches to using static method for creation
|
||||
@ -617,7 +615,7 @@ std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type()
|
||||
return bridge_str
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, func: NativeFunction) -> list[str]:
|
||||
def __call__(self, func: NativeFunction) -> List[str]:
|
||||
sig = kernel_signature(func, self.backend_index)
|
||||
metadata = self.backend_index.get_kernel(func)
|
||||
assert metadata is not None
|
||||
@ -641,7 +639,7 @@ class ComputeShapeSignature:
|
||||
Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
|
||||
def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool):
|
||||
self.__schema = LazyIrSchema(f.func, symint=symint)
|
||||
self.__dispatch_args = ", ".join(
|
||||
[a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
|
||||
@ -672,7 +670,7 @@ class GenLazyShapeInferenceDefinition:
|
||||
tensor_class: str
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> list[str]:
|
||||
def __call__(self, f: NativeFunction) -> List[str]:
|
||||
metadata = self.backend_index.get_kernel(f)
|
||||
assert metadata is not None
|
||||
|
||||
@ -689,8 +687,8 @@ class GenLazyShapeInferenceDefinition:
|
||||
|
||||
|
||||
def generate_non_native_lazy_ir_nodes(
|
||||
non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
|
||||
) -> list[str]:
|
||||
non_native: List[Dict[str, Any]], gen_lazy_ir: GenLazyIR
|
||||
) -> List[str]:
|
||||
"""Generate the non-native lazy IR node classes"""
|
||||
nodes = []
|
||||
for op in non_native:
|
||||
|
@ -1,4 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torchgen.api.meta as meta
|
||||
import torchgen.api.structured as structured
|
||||
@ -9,7 +9,7 @@ from torchgen.utils import mapMaybe
|
||||
|
||||
|
||||
@with_native_function_and_index
|
||||
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None:
|
||||
def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> Optional[str]:
|
||||
sig = kernel_signature(f, backend_index)
|
||||
metadata = backend_index.get_kernel(f)
|
||||
if metadata is None:
|
||||
@ -22,7 +22,7 @@ def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | No
|
||||
|
||||
|
||||
@with_native_function_and_index
|
||||
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list[str]:
|
||||
def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> List[str]:
|
||||
meta_name = meta.name(g)
|
||||
out_args = structured.impl_arguments(g)
|
||||
metadata = backend_index.get_kernel(g)
|
||||
@ -42,8 +42,8 @@ void impl({', '.join(a.decl() for a in out_args)});
|
||||
# actual kernel definitions we keep in aten/src/ATen/native/
|
||||
@with_native_function_and_index
|
||||
def compute_native_function_declaration(
|
||||
g: NativeFunctionsGroup | NativeFunction, backend_index: BackendIndex
|
||||
) -> list[str]:
|
||||
g: Union[NativeFunctionsGroup, NativeFunction], backend_index: BackendIndex
|
||||
) -> List[str]:
|
||||
metadata = backend_index.get_kernel(g)
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
if metadata is not None and metadata.structured:
|
||||
|
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, TYPE_CHECKING
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torchgen.api.cpp as cpp
|
||||
import torchgen.api.meta as meta
|
||||
@ -36,18 +34,15 @@ from torchgen.model import (
|
||||
SchemaKind,
|
||||
TensorOptionsArguments,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import assert_never, mapMaybe, Target
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
def gen_registration_headers(
|
||||
backend_index: BackendIndex,
|
||||
per_operator_headers: bool,
|
||||
rocm: bool,
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
if per_operator_headers:
|
||||
headers = ["#include <ATen/ops/as_strided_native.h>"]
|
||||
else:
|
||||
@ -78,7 +73,7 @@ def gen_registration_headers(
|
||||
|
||||
def gen_empty_impl_names(
|
||||
backend_index: BackendIndex,
|
||||
) -> tuple[str | None, str | None]:
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
empty_impl = None
|
||||
empty_strided_impl = None
|
||||
|
||||
@ -102,7 +97,7 @@ def gen_empty_impl_names(
|
||||
return empty_impl, empty_strided_impl
|
||||
|
||||
|
||||
def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
|
||||
def gen_create_out_helper(backend_index: BackendIndex) -> List[str]:
|
||||
if backend_index.dispatch_key == DispatchKey.Meta:
|
||||
empty_options = "options.device(at::kMeta)"
|
||||
else:
|
||||
@ -125,7 +120,7 @@ Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &o
|
||||
]
|
||||
|
||||
|
||||
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
|
||||
def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> List[str]:
|
||||
_, empty_strided_impl = gen_empty_impl_names(backend_index)
|
||||
return (
|
||||
[]
|
||||
@ -143,7 +138,7 @@ std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, I
|
||||
)
|
||||
|
||||
|
||||
def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
|
||||
def gen_resize_out_helper(backend_index: BackendIndex) -> List[str]:
|
||||
if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
||||
# The function isn't used by this key (since only functional ops have a kernel for this key),
|
||||
# so we need to not include it to avoid a defined-but-not-used error.
|
||||
@ -173,7 +168,7 @@ void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const
|
||||
]
|
||||
|
||||
|
||||
def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
|
||||
def gen_check_inplace_helper(backend_index: BackendIndex) -> List[str]:
|
||||
return [
|
||||
"""
|
||||
void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
|
||||
@ -196,7 +191,7 @@ void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &o
|
||||
]
|
||||
|
||||
|
||||
def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
|
||||
def gen_registration_helpers(backend_index: BackendIndex) -> List[str]:
|
||||
return [
|
||||
'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
|
||||
*gen_create_out_helper(backend_index),
|
||||
@ -254,7 +249,7 @@ class RegisterDispatchKey:
|
||||
# Finally, this field is currently Optional because it is only used by external backends.
|
||||
# It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
|
||||
# all of the existing kernel signatures scattered across aten/src/ATen/native.
|
||||
class_method_name: str | None
|
||||
class_method_name: Optional[str]
|
||||
|
||||
# Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
|
||||
# operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
|
||||
@ -262,7 +257,7 @@ class RegisterDispatchKey:
|
||||
|
||||
@staticmethod
|
||||
def gen_device_check(
|
||||
type: DeviceCheckType, args: list[Argument], method_name: str
|
||||
type: DeviceCheckType, args: List[Argument], method_name: str
|
||||
) -> str:
|
||||
if type == DeviceCheckType.NoCheck:
|
||||
return " // No device check\n"
|
||||
@ -277,7 +272,7 @@ class RegisterDispatchKey:
|
||||
return device_check
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
|
||||
def __call__(self, f: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]:
|
||||
if isinstance(f, NativeFunctionsGroup):
|
||||
g: NativeFunctionsGroup = f
|
||||
# Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
|
||||
@ -296,7 +291,7 @@ class RegisterDispatchKey:
|
||||
|
||||
def wrapper_kernel_sig(
|
||||
self, f: NativeFunction
|
||||
) -> NativeSignature | DispatcherSignature:
|
||||
) -> Union[NativeSignature, DispatcherSignature]:
|
||||
# The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
|
||||
return DispatcherSignature.from_schema(
|
||||
f.func,
|
||||
@ -305,8 +300,8 @@ class RegisterDispatchKey:
|
||||
)
|
||||
|
||||
def gen_out_inplace_wrapper(
|
||||
self, f: NativeFunction, g: NativeFunctionsGroup | None
|
||||
) -> str | None:
|
||||
self, f: NativeFunction, g: Optional[NativeFunctionsGroup]
|
||||
) -> Optional[str]:
|
||||
if g is None:
|
||||
return None
|
||||
k = f.func.kind()
|
||||
@ -355,7 +350,7 @@ class RegisterDispatchKey:
|
||||
}}
|
||||
"""
|
||||
|
||||
def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
|
||||
def gen_structured(self, g: NativeFunctionsGroup) -> List[str]:
|
||||
metadata = self.backend_index.get_kernel(g)
|
||||
if self.backend_index.dispatch_key == DispatchKey.Meta:
|
||||
assert not self.backend_index.has_kernel(g.out), (
|
||||
@ -385,8 +380,8 @@ class RegisterDispatchKey:
|
||||
return list(mapMaybe(structured_gen.gen_one, g.functions()))
|
||||
|
||||
def gen_unstructured(
|
||||
self, f: NativeFunction, g: NativeFunctionsGroup | None = None
|
||||
) -> str | None:
|
||||
self, f: NativeFunction, g: Optional[NativeFunctionsGroup] = None
|
||||
) -> Optional[str]:
|
||||
with native_function_manager(f):
|
||||
inplace_meta = False
|
||||
gets_out_inplace_wrapper = False
|
||||
@ -737,7 +732,7 @@ resize_out(out, sizes, strides, options);
|
||||
return "\n".join(line for line in lines if line)
|
||||
|
||||
@method_with_native_function
|
||||
def gen_one(self, f: NativeFunction) -> str | None:
|
||||
def gen_one(self, f: NativeFunction) -> Optional[str]:
|
||||
assert not f.manual_kernel_registration
|
||||
|
||||
if (
|
||||
@ -811,7 +806,7 @@ return {sig.name()}({', '.join(e.expr for e in translate(cpp_sig.arguments(), si
|
||||
sig_body = []
|
||||
# We'll use context to keep track of any variables we've brought
|
||||
# into scope while generating code
|
||||
context: list[Binding | Expr] = list(sig.arguments())
|
||||
context: List[Union[Binding, Expr]] = list(sig.arguments())
|
||||
|
||||
# Initialize the class corresponding to this structured
|
||||
# operator; feeding it the output argument(s) if it is known
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torchgen.api.ufunc as ufunc
|
||||
from torchgen.api.translate import translate
|
||||
@ -16,6 +14,7 @@ from torchgen.api.types import (
|
||||
StructuredImplSignature,
|
||||
VectorizedCType,
|
||||
)
|
||||
from torchgen.api.ufunc import UfunctorBindings
|
||||
from torchgen.context import with_native_function
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
@ -29,10 +28,6 @@ from torchgen.model import (
|
||||
from torchgen.utils import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.api.ufunc import UfunctorBindings
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# CUDA STUFF
|
||||
@ -65,7 +60,7 @@ if TYPE_CHECKING:
|
||||
@dataclass(frozen=True)
|
||||
class UfunctorSignature:
|
||||
g: NativeFunctionsGroup
|
||||
scalar_tensor_idx: int | None
|
||||
scalar_tensor_idx: Optional[int]
|
||||
name: str
|
||||
|
||||
def arguments(self) -> UfunctorBindings:
|
||||
@ -73,7 +68,7 @@ class UfunctorSignature:
|
||||
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
|
||||
)
|
||||
|
||||
def fields(self) -> list[Binding]:
|
||||
def fields(self) -> List[Binding]:
|
||||
# fields are renamed to have a trailing underscore, as is conventional
|
||||
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
|
||||
|
||||
@ -103,10 +98,10 @@ class UfuncSignature:
|
||||
name: str
|
||||
compute_t: CType
|
||||
|
||||
def arguments(self) -> list[Binding]:
|
||||
def arguments(self) -> List[Binding]:
|
||||
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
|
||||
|
||||
def call(self, ctx: Sequence[Binding | Expr]) -> str:
|
||||
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
|
||||
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
|
||||
|
||||
|
||||
@ -137,10 +132,10 @@ def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
|
||||
|
||||
def compute_ufunc_cuda_functors(
|
||||
g: NativeFunctionsGroup,
|
||||
) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
|
||||
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
|
||||
# First, build the functors.
|
||||
ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
|
||||
ufunctors: list[str] = []
|
||||
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
|
||||
ufunctors: List[str] = []
|
||||
loops = g.out.ufunc_inner_loop
|
||||
scalar_tensor_idx_lookup = {
|
||||
UfuncKey.CUDAFunctorOnSelf: 1,
|
||||
@ -242,7 +237,7 @@ BinaryScalarSpecializationConfigs = [
|
||||
def compute_ufunc_cuda_dtype_body(
|
||||
g: NativeFunctionsGroup,
|
||||
dtype: ScalarType,
|
||||
inner_loops: dict[UfuncKey, UfunctorSignature],
|
||||
inner_loops: Dict[UfuncKey, UfunctorSignature],
|
||||
parent_ctx: Sequence[Binding],
|
||||
) -> str:
|
||||
body = "using opmath_t = at::opmath_type<scalar_t>;"
|
||||
@ -254,7 +249,7 @@ def compute_ufunc_cuda_dtype_body(
|
||||
scalar_idx = config.scalar_idx + 1
|
||||
# Make a copy and at the same time widen the type (not permissible
|
||||
# without copy; we don't want to mutate the input argument anyway)
|
||||
ctx: list[Expr | Binding] = list(parent_ctx)
|
||||
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
|
||||
ctx.append(
|
||||
Expr(
|
||||
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
|
||||
@ -351,7 +346,7 @@ class StubSignature:
|
||||
def type_name(self) -> str:
|
||||
return f"{str(self.g.functional.func.name.name)}_fn"
|
||||
|
||||
def arguments(self) -> list[Binding]:
|
||||
def arguments(self) -> List[Binding]:
|
||||
return ufunc.stub_arguments(self.g)
|
||||
|
||||
def type(self) -> str:
|
||||
@ -398,7 +393,7 @@ def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
|
||||
def compute_ufunc_cpu_dtype_body(
|
||||
g: NativeFunctionsGroup,
|
||||
dtype: ScalarType,
|
||||
inner_loops: dict[UfuncKey, UfuncSignature],
|
||||
inner_loops: Dict[UfuncKey, UfuncSignature],
|
||||
parent_ctx: Sequence[Binding],
|
||||
) -> str:
|
||||
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
|
||||
@ -464,8 +459,8 @@ def compute_ufunc_cpu_dtype_body(
|
||||
)
|
||||
)
|
||||
|
||||
def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
|
||||
r: list[Expr | Binding] = []
|
||||
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
|
||||
r: List[Union[Expr, Binding]] = []
|
||||
r.extend(ctx)
|
||||
r.extend(b)
|
||||
return r
|
||||
@ -494,7 +489,7 @@ def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
|
||||
|
||||
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
|
||||
loops = g.out.ufunc_inner_loop
|
||||
ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
|
||||
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
|
||||
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
|
||||
lks = []
|
||||
# ORDER MATTERS: this specifies overriding precedence
|
||||
|
@ -1,29 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from torchgen import dest
|
||||
|
||||
# disable import sorting to avoid circular dependency.
|
||||
from torchgen.api.types import DispatcherSignature # usort:skip
|
||||
from torchgen.context import method_with_native_function
|
||||
from torchgen.executorch.model import ETKernelIndex
|
||||
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import concatMap, Target
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.executorch.model import ETKernelIndex
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
|
||||
# model authoring side.
|
||||
@dataclass(frozen=True)
|
||||
class ComputeNativeFunctionStub:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
if Variant.function not in f.variants:
|
||||
return None
|
||||
|
||||
@ -85,7 +80,7 @@ def gen_custom_ops_registration(
|
||||
selector: SelectiveBuilder,
|
||||
kernel_index: ETKernelIndex,
|
||||
rocm: bool,
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate custom ops registration code for dest.RegisterDispatchKey.
|
||||
|
||||
@ -102,7 +97,7 @@ def gen_custom_ops_registration(
|
||||
dispatch_key = DispatchKey.CPU
|
||||
backend_index = kernel_index._to_backend_index()
|
||||
static_init_dispatch_registrations = ""
|
||||
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||
ns_grouped_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
||||
for native_function in native_functions:
|
||||
ns_grouped_native_functions[native_function.namespace].append(native_function)
|
||||
|
||||
|
@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence, Set, Union
|
||||
|
||||
from torchgen import local
|
||||
from torchgen.api.types import (
|
||||
@ -65,7 +63,7 @@ def valuetype_type(
|
||||
*,
|
||||
binds: ArgName,
|
||||
remove_non_owning_ref_types: bool = False,
|
||||
) -> NamedCType | None:
|
||||
) -> Optional[NamedCType]:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
|
||||
return None
|
||||
@ -211,7 +209,7 @@ def returns_type(rs: Sequence[Return]) -> CType:
|
||||
|
||||
|
||||
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
|
||||
returns: list[str] = []
|
||||
returns: List[str] = []
|
||||
for i, r in enumerate(f.func.returns):
|
||||
# If we have an inplace function, the return argument is
|
||||
# implicitly named self.
|
||||
@ -297,16 +295,16 @@ def default_expr(d: str, t: Type) -> str:
|
||||
|
||||
|
||||
def argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
a: Union[Argument, TensorOptionsArguments, SelfArgument],
|
||||
*,
|
||||
cpp_no_default_args: set[str],
|
||||
cpp_no_default_args: Set[str],
|
||||
method: bool,
|
||||
faithful: bool,
|
||||
has_tensor_options: bool,
|
||||
) -> list[Binding]:
|
||||
) -> List[Binding]:
|
||||
def sub_argument(
|
||||
a: Argument | TensorOptionsArguments | SelfArgument,
|
||||
) -> list[Binding]:
|
||||
a: Union[Argument, TensorOptionsArguments, SelfArgument]
|
||||
) -> List[Binding]:
|
||||
return argument(
|
||||
a,
|
||||
cpp_no_default_args=cpp_no_default_args,
|
||||
@ -321,7 +319,7 @@ def argument(
|
||||
binds = SpecialArgName.possibly_redundant_memory_format
|
||||
else:
|
||||
binds = a.name
|
||||
default: str | None = None
|
||||
default: Optional[str] = None
|
||||
if a.name not in cpp_no_default_args and a.default is not None:
|
||||
default = default_expr(a.default, a.type)
|
||||
return [
|
||||
@ -349,9 +347,9 @@ def arguments(
|
||||
*,
|
||||
faithful: bool,
|
||||
method: bool,
|
||||
cpp_no_default_args: set[str],
|
||||
) -> list[Binding]:
|
||||
args: list[Argument | TensorOptionsArguments | SelfArgument] = []
|
||||
cpp_no_default_args: Set[str],
|
||||
) -> List[Binding]:
|
||||
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
|
||||
if faithful:
|
||||
args.extend(arguments.non_out)
|
||||
args.extend(arguments.out)
|
||||
|
@ -1,15 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Optional, Set
|
||||
|
||||
import torchgen.api.cpp as aten_cpp
|
||||
from torchgen.api.types import Binding, CType
|
||||
from torchgen.executorch.api.types.types import contextArg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.api.types import Binding, CType
|
||||
from torchgen.model import FunctionSchema, NativeFunction
|
||||
from torchgen.model import FunctionSchema, NativeFunction
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -25,14 +20,14 @@ class ExecutorchCppSignature:
|
||||
func: FunctionSchema
|
||||
|
||||
# The set of C++ arguments which should not have defaults applied to them
|
||||
cpp_no_default_args: set[str]
|
||||
cpp_no_default_args: Set[str]
|
||||
|
||||
# Allows you to prepend an arbitrary prefix to the signature name.
|
||||
# This is useful for parts of the codegen that generate wrappers around kernels,
|
||||
# and need to avoid naming collisions.
|
||||
prefix: str = ""
|
||||
|
||||
def arguments(self, *, include_context: bool = True) -> list[Binding]:
|
||||
def arguments(self, *, include_context: bool = True) -> List[Binding]:
|
||||
return ([contextArg] if include_context else []) + et_cpp.arguments(
|
||||
self.func.arguments,
|
||||
faithful=True, # always faithful, out argument at the end
|
||||
@ -46,7 +41,7 @@ class ExecutorchCppSignature:
|
||||
faithful_name_for_out_overloads=True,
|
||||
)
|
||||
|
||||
def decl(self, name: str | None = None, *, include_context: bool = True) -> str:
|
||||
def decl(self, name: Optional[str] = None, *, include_context: bool = True) -> str:
|
||||
args_str = ", ".join(
|
||||
a.decl() for a in self.arguments(include_context=include_context)
|
||||
)
|
||||
@ -54,7 +49,7 @@ class ExecutorchCppSignature:
|
||||
name = self.name()
|
||||
return f"{self.returns_type().cpp_type()} {name}({args_str})"
|
||||
|
||||
def defn(self, name: str | None = None) -> str:
|
||||
def defn(self, name: Optional[str] = None) -> str:
|
||||
args = [a.defn() for a in self.arguments()]
|
||||
args_str = ", ".join(args)
|
||||
if name is None:
|
||||
@ -67,7 +62,7 @@ class ExecutorchCppSignature:
|
||||
@staticmethod
|
||||
def from_native_function(
|
||||
f: NativeFunction, *, prefix: str = ""
|
||||
) -> ExecutorchCppSignature:
|
||||
) -> "ExecutorchCppSignature":
|
||||
return ExecutorchCppSignature(
|
||||
func=f.func, prefix=prefix, cpp_no_default_args=f.cpp_no_default_args
|
||||
)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
from torchgen.api.types import (
|
||||
BaseCppType,
|
||||
@ -41,7 +40,7 @@ contextArg = Binding(
|
||||
default=None,
|
||||
)
|
||||
|
||||
BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
|
||||
BaseTy.int: longT,
|
||||
BaseTy.float: doubleT,
|
||||
BaseTy.bool: boolT,
|
||||
@ -55,7 +54,7 @@ BaseTypeToCppMapping: dict[BaseTy, BaseCppType] = {
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptionalCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -64,13 +63,13 @@ class OptionalCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"torch::executor::optional<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return OptionalCType(self.elem.remove_const_ref())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArrayRefCType(CType):
|
||||
elem: CType
|
||||
elem: "CType"
|
||||
|
||||
def cpp_type(self, *, strip_ref: bool = False) -> str:
|
||||
# Do not pass `strip_ref` recursively.
|
||||
@ -79,5 +78,5 @@ class ArrayRefCType(CType):
|
||||
def cpp_type_registration_declarations(self) -> str:
|
||||
return f"torch::executor::ArrayRef<{self.elem.cpp_type_registration_declarations()}>"
|
||||
|
||||
def remove_const_ref(self) -> CType:
|
||||
def remove_const_ref(self) -> "CType":
|
||||
return ArrayRefCType(self.elem.remove_const_ref())
|
||||
|
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Sequence, TYPE_CHECKING
|
||||
from typing import Callable, List, Sequence, Tuple
|
||||
|
||||
from torchgen.api.types import Binding, CType, NamedCType
|
||||
from torchgen.model import (
|
||||
Argument,
|
||||
BaseTy,
|
||||
@ -14,10 +13,6 @@ from torchgen.model import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.api.types import Binding, CType, NamedCType
|
||||
|
||||
|
||||
connector = "\n\t"
|
||||
|
||||
|
||||
@ -57,7 +52,7 @@ class Unboxing:
|
||||
# Convert all the arguments in a NativeFunction to C++ code
|
||||
def convert_arguments(
|
||||
self, args: Sequence[Binding]
|
||||
) -> tuple[list[Binding], list[str]]:
|
||||
) -> Tuple[List[Binding], List[str]]:
|
||||
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
|
||||
binding_list = []
|
||||
for arg in args:
|
||||
@ -77,7 +72,7 @@ class Unboxing:
|
||||
|
||||
def argumenttype_evalue_convert(
|
||||
self, t: Type, arg_name: str, *, mutable: bool = False
|
||||
) -> tuple[str, CType, list[str], list[str]]:
|
||||
) -> Tuple[str, CType, List[str], List[str]]:
|
||||
"""
|
||||
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
|
||||
(1) the C++ code necessary to unbox the argument
|
||||
@ -112,14 +107,14 @@ class Unboxing:
|
||||
|
||||
def _gen_code_base_type(
|
||||
self, arg_name: str, out_name: str, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
return [
|
||||
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
|
||||
], []
|
||||
|
||||
def _gen_code_optional_type(
|
||||
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
in_name = f"{arg_name}_opt_in"
|
||||
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
|
||||
t.elem, in_name
|
||||
@ -135,7 +130,7 @@ class Unboxing:
|
||||
|
||||
def _gen_code_list_type(
|
||||
self, arg_name: str, out_name: str, t: ListType, ctype: CType
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
in_name = f"{arg_name}_list_in"
|
||||
elem_name = f"{arg_name}_elem"
|
||||
code = []
|
||||
|
@ -1,12 +1,11 @@
|
||||
# Represents all kernels used by an Executorch model.
|
||||
# It maintains a Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] structure.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from collections import defaultdict, namedtuple
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from torchgen.model import (
|
||||
BackendIndex,
|
||||
@ -42,7 +41,7 @@ class ETKernelKeyOpArgMeta:
|
||||
arg_name: str
|
||||
dtype: str
|
||||
# The order of the dimensions if entry is a Tensor
|
||||
dim_order: tuple[int, ...]
|
||||
dim_order: Tuple[int, ...]
|
||||
|
||||
def to_native_string(self) -> str:
|
||||
dtype_str = ScalarType[self.dtype].value
|
||||
@ -53,7 +52,7 @@ class ETKernelKeyOpArgMeta:
|
||||
@dataclass(frozen=True)
|
||||
class ETKernelKey:
|
||||
# Field undefined is default = True
|
||||
arg_meta: tuple[ETKernelKeyOpArgMeta, ...] = ()
|
||||
arg_meta: Tuple[ETKernelKeyOpArgMeta, ...] = ()
|
||||
|
||||
# Indicator for this kernel being used as a catch all
|
||||
default: bool = False
|
||||
@ -62,10 +61,10 @@ class ETKernelKey:
|
||||
|
||||
@staticmethod
|
||||
def gen_from_yaml(
|
||||
args: dict[str, tuple[str, str]],
|
||||
type_alias_map: dict[str, list[str]], # TODO: Support unwrapped str val
|
||||
dim_order_alias_map: dict[str, list[int]],
|
||||
) -> list[ETKernelKey]:
|
||||
args: Dict[str, Tuple[str, str]],
|
||||
type_alias_map: Dict[str, List[str]], # TODO: Support unwrapped str val
|
||||
dim_order_alias_map: Dict[str, List[int]],
|
||||
) -> List["ETKernelKey"]:
|
||||
"""Generate ETKernelKeys from arg kernel specs
|
||||
Multiple ETKernelKeys are returned due to dtype permutations from utilizing
|
||||
type_alias_map (actualizing each potential type permutation as a KernelKey)
|
||||
@ -138,15 +137,15 @@ class ETKernelKey:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ETKernelIndex:
|
||||
index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]]
|
||||
index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]]
|
||||
|
||||
def has_kernels(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
||||
def has_kernels(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
|
||||
m = self.get_kernels(g)
|
||||
return m is not None
|
||||
|
||||
def get_kernels(
|
||||
self, g: NativeFunction | NativeFunctionsGroup
|
||||
) -> dict[ETKernelKey, BackendMetadata]:
|
||||
self, g: Union[NativeFunction, NativeFunctionsGroup]
|
||||
) -> Dict[ETKernelKey, BackendMetadata]:
|
||||
if isinstance(g, NativeFunction):
|
||||
f = g
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
@ -159,8 +158,8 @@ class ETKernelIndex:
|
||||
|
||||
@staticmethod
|
||||
def grow_from_backend_indices(
|
||||
kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]],
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
kernel_index: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]],
|
||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
|
||||
) -> None:
|
||||
for dk in backend_indices:
|
||||
index = backend_indices[dk]
|
||||
@ -172,17 +171,17 @@ class ETKernelIndex:
|
||||
|
||||
@staticmethod
|
||||
def from_backend_indices(
|
||||
backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||
) -> ETKernelIndex:
|
||||
kernel_index: dict[
|
||||
OperatorName, dict[ETKernelKey, BackendMetadata]
|
||||
backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
||||
) -> "ETKernelIndex":
|
||||
kernel_index: Dict[
|
||||
OperatorName, Dict[ETKernelKey, BackendMetadata]
|
||||
] = defaultdict(dict)
|
||||
ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices)
|
||||
return ETKernelIndex(kernel_index)
|
||||
|
||||
def grow(
|
||||
self, backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]]
|
||||
) -> ETKernelIndex:
|
||||
self, backend_indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]]
|
||||
) -> "ETKernelIndex":
|
||||
ETKernelIndex.grow_from_backend_indices(self.index, backend_indices)
|
||||
return self
|
||||
|
||||
@ -190,7 +189,7 @@ class ETKernelIndex:
|
||||
"""
|
||||
WARNING: this will be deprecated once all the codegen places know how to handle ETKernelIndex.
|
||||
"""
|
||||
index: dict[OperatorName, BackendMetadata] = {}
|
||||
index: Dict[OperatorName, BackendMetadata] = {}
|
||||
for op in self.index:
|
||||
kernel_dict = self.index[op]
|
||||
assert (
|
||||
@ -210,7 +209,9 @@ class ETKernelIndex:
|
||||
|
||||
# Note duplicate ETKernelKey from index_b will clobber the metadata from index_a
|
||||
@staticmethod
|
||||
def merge_indices(index_a: ETKernelIndex, index_b: ETKernelIndex) -> ETKernelIndex:
|
||||
def merge_indices(
|
||||
index_a: "ETKernelIndex", index_b: "ETKernelIndex"
|
||||
) -> "ETKernelIndex":
|
||||
combined = defaultdict(dict, index_a.index.copy())
|
||||
|
||||
for op, entry in index_b.index.items():
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, namedtuple
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
@ -24,7 +22,7 @@ ETParsedYaml = namedtuple("ETParsedYaml", ["native_functions", "et_kernel_indice
|
||||
ET_FIELDS = ["kernels", "type_alias", "dim_order_alias"]
|
||||
|
||||
|
||||
def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]:
|
||||
def parse_from_yaml(ei: Dict[str, object]) -> Dict[ETKernelKey, BackendMetadata]:
|
||||
"""Given a loaded yaml representing kernel assignment information, extract the
|
||||
mapping from `kernel keys` to `BackendMetadata` (the latter representing the kernel instance)
|
||||
|
||||
@ -36,11 +34,11 @@ def parse_from_yaml(ei: dict[str, object]) -> dict[ETKernelKey, BackendMetadata]
|
||||
if (kernels := e.pop("kernels", None)) is None:
|
||||
return {}
|
||||
|
||||
type_alias: dict[str, list[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
||||
dim_order_alias: dict[str, list[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
||||
type_alias: Dict[str, List[str]] = e.pop("type_alias", {}) # type: ignore[assignment]
|
||||
dim_order_alias: Dict[str, List[str]] = e.pop("dim_order_alias", {}) # type: ignore[assignment]
|
||||
dim_order_alias.pop("__line__", None)
|
||||
|
||||
kernel_mapping: dict[ETKernelKey, BackendMetadata] = {}
|
||||
kernel_mapping: Dict[ETKernelKey, BackendMetadata] = {}
|
||||
|
||||
for entry in kernels: # type: ignore[attr-defined]
|
||||
arg_meta = entry.get("arg_meta")
|
||||
@ -78,7 +76,7 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
||||
of `kernel keys` to `BackendMetadata` (the latter representing the kernel instance
|
||||
that should be used by the kernel key).
|
||||
"""
|
||||
indices: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = {}
|
||||
indices: Dict[OperatorName, Dict[ETKernelKey, BackendMetadata]] = {}
|
||||
for ei in es: # type: ignore[attr-defined]
|
||||
e = ei.copy()
|
||||
|
||||
@ -97,11 +95,11 @@ def parse_et_yaml_struct(es: object) -> ETKernelIndex:
|
||||
return ETKernelIndex(indices)
|
||||
|
||||
|
||||
def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
|
||||
def extract_kernel_fields(es: object) -> Dict[OperatorName, Dict[str, Any]]:
|
||||
"""Given a loaded yaml representing a list of operators, extract the
|
||||
kernel key related fields indexed by the operator name.
|
||||
"""
|
||||
fields: dict[OperatorName, dict[str, Any]] = defaultdict(dict)
|
||||
fields: Dict[OperatorName, Dict[str, Any]] = defaultdict(dict)
|
||||
for ei in es: # type: ignore[attr-defined]
|
||||
funcs = ei.get("func")
|
||||
assert isinstance(funcs, str), f"not a str: {funcs}"
|
||||
@ -120,9 +118,9 @@ def extract_kernel_fields(es: object) -> dict[OperatorName, dict[str, Any]]:
|
||||
def parse_et_yaml(
|
||||
path: str,
|
||||
tags_yaml_path: str,
|
||||
ignore_keys: set[DispatchKey] | None = None,
|
||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||||
skip_native_fns_gen: bool = False,
|
||||
) -> tuple[list[NativeFunction], dict[OperatorName, dict[str, Any]]]:
|
||||
) -> Tuple[List[NativeFunction], Dict[OperatorName, Dict[str, Any]]]:
|
||||
"""Parse native_functions.yaml into NativeFunctions and an Operator Indexed Dict
|
||||
of fields to persist from native_functions.yaml to functions.yaml
|
||||
"""
|
||||
|
278
torchgen/gen.py
278
torchgen/gen.py
@ -1,13 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
from collections import defaultdict, namedtuple, OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Literal, Sequence, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
@ -138,20 +148,20 @@ class LineLoader(YamlLoader):
|
||||
ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
|
||||
|
||||
|
||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
|
||||
_GLOBAL_PARSE_NATIVE_YAML_CACHE: Dict[str, ParsedYaml] = {}
|
||||
_GLOBAL_PARSE_TAGS_YAML_CACHE: Dict[str, Set[str]] = {}
|
||||
|
||||
|
||||
def parse_native_yaml_struct(
|
||||
es: object,
|
||||
valid_tags: set[str],
|
||||
ignore_keys: set[DispatchKey] | None = None,
|
||||
valid_tags: Set[str],
|
||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||||
path: str = "<stdin>",
|
||||
skip_native_fns_gen: bool = False,
|
||||
) -> ParsedYaml:
|
||||
assert isinstance(es, list)
|
||||
rs: list[NativeFunction] = []
|
||||
bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
||||
rs: List[NativeFunction] = []
|
||||
bs: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]] = defaultdict(dict)
|
||||
for e in es:
|
||||
assert isinstance(e, dict), f"expected to be dict: {e}"
|
||||
assert isinstance(e.get("__line__"), int), e
|
||||
@ -164,7 +174,7 @@ def parse_native_yaml_struct(
|
||||
BackendIndex.grow_index(bs, m)
|
||||
error_check_native_functions(rs)
|
||||
# Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
|
||||
indices: dict[DispatchKey, BackendIndex] = defaultdict(
|
||||
indices: Dict[DispatchKey, BackendIndex] = defaultdict(
|
||||
lambda: BackendIndex(
|
||||
dispatch_key=DispatchKey.Undefined,
|
||||
use_out_as_primary=True,
|
||||
@ -190,9 +200,9 @@ def parse_native_yaml_struct(
|
||||
return ParsedYaml(rs, indices)
|
||||
|
||||
|
||||
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
|
||||
def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> Set[str]:
|
||||
assert isinstance(es, list)
|
||||
rs: set[str] = set()
|
||||
rs: Set[str] = set()
|
||||
for e in es:
|
||||
assert isinstance(e.get("__line__"), int), e
|
||||
loc = Location(path, e["__line__"])
|
||||
@ -208,7 +218,7 @@ def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def parse_tags_yaml(path: str) -> set[str]:
|
||||
def parse_tags_yaml(path: str) -> Set[str]:
|
||||
global _GLOBAL_PARSE_TAGS_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
|
||||
with open(path) as f:
|
||||
@ -221,10 +231,10 @@ def parse_tags_yaml(path: str) -> set[str]:
|
||||
def parse_native_yaml(
|
||||
path: str,
|
||||
tags_yaml_path: str,
|
||||
ignore_keys: set[DispatchKey] | None = None,
|
||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||||
*,
|
||||
skip_native_fns_gen: bool = False,
|
||||
loaded_yaml: object | None = None,
|
||||
loaded_yaml: Optional[object] = None,
|
||||
) -> ParsedYaml:
|
||||
global _GLOBAL_PARSE_NATIVE_YAML_CACHE
|
||||
if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
|
||||
@ -251,8 +261,8 @@ def parse_native_yaml(
|
||||
# Some assertions are already performed during parsing, but those are only within a single NativeFunction.
|
||||
# Assertions here are meant to be performed across NativeFunctions.
|
||||
def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
|
||||
func_map: dict[OperatorName, NativeFunction] = {}
|
||||
base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
|
||||
func_map: Dict[OperatorName, NativeFunction] = {}
|
||||
base_func_map: Dict[BaseOperatorName, List[NativeFunction]] = defaultdict(list)
|
||||
for f in funcs:
|
||||
func_map[f.func.name] = f
|
||||
base_func_map[f.func.name.name].append(f)
|
||||
@ -319,7 +329,7 @@ def cpp_string(s: str) -> str:
|
||||
# and similar functional combinators.
|
||||
|
||||
|
||||
def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
|
||||
def static_dispatch_keys(backends: List[BackendIndex]) -> List[DispatchKey]:
|
||||
if len(backends) == 0:
|
||||
return []
|
||||
else:
|
||||
@ -333,7 +343,7 @@ def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
|
||||
|
||||
def get_static_dispatch_backend(
|
||||
f: NativeFunction, backend_index: BackendIndex
|
||||
) -> DispatchKey | None:
|
||||
) -> Optional[DispatchKey]:
|
||||
if f.structured_delegate is not None or backend_index.has_kernel(f):
|
||||
# TODO: for ops with structured_delegate it should check the dispatch table of
|
||||
# the out variant instead. For now, these structured ops all have CPU/CUDA kernels
|
||||
@ -352,8 +362,8 @@ def get_static_dispatch_backend(
|
||||
|
||||
|
||||
def static_dispatch_ops_header(
|
||||
f: NativeFunction, backend_index: list[BackendIndex]
|
||||
) -> str | None:
|
||||
f: NativeFunction, backend_index: List[BackendIndex]
|
||||
) -> Optional[str]:
|
||||
if backend_index is None or f.manual_kernel_registration:
|
||||
return None
|
||||
|
||||
@ -367,7 +377,7 @@ def static_dispatch_ops_header(
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
|
||||
def static_dispatch_extra_headers(backends: List[BackendIndex]) -> List[str]:
|
||||
return [
|
||||
f"#include <ATen/{dispatch_key}Functions.h>"
|
||||
for dispatch_key in static_dispatch_keys(backends)
|
||||
@ -378,12 +388,12 @@ def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
|
||||
# Note that we have a special case for `memory_format` argument and this case is not covered by
|
||||
# tools.codegen.api.translate() yet as its application is limited to static dispatch.
|
||||
def translate_args(
|
||||
sig: CppSignature | DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
cpp_sig: CppSignature,
|
||||
) -> str:
|
||||
# Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
|
||||
def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
|
||||
output_bindings: list[Binding] = []
|
||||
def add_spl_memory_format_binding(input_bindings: List[Binding]) -> List[Binding]:
|
||||
output_bindings: List[Binding] = []
|
||||
for binding in input_bindings:
|
||||
if binding.name == "memory_format":
|
||||
spl_mem_format_binding = Binding(
|
||||
@ -413,7 +423,7 @@ def translate_args(
|
||||
|
||||
|
||||
def generate_static_dispatch_backend_call(
|
||||
sig: CppSignature | DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_index: BackendIndex,
|
||||
) -> str:
|
||||
@ -431,9 +441,9 @@ def generate_static_dispatch_backend_call(
|
||||
|
||||
|
||||
def generate_static_dispatch_fallback_call(
|
||||
sig: CppSignature | DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: list[BackendIndex],
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
cpp_sigs = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=False
|
||||
@ -460,9 +470,9 @@ def generate_static_dispatch_fallback_call(
|
||||
|
||||
|
||||
def static_dispatch(
|
||||
sig: CppSignature | DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: list[BackendIndex],
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
"""
|
||||
For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
|
||||
@ -502,7 +512,7 @@ def static_dispatch(
|
||||
tensor_opts = f.func.arguments.tensor_options
|
||||
|
||||
stmts = []
|
||||
subexprs: list[str] = []
|
||||
subexprs: List[str] = []
|
||||
if tensor_opts is not None:
|
||||
subexprs.append(
|
||||
"DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
|
||||
@ -538,10 +548,10 @@ def static_dispatch(
|
||||
@dataclass(frozen=True)
|
||||
class RegisterSchema:
|
||||
selector: SelectiveBuilder
|
||||
known_tags: dict[str, int] = field(default_factory=dict)
|
||||
known_tags: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
if not self.selector.is_native_function_selected(f):
|
||||
return None
|
||||
tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
|
||||
@ -563,7 +573,7 @@ class RegisterSchema:
|
||||
@dataclass(frozen=True)
|
||||
class ComputeOperators:
|
||||
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
||||
static_dispatch_backend_indices: list[BackendIndex]
|
||||
static_dispatch_backend_indices: List[BackendIndex]
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str:
|
||||
@ -660,7 +670,7 @@ static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed
|
||||
@dataclass(frozen=True)
|
||||
class ComputeFunction:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
sig_group = CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||
)
|
||||
@ -708,10 +718,10 @@ namespace symint {{
|
||||
@dataclass(frozen=True)
|
||||
class ComputeTensorMethod:
|
||||
target: Literal[Target.DECLARATION, Target.DEFINITION]
|
||||
static_dispatch_backend_indices: list[BackendIndex]
|
||||
static_dispatch_backend_indices: List[BackendIndex]
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
if Variant.method not in f.variants:
|
||||
return None
|
||||
|
||||
@ -754,7 +764,7 @@ inline {sig.defn(prefix="Tensor::")} const {{
|
||||
@dataclass(frozen=True)
|
||||
class ComputeRedispatchFunction:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
# We unconditionally generate function variants of the redispatch API.
|
||||
# This is mainly because we can namespace functions separately, but not methods,
|
||||
sig_group = CppSignatureGroup.from_native_function(
|
||||
@ -788,7 +798,7 @@ def compute_aten_op(f: NativeFunction) -> str:
|
||||
|
||||
|
||||
# Generates MetaFunctions.h
|
||||
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
|
||||
def compute_meta_function_declaration(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
if not g.structured:
|
||||
return None
|
||||
with native_function_manager(g.out):
|
||||
@ -933,7 +943,7 @@ class ComputeBackendSelect:
|
||||
selector: SelectiveBuilder
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
if not needs_backend_select(f, self.selector):
|
||||
return None
|
||||
|
||||
@ -949,7 +959,7 @@ class ComputeBackendSelect:
|
||||
|
||||
dispatcher_sig = DispatcherSignature.from_schema(f.func)
|
||||
|
||||
sig: NativeSignature | DispatcherSignature
|
||||
sig: Union[NativeSignature, DispatcherSignature]
|
||||
sig = dispatcher_sig
|
||||
dispatcher_exprs = dispatcher_sig.exprs()
|
||||
dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
|
||||
@ -1049,7 +1059,7 @@ def dynamic_type(t: Type) -> str:
|
||||
).cpp_type()
|
||||
|
||||
|
||||
def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
|
||||
def compute_method_of_yaml(variants: Set[Variant]) -> List[str]:
|
||||
# This is written out explicitly to ensure that Tensor and
|
||||
# namespace are put into the list in the right order
|
||||
method_of = ["Type"]
|
||||
@ -1062,7 +1072,7 @@ def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
|
||||
|
||||
def compute_returns_yaml(
|
||||
f: NativeFunction,
|
||||
) -> tuple[list[dict[str, str]], dict[str, str]]:
|
||||
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
|
||||
# Note [name and field_name]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# To understand name_to_field_name, we must first talk about this
|
||||
@ -1102,7 +1112,7 @@ def compute_returns_yaml(
|
||||
# schema itself.
|
||||
#
|
||||
# See also https://github.com/pytorch/pytorch/issues/43114
|
||||
name_to_field_name: dict[str, str] = {}
|
||||
name_to_field_name: Dict[str, str] = {}
|
||||
|
||||
# Compute the returns field of the YAML entry
|
||||
names = cpp.return_names(f)
|
||||
@ -1131,12 +1141,12 @@ def compute_cpp_argument_yaml(
|
||||
cpp_a: Binding,
|
||||
*,
|
||||
schema_order: bool,
|
||||
kwarg_only_set: set[str],
|
||||
out_arg_set: set[str],
|
||||
name_to_field_name: dict[str, str],
|
||||
kwarg_only_set: Set[str],
|
||||
out_arg_set: Set[str],
|
||||
name_to_field_name: Dict[str, str],
|
||||
) -> object:
|
||||
if isinstance(cpp_a.argument, TensorOptionsArguments):
|
||||
arg: dict[str, object] = {
|
||||
arg: Dict[str, object] = {
|
||||
"annotation": None,
|
||||
"dynamic_type": "at::TensorOptions",
|
||||
"is_nullable": False,
|
||||
@ -1163,11 +1173,11 @@ def compute_argument_yaml(
|
||||
a: Argument,
|
||||
*,
|
||||
schema_order: bool,
|
||||
kwarg_only_set: set[str],
|
||||
out_arg_set: set[str],
|
||||
name_to_field_name: dict[str, str],
|
||||
kwarg_only_set: Set[str],
|
||||
out_arg_set: Set[str],
|
||||
name_to_field_name: Dict[str, str],
|
||||
) -> object:
|
||||
arg: dict[str, object] = {
|
||||
arg: Dict[str, object] = {
|
||||
"annotation": str(a.annotation) if a.annotation else None,
|
||||
"dynamic_type": dynamic_type(a.type),
|
||||
"is_nullable": a.type.is_nullable(),
|
||||
@ -1293,7 +1303,7 @@ def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
|
||||
|
||||
@with_native_function_and_indices
|
||||
def compute_registration_declarations(
|
||||
f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
|
||||
f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]
|
||||
) -> str:
|
||||
name = dispatcher.name(f.func)
|
||||
returns_type = dispatcher.returns_type(
|
||||
@ -1301,7 +1311,7 @@ def compute_registration_declarations(
|
||||
).cpp_type_registration_declarations()
|
||||
args = dispatcher.arguments(f.func)
|
||||
args_str = ", ".join(a.no_default().decl_registration_declarations() for a in args)
|
||||
comment_data: dict[str, str] = {
|
||||
comment_data: Dict[str, str] = {
|
||||
"schema": f"aten::{f.func}",
|
||||
# TODO: What exactly is the semantics of the 'dispatch' field?
|
||||
"dispatch": str(
|
||||
@ -1327,8 +1337,8 @@ def compute_registration_declarations(
|
||||
|
||||
|
||||
def get_custom_build_selector(
|
||||
provided_op_registration_allowlist: list[str] | None,
|
||||
op_selection_yaml_path: str | None,
|
||||
provided_op_registration_allowlist: Optional[List[str]],
|
||||
op_selection_yaml_path: Optional[str],
|
||||
) -> SelectiveBuilder:
|
||||
assert not (
|
||||
provided_op_registration_allowlist is not None
|
||||
@ -1339,7 +1349,7 @@ def get_custom_build_selector(
|
||||
+ "same time."
|
||||
)
|
||||
|
||||
op_registration_allowlist: set[str] | None = None
|
||||
op_registration_allowlist: Optional[Set[str]] = None
|
||||
if provided_op_registration_allowlist is not None:
|
||||
op_registration_allowlist = set(provided_op_registration_allowlist)
|
||||
|
||||
@ -1359,11 +1369,11 @@ def get_custom_build_selector(
|
||||
|
||||
def get_grouped_by_view_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
def maybe_create_view_group(
|
||||
d: dict[ViewSchemaKind | SchemaKind, NativeFunction]
|
||||
) -> list[NativeFunction | NativeFunctionsViewGroup]:
|
||||
funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
|
||||
d: Dict[Union[ViewSchemaKind, SchemaKind], NativeFunction]
|
||||
) -> List[Union[NativeFunction, NativeFunctionsViewGroup]]:
|
||||
funcs: List[Union[NativeFunction, NativeFunctionsViewGroup]] = []
|
||||
if ViewSchemaKind.aliasing in d:
|
||||
view = d.pop(ViewSchemaKind.aliasing)
|
||||
view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
|
||||
@ -1381,8 +1391,8 @@ def get_grouped_by_view_native_functions(
|
||||
funcs.extend(d.values())
|
||||
return funcs
|
||||
|
||||
grouped_by_views: dict[
|
||||
FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
|
||||
grouped_by_views: Dict[
|
||||
FunctionSchema, Dict[Union[SchemaKind, ViewSchemaKind], NativeFunction]
|
||||
] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
schema = f.func.view_signature()
|
||||
@ -1406,10 +1416,10 @@ def get_grouped_by_view_native_functions(
|
||||
|
||||
def get_grouped_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
||||
def flatten_pre_group(
|
||||
d: dict[SchemaKind, NativeFunction]
|
||||
) -> Sequence[NativeFunction | NativeFunctionsGroup]:
|
||||
d: Dict[SchemaKind, NativeFunction]
|
||||
) -> Sequence[Union[NativeFunction, NativeFunctionsGroup]]:
|
||||
r = NativeFunctionsGroup.from_dict(d)
|
||||
if r is None:
|
||||
# Invariant: any NativeFunctions that are code-generated
|
||||
@ -1428,13 +1438,13 @@ def get_grouped_native_functions(
|
||||
|
||||
def get_ns_grouped_kernels(
|
||||
*,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
native_function_decl_gen: Callable[
|
||||
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
|
||||
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
|
||||
] = dest.compute_native_function_declaration,
|
||||
) -> dict[str, list[str]]:
|
||||
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||
) -> Dict[str, List[str]]:
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
for f in grouped_native_functions:
|
||||
native_function_namespaces = set()
|
||||
dispatch_keys = set()
|
||||
@ -1457,9 +1467,9 @@ def get_ns_grouped_kernels(
|
||||
|
||||
def get_native_function_declarations_from_ns_grouped_kernels(
|
||||
*,
|
||||
ns_grouped_kernels: dict[str, list[str]],
|
||||
) -> list[str]:
|
||||
declarations: list[str] = []
|
||||
ns_grouped_kernels: Dict[str, List[str]],
|
||||
) -> List[str]:
|
||||
declarations: List[str] = []
|
||||
newline = "\n"
|
||||
for namespace, kernels in ns_grouped_kernels.items():
|
||||
ns_helper = NamespaceHelper(
|
||||
@ -1485,12 +1495,12 @@ def get_native_function_declarations_from_ns_grouped_kernels(
|
||||
# Return native function declarations grouped by their namespaces.
|
||||
def get_native_function_declarations(
|
||||
*,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
native_function_decl_gen: Callable[
|
||||
[NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
|
||||
[Union[NativeFunctionsGroup, NativeFunction], BackendIndex], List[str]
|
||||
] = dest.compute_native_function_declaration,
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate kernel declarations, in `NativeFunction(s).h`.
|
||||
:param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
|
||||
@ -1510,7 +1520,7 @@ def get_native_function_declarations(
|
||||
|
||||
|
||||
def get_kernel_namespace(
|
||||
*, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
|
||||
*, f: Union[NativeFunction, NativeFunctionsGroup], backend_idx: BackendIndex
|
||||
) -> str:
|
||||
backend_metadata = backend_idx.get_kernel(f)
|
||||
assert not backend_metadata or "::native" in backend_metadata.cpp_namespace, (
|
||||
@ -1528,7 +1538,7 @@ def get_kernel_namespace(
|
||||
def get_native_function_definitions(
|
||||
*,
|
||||
fm: FileManager,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_idx: BackendIndex,
|
||||
selector: SelectiveBuilder,
|
||||
@ -1536,11 +1546,11 @@ def get_native_function_definitions(
|
||||
symint: bool,
|
||||
skip_dispatcher_op_registration: bool,
|
||||
gen_dispatch_helpers: bool,
|
||||
) -> list[str]:
|
||||
definitions: list[str] = []
|
||||
ns_definitions: dict[str, list[str]] = defaultdict(list)
|
||||
anonymous_definitions: dict[str, list[str]] = defaultdict(list)
|
||||
registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
|
||||
) -> List[str]:
|
||||
definitions: List[str] = []
|
||||
ns_definitions: Dict[str, List[str]] = defaultdict(list)
|
||||
anonymous_definitions: Dict[str, List[str]] = defaultdict(list)
|
||||
registrations: Dict[str, Dict[str, List[str]]] = defaultdict(dict)
|
||||
newline = "\n"
|
||||
ns_gen = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
@ -1630,15 +1640,15 @@ TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
|
||||
# Used in CPUFunctions_inl.h and etc.
|
||||
def get_namespaced_declaration(
|
||||
*,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_idx: BackendIndex,
|
||||
selector: SelectiveBuilder,
|
||||
rocm: bool,
|
||||
symint: bool,
|
||||
) -> list[str]:
|
||||
declarations: list[str] = []
|
||||
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||
) -> List[str]:
|
||||
declarations: List[str] = []
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
newline = "\n"
|
||||
func = dest.RegisterDispatchKey(
|
||||
backend_idx,
|
||||
@ -1682,8 +1692,8 @@ def get_native_function_schema_registrations(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
schema_selector: SelectiveBuilder,
|
||||
) -> tuple[list[str], str]:
|
||||
ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||
) -> Tuple[List[str], str]:
|
||||
ns_native_functions: Dict[str, List[NativeFunction]] = defaultdict(list)
|
||||
for native_function in native_functions:
|
||||
ns_native_functions[native_function.namespace].append(native_function)
|
||||
schema_registrations = ""
|
||||
@ -1717,14 +1727,14 @@ def get_native_function_schema_registrations(
|
||||
def gen_aggregated_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
static_dispatch_idx: list[BackendIndex],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
functions_keys: set[DispatchKey],
|
||||
functions_keys: Set[DispatchKey],
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
rocm: bool,
|
||||
) -> None:
|
||||
@ -1838,25 +1848,25 @@ def gen_aggregated_headers(
|
||||
def gen_per_operator_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
static_dispatch_idx: list[BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
ops_fm: FileManager,
|
||||
functions_keys: set[DispatchKey],
|
||||
functions_keys: Set[DispatchKey],
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
rocm: bool,
|
||||
) -> None:
|
||||
# For CMake builds, split operator declarations into separate headers in
|
||||
# the ATen/ops folder to split up header dependencies
|
||||
functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
|
||||
functions_by_root_name: Dict[str, List[NativeFunction]] = defaultdict(list)
|
||||
for fn in native_functions:
|
||||
functions_by_root_name[fn.root_name].append(fn)
|
||||
|
||||
grouped_functions_by_root_name: dict[
|
||||
str, list[NativeFunction | NativeFunctionsGroup]
|
||||
grouped_functions_by_root_name: Dict[
|
||||
str, List[Union[NativeFunction, NativeFunctionsGroup]]
|
||||
] = defaultdict(list)
|
||||
for group in grouped_native_functions:
|
||||
name = group.root_name
|
||||
@ -2032,18 +2042,18 @@ def gen_per_operator_headers(
|
||||
def gen_headers(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
valid_tags: set[str],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
valid_tags: Set[str],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
static_dispatch_idx: list[BackendIndex],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
selector: SelectiveBuilder,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
core_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
ops_fm: FileManager,
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
functions_keys: set[DispatchKey],
|
||||
functions_keys: Set[DispatchKey],
|
||||
rocm: bool,
|
||||
per_operator_headers: bool,
|
||||
) -> None:
|
||||
@ -2123,8 +2133,8 @@ def gen_headers(
|
||||
"VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
|
||||
)
|
||||
|
||||
def gen_aten_interned_strings() -> dict[str, str]:
|
||||
attrs: set[str] = set() # All function argument names
|
||||
def gen_aten_interned_strings() -> Dict[str, str]:
|
||||
attrs: Set[str] = set() # All function argument names
|
||||
names = set() # All ATen function names
|
||||
for func in native_functions:
|
||||
names.add(str(func.func.name.name))
|
||||
@ -2161,7 +2171,7 @@ def gen_headers(
|
||||
|
||||
core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
|
||||
|
||||
def gen_tags_enum() -> dict[str, str]:
|
||||
def gen_tags_enum() -> Dict[str, str]:
|
||||
return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
|
||||
|
||||
core_fm.write("enum_tag.h", gen_tags_enum)
|
||||
@ -2170,19 +2180,19 @@ def gen_headers(
|
||||
def gen_source_files(
|
||||
*,
|
||||
native_functions: Sequence[NativeFunction],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
structured_native_functions: Sequence[NativeFunctionsGroup],
|
||||
view_groups: Sequence[NativeFunctionsViewGroup],
|
||||
selector: SelectiveBuilder,
|
||||
static_dispatch_idx: list[BackendIndex],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
static_dispatch_idx: List[BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
aoti_fm: FileManager,
|
||||
core_fm: FileManager,
|
||||
cpu_fm: FileManager,
|
||||
cpu_vec_fm: FileManager,
|
||||
cuda_fm: FileManager,
|
||||
dispatch_keys: Sequence[DispatchKey],
|
||||
functions_keys: set[DispatchKey],
|
||||
functions_keys: Set[DispatchKey],
|
||||
rocm: bool,
|
||||
force_schema_registration: bool,
|
||||
per_operator_headers: bool,
|
||||
@ -2206,7 +2216,7 @@ def gen_source_files(
|
||||
|
||||
if per_operator_headers:
|
||||
|
||||
def operator_headers() -> list[str]:
|
||||
def operator_headers() -> List[str]:
|
||||
headers = []
|
||||
for g in grouped_native_functions:
|
||||
is_registered = False
|
||||
@ -2248,7 +2258,7 @@ def gen_source_files(
|
||||
|
||||
else:
|
||||
|
||||
def operator_headers() -> list[str]:
|
||||
def operator_headers() -> List[str]:
|
||||
headers = ["#include <ATen/NativeFunctions.h>"]
|
||||
if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
|
||||
headers.append("#include <ATen/Functions.h>")
|
||||
@ -2439,7 +2449,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
del fm
|
||||
|
||||
# BackendSelect is generated specially
|
||||
def gen_backend_select() -> dict[str, list[str]]:
|
||||
def gen_backend_select() -> Dict[str, List[str]]:
|
||||
relevant_fns = [
|
||||
fn for fn in native_functions if needs_backend_select(fn, selector)
|
||||
]
|
||||
@ -2484,7 +2494,7 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
)
|
||||
|
||||
def key_func(
|
||||
fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
fn: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
) -> str:
|
||||
return fn.root_name
|
||||
|
||||
@ -2526,11 +2536,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
)
|
||||
|
||||
def functionalization_env_callable(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> dict[str, list[str]]:
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
) -> Dict[str, List[str]]:
|
||||
def gen_op_headers(
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
) -> List[str]:
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
# view ops always get a functionalization kernel
|
||||
headers = [
|
||||
@ -2580,8 +2590,8 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
),
|
||||
}
|
||||
|
||||
all_groups: list[
|
||||
NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
|
||||
all_groups: List[
|
||||
Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup]
|
||||
] = list(structured_native_functions) + list(
|
||||
view_groups # type: ignore[assignment, arg-type, operator]
|
||||
)
|
||||
@ -2590,11 +2600,11 @@ codegen to generate the correct cpp call for this op. Contact AOTInductor team f
|
||||
# (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
|
||||
# (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
|
||||
# Although this could go away long-term if we add a dedicated dispatch key for decompositions.
|
||||
structured_map: dict[OperatorName, NativeFunction] = {
|
||||
structured_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f
|
||||
for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
|
||||
}
|
||||
view_map: dict[OperatorName, NativeFunction] = {
|
||||
view_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
|
||||
}
|
||||
for f in native_functions:
|
||||
@ -2705,12 +2715,12 @@ def gen_declarations_yaml(
|
||||
)
|
||||
|
||||
|
||||
def get_torchgen_root() -> Path:
|
||||
def get_torchgen_root() -> pathlib.Path:
|
||||
"""
|
||||
If you're depending on torchgen out-of-tree, you can use the root to figure
|
||||
out the path to native_functions.yaml
|
||||
"""
|
||||
return Path(__file__).parent.resolve()
|
||||
return pathlib.Path(__file__).parent.resolve()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@ -2872,11 +2882,11 @@ def main() -> None:
|
||||
#
|
||||
# Invalid character escape '\c'.
|
||||
core_install_dir = f"{options.install_dir}/core"
|
||||
Path(core_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(core_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
ops_install_dir = f"{options.install_dir}/ops"
|
||||
Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
aoti_install_dir = f"{options.aoti_install_dir}"
|
||||
Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
pathlib.Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
core_fm = make_file_manager(options=options, install_dir=core_install_dir)
|
||||
cpu_fm = make_file_manager(options=options)
|
||||
@ -2906,7 +2916,7 @@ def main() -> None:
|
||||
if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
|
||||
]
|
||||
|
||||
static_dispatch_idx: list[BackendIndex] = []
|
||||
static_dispatch_idx: List[BackendIndex] = []
|
||||
if options.static_dispatch_backend:
|
||||
static_dispatch_idx = [
|
||||
backend_indices[DispatchKey.parse(key)]
|
||||
@ -2963,7 +2973,7 @@ def main() -> None:
|
||||
gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
|
||||
|
||||
if options.output_dependencies:
|
||||
depfile_path = Path(options.output_dependencies).resolve()
|
||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
||||
depfile_name = depfile_path.name
|
||||
depfile_stem = depfile_path.stem
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
|
||||
@ -71,7 +69,7 @@ base_type_to_callsite_expr = {
|
||||
|
||||
|
||||
# convert args to C types, names in declarations, and expressions in function bodies
|
||||
def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return]
|
||||
def convert_arg_type_and_name(typ: Type, name: str) -> Tuple[List[str], List[str], List[str], List[str]]: # type: ignore[return]
|
||||
if isinstance(typ, BaseType):
|
||||
if typ.name in base_type_to_c_type:
|
||||
return (
|
||||
@ -169,12 +167,12 @@ def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str
|
||||
)
|
||||
|
||||
|
||||
def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
|
||||
def zip_type_and_name(types: List[str], names: List[str]) -> List[str]:
|
||||
return [typ + " " + name for typ, name in zip(types, names)]
|
||||
|
||||
|
||||
# Generate argument declarations and callsite expressions
|
||||
def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[str]]:
|
||||
def gen_arguments(flat_arguments: Sequence[Argument]) -> Tuple[List[str], List[str]]:
|
||||
types = []
|
||||
new_names = []
|
||||
callsite_exprs = []
|
||||
@ -191,7 +189,7 @@ def gen_arguments(flat_arguments: Sequence[Argument]) -> tuple[list[str], list[s
|
||||
# Return values are passed out as pointer arguments because all the C shim functions
|
||||
# are expected to return AOTITorchError.
|
||||
# Generate returns as declarations and callsite expressions
|
||||
def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
|
||||
def gen_returns(schema: FunctionSchema) -> Tuple[List[str], List[str]]:
|
||||
types = []
|
||||
names = []
|
||||
for idx, ret in enumerate(schema.returns):
|
||||
@ -224,7 +222,7 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
|
||||
ret_pointer_can_be_null = True
|
||||
break
|
||||
|
||||
callsite_exprs: list[str] = []
|
||||
callsite_exprs: List[str] = []
|
||||
for idx, ret in enumerate(schema.returns):
|
||||
tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
|
||||
assert isinstance(ret.type, BaseType)
|
||||
@ -238,12 +236,12 @@ def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
|
||||
|
||||
|
||||
# gen.py generates header first and then src, so caching the result here to avoid duplicate work
|
||||
declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
|
||||
declaration_definition_cache: Dict[Tuple[str, str, str], Tuple[str, str]] = {}
|
||||
|
||||
|
||||
def gen_declaration_and_definition(
|
||||
schema: FunctionSchema, device: str, backend_call: str
|
||||
) -> tuple[str, str]:
|
||||
) -> Tuple[str, str]:
|
||||
func_name = schema.name.unambiguous_name()
|
||||
|
||||
global declaration_definition_cache
|
||||
@ -256,7 +254,7 @@ def gen_declaration_and_definition(
|
||||
args, callsite_exprs = gen_arguments(
|
||||
[*schema.arguments.out, *schema.arguments.flat_non_out]
|
||||
)
|
||||
ret_assignments: list[str] = []
|
||||
ret_assignments: List[str] = []
|
||||
else:
|
||||
args, callsite_exprs = gen_arguments(schema.arguments.flat_all)
|
||||
# ignore return values for inplace ops
|
||||
@ -286,7 +284,7 @@ def gen_declaration_and_definition(
|
||||
|
||||
|
||||
def gen_static_dispatch_backend_call_signature(
|
||||
sig: CppSignature | DispatcherSignature,
|
||||
sig: Union[CppSignature, DispatcherSignature],
|
||||
f: NativeFunction,
|
||||
) -> CppSignature:
|
||||
sig = DispatcherSignature.from_schema(f.func)
|
||||
@ -312,10 +310,10 @@ def gen_static_dispatch_backend_call(
|
||||
|
||||
def get_backend_index_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
) -> BackendIndex | None:
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> Optional[BackendIndex]:
|
||||
backend_index = None
|
||||
if backend_indices[dispatch_key].has_kernel(func) or (
|
||||
func.structured_delegate is not None
|
||||
@ -343,10 +341,10 @@ def get_backend_index_for_aoti(
|
||||
|
||||
def get_header_for_aoti(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
) -> str | None:
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices
|
||||
)
|
||||
@ -367,11 +365,11 @@ def get_fallback_op_name(func: NativeFunction) -> str:
|
||||
|
||||
def gen_c_shim(
|
||||
func: NativeFunction,
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
) -> str | None:
|
||||
) -> Optional[str]:
|
||||
backend_index = get_backend_index_for_aoti(
|
||||
func, func_group_mapping, dispatch_key, backend_indices
|
||||
)
|
||||
@ -401,16 +399,16 @@ def gen_c_shim(
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShimGenerator:
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
|
||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup]
|
||||
dispatch_key: DispatchKey
|
||||
backend_indices: dict[DispatchKey, BackendIndex]
|
||||
backend_indices: Dict[DispatchKey, BackendIndex]
|
||||
header: bool # True to generate .h and False to generate .cpp
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(
|
||||
self,
|
||||
func: NativeFunction,
|
||||
) -> str | None:
|
||||
) -> Optional[str]:
|
||||
result = gen_c_shim(
|
||||
func,
|
||||
self.func_group_mapping,
|
||||
@ -423,9 +421,9 @@ class ShimGenerator:
|
||||
|
||||
def gen_aoti_c_shim(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
|
||||
func_group_mapping: Dict[OperatorName, NativeFunctionsGroup],
|
||||
dispatch_key: DispatchKey,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
header: bool,
|
||||
includes: str = "",
|
||||
) -> str:
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from collections import Counter, defaultdict, namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Set, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@ -38,10 +36,10 @@ ParsedExternalYaml = namedtuple(
|
||||
|
||||
def parse_backend_yaml(
|
||||
backend_yaml_path: str,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
) -> ParsedExternalYaml:
|
||||
native_functions_map: dict[OperatorName, NativeFunction] = {
|
||||
native_functions_map: Dict[OperatorName, NativeFunction] = {
|
||||
f.func.name: f
|
||||
for f in concatMap(
|
||||
lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
|
||||
@ -121,14 +119,14 @@ def parse_backend_yaml(
|
||||
Only the following keys are supported: {", ".join(valid_keys)}'
|
||||
|
||||
def create_backend_index(
|
||||
backend_ops: list[str],
|
||||
symint_ops: set[str],
|
||||
backend_ops: List[str],
|
||||
symint_ops: Set[str],
|
||||
dispatch_key: DispatchKey,
|
||||
*,
|
||||
use_out_as_primary: bool,
|
||||
use_device_guard: bool,
|
||||
) -> BackendIndex:
|
||||
metadata: dict[OperatorName, BackendMetadata] = {}
|
||||
metadata: Dict[OperatorName, BackendMetadata] = {}
|
||||
for op in backend_ops:
|
||||
op_name = OperatorName.parse(op)
|
||||
assert (
|
||||
@ -151,7 +149,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||
index=metadata,
|
||||
)
|
||||
|
||||
backend_key: DispatchKey | None = None
|
||||
backend_key: Optional[DispatchKey] = None
|
||||
if len(supported) > 0:
|
||||
with context(
|
||||
lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
|
||||
@ -168,7 +166,7 @@ Only the following keys are supported: {", ".join(valid_keys)}'
|
||||
assert backend_key not in backend_indices
|
||||
backend_indices[backend_key] = backend_idx
|
||||
|
||||
autograd_key: DispatchKey | None = None
|
||||
autograd_key: Optional[DispatchKey] = None
|
||||
if len(supported_autograd) > 0:
|
||||
with context(
|
||||
lambda: f'The "autograd" key was specified, which indicates that you would like to override \
|
||||
@ -247,12 +245,12 @@ autograd key. They cannot be mix and matched. If this is something you need, fee
|
||||
|
||||
def error_on_missing_kernels(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
backend_key: DispatchKey,
|
||||
autograd_key: DispatchKey | None,
|
||||
autograd_key: Optional[DispatchKey],
|
||||
class_name: str,
|
||||
kernel_defn_file_path: str,
|
||||
full_codegen: list[OperatorName] | None = None,
|
||||
full_codegen: Optional[List[OperatorName]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
with open(kernel_defn_file_path) as f:
|
||||
@ -270,7 +268,7 @@ def error_on_missing_kernels(
|
||||
)
|
||||
# Quick mapping from each OperatorName used by the external backend
|
||||
# to its backend kernel name
|
||||
expected_backend_op_names: dict[OperatorName, str] = dict(
|
||||
expected_backend_op_names: Dict[OperatorName, str] = dict(
|
||||
list(
|
||||
concatMap(
|
||||
lambda index: [
|
||||
@ -280,13 +278,13 @@ def error_on_missing_kernels(
|
||||
)
|
||||
)
|
||||
)
|
||||
expected_backend_native_funcs: list[NativeFunction] = [
|
||||
expected_backend_native_funcs: List[NativeFunction] = [
|
||||
f
|
||||
for f in native_functions
|
||||
if f.func.name in expected_backend_op_names.keys()
|
||||
and f.func.name not in full_codegen
|
||||
]
|
||||
expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
|
||||
expected_backend_kernel_name_counts: Dict[str, List[NativeFunction]] = defaultdict(
|
||||
list
|
||||
)
|
||||
for native_f in expected_backend_native_funcs:
|
||||
@ -358,10 +356,10 @@ def gen_dispatchkey_nativefunc_headers(
|
||||
fm: FileManager,
|
||||
class_name: str,
|
||||
cpp_namespace: str,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_dispatch_key: DispatchKey,
|
||||
autograd_dispatch_key: DispatchKey | None,
|
||||
autograd_dispatch_key: Optional[DispatchKey],
|
||||
backend_name: str = "",
|
||||
) -> None:
|
||||
assert class_name is not None
|
||||
@ -415,11 +413,11 @@ def gen_dispatcher_registrations(
|
||||
fm: FileManager,
|
||||
output_dir: str,
|
||||
class_name: str,
|
||||
backend_indices: dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
backend_indices: Dict[DispatchKey, BackendIndex],
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
backend_dispatch_key: DispatchKey,
|
||||
dispatch_key: DispatchKey,
|
||||
selector: SelectiveBuilder,
|
||||
selector: "SelectiveBuilder",
|
||||
# build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
|
||||
build_in_tree: bool = False,
|
||||
per_operator_headers: bool = False,
|
||||
@ -526,7 +524,7 @@ TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
|
||||
|
||||
|
||||
def run(
|
||||
source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
|
||||
source_yaml: str, output_dir: str, dry_run: bool, impl_path: Optional[str] = None
|
||||
) -> None:
|
||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||
pytorch_root = Path(__file__).absolute().parent.parent
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Sequence, TextIO, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
|
||||
|
||||
import yaml
|
||||
|
||||
@ -47,6 +45,7 @@ from torchgen.model import (
|
||||
OperatorName,
|
||||
Variant,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import (
|
||||
context,
|
||||
FileManager,
|
||||
@ -56,11 +55,7 @@ from torchgen.utils import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
|
||||
def _sig_decl_wrapper(sig: Union[CppSignature, ExecutorchCppSignature]) -> str:
|
||||
"""
|
||||
A wrapper function to basically get `sig.decl(include_context=True)`.
|
||||
For ATen kernel, the codegen has no idea about ET contextArg, so we
|
||||
@ -77,9 +72,9 @@ def _sig_decl_wrapper(sig: CppSignature | ExecutorchCppSignature) -> str:
|
||||
|
||||
|
||||
def static_dispatch(
|
||||
sig: CppSignature | ExecutorchCppSignature,
|
||||
sig: Union[CppSignature, ExecutorchCppSignature],
|
||||
f: NativeFunction,
|
||||
backend_indices: list[BackendIndex],
|
||||
backend_indices: List[BackendIndex],
|
||||
) -> str:
|
||||
"""
|
||||
For a given `NativeFunction`, find out the corresponding native function and dispatch to it. If zero or more than one
|
||||
@ -118,7 +113,7 @@ TORCH_API inline {_sig_decl_wrapper(sig)} {{
|
||||
# and the scaffolding to call into the dispatcher from these functions.
|
||||
@dataclass(frozen=True)
|
||||
class ComputeFunction:
|
||||
static_dispatch_backend_indices: list[BackendIndex]
|
||||
static_dispatch_backend_indices: List[BackendIndex]
|
||||
|
||||
selector: SelectiveBuilder
|
||||
|
||||
@ -127,7 +122,7 @@ class ComputeFunction:
|
||||
is_custom_op: Callable[[NativeFunction], bool]
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
is_method_variant = False
|
||||
if not self.selector.is_root_operator(f"{f.namespace}::{f.func.name}"):
|
||||
return None
|
||||
@ -141,7 +136,7 @@ class ComputeFunction:
|
||||
f"Can't handle native function {f.func} with the following variant specification {f.variants}."
|
||||
)
|
||||
|
||||
sig: CppSignature | ExecutorchCppSignature = (
|
||||
sig: Union[CppSignature, ExecutorchCppSignature] = (
|
||||
CppSignatureGroup.from_native_function(
|
||||
f, method=False, fallback_binding=f.manual_cpp_binding
|
||||
).most_faithful_signature()
|
||||
@ -184,10 +179,10 @@ class ComputeCodegenUnboxedKernels:
|
||||
@method_with_nested_native_function
|
||||
def __call__(
|
||||
self,
|
||||
unbox_kernel_entry: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]],
|
||||
unbox_kernel_entry: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]],
|
||||
) -> str:
|
||||
f: NativeFunction = unbox_kernel_entry[0]
|
||||
kernel_key: ETKernelKey | list[ETKernelKey] = unbox_kernel_entry[1][0]
|
||||
kernel_key: Union[ETKernelKey, List[ETKernelKey]] = unbox_kernel_entry[1][0]
|
||||
kernel_meta: BackendMetadata = unbox_kernel_entry[1][1]
|
||||
|
||||
op_name = f"{f.namespace}::{f.func.name}"
|
||||
@ -201,7 +196,7 @@ class ComputeCodegenUnboxedKernels:
|
||||
)
|
||||
if not used_kernel_keys:
|
||||
return ""
|
||||
sig: CppSignature | ExecutorchCppSignature
|
||||
sig: Union[CppSignature, ExecutorchCppSignature]
|
||||
argument_type_gen: Callable[..., NamedCType]
|
||||
return_type_gen: Callable[..., CType]
|
||||
if self.use_aten_lib:
|
||||
@ -295,11 +290,11 @@ def gen_unboxing(
|
||||
) -> None:
|
||||
# Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata))
|
||||
def key_func(
|
||||
item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]
|
||||
item: Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]
|
||||
) -> str:
|
||||
return item[0].root_name + ":" + item[1][0].to_native_string()
|
||||
|
||||
items: list[tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]]] = [
|
||||
items: List[Tuple[NativeFunction, Tuple[ETKernelKey, BackendMetadata]]] = [
|
||||
(native_function, (kernel_key, metadata))
|
||||
for native_function in native_functions
|
||||
for kernel_key, metadata in kernel_index.get_kernels(native_function).items()
|
||||
@ -330,8 +325,8 @@ def gen_unboxing(
|
||||
|
||||
@with_native_function_and_index # type: ignore[arg-type]
|
||||
def compute_native_function_declaration(
|
||||
g: NativeFunctionsGroup | NativeFunction, kernel_index: ETKernelIndex
|
||||
) -> list[str]:
|
||||
g: Union[NativeFunctionsGroup, NativeFunction], kernel_index: ETKernelIndex
|
||||
) -> List[str]:
|
||||
assert isinstance(g, NativeFunction)
|
||||
sig = ExecutorchCppSignature.from_native_function(f=g)
|
||||
metadata_list = kernel_index.get_kernels(g).values()
|
||||
@ -357,7 +352,7 @@ def gen_functions_declarations(
|
||||
kernel_index: ETKernelIndex,
|
||||
selector: SelectiveBuilder,
|
||||
use_aten_lib: bool,
|
||||
custom_ops_native_functions: Sequence[NativeFunction] | None = None,
|
||||
custom_ops_native_functions: Optional[Sequence[NativeFunction]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates namespace separated C++ function API inline declaration/definitions.
|
||||
@ -411,13 +406,13 @@ def get_ns_grouped_kernels(
|
||||
kernel_index: ETKernelIndex,
|
||||
native_function_decl_gen: Callable[
|
||||
[
|
||||
NativeFunctionsGroup | NativeFunction,
|
||||
Union[NativeFunctionsGroup, NativeFunction],
|
||||
ETKernelIndex,
|
||||
],
|
||||
list[str],
|
||||
List[str],
|
||||
],
|
||||
) -> dict[str, list[str]]:
|
||||
ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
|
||||
) -> Dict[str, List[str]]:
|
||||
ns_grouped_kernels: Dict[str, List[str]] = defaultdict(list)
|
||||
for f in native_functions:
|
||||
native_function_namespaces = set()
|
||||
op_kernels = kernel_index.get_kernels(f)
|
||||
@ -600,7 +595,7 @@ def gen_custom_ops(
|
||||
def translate_native_yaml(
|
||||
tags_yaml_path: str,
|
||||
aten_yaml_path: str,
|
||||
native_yaml_path: str | None,
|
||||
native_yaml_path: Optional[str],
|
||||
use_aten_lib: bool,
|
||||
out_file: TextIO,
|
||||
) -> None:
|
||||
@ -651,15 +646,15 @@ def translate_native_yaml(
|
||||
skip_native_fns_gen=False,
|
||||
)
|
||||
|
||||
func_to_scoped_name: dict[FunctionSchema, str] = {
|
||||
func_to_scoped_name: Dict[FunctionSchema, str] = {
|
||||
f.func: f"{f.namespace}::{f.func.name}" for f in native_functions
|
||||
}
|
||||
op_to_scoped_name: dict[OperatorName, str] = {
|
||||
op_to_scoped_name: Dict[OperatorName, str] = {
|
||||
func.name: name for func, name in func_to_scoped_name.items()
|
||||
}
|
||||
|
||||
schema_dict = {name: str(func) for func, name in func_to_scoped_name.items()}
|
||||
kernel_persist_dict: dict[str, dict[str, Any]] = {
|
||||
kernel_persist_dict: Dict[str, Dict[str, Any]] = {
|
||||
op_to_scoped_name[op]: v for op, v in persisted_fields.items()
|
||||
}
|
||||
|
||||
@ -697,13 +692,13 @@ def translate_native_yaml(
|
||||
|
||||
|
||||
def parse_yaml(
|
||||
path: str | None,
|
||||
path: Optional[str],
|
||||
tags_yaml_path: str,
|
||||
function_filter: Callable[[NativeFunction], bool],
|
||||
skip_native_fns_gen: bool = False,
|
||||
) -> tuple[
|
||||
list[NativeFunction],
|
||||
dict[DispatchKey, dict[OperatorName, BackendMetadata]] | ETKernelIndex,
|
||||
) -> Tuple[
|
||||
List[NativeFunction],
|
||||
Union[Dict[DispatchKey, Dict[OperatorName, BackendMetadata]], ETKernelIndex],
|
||||
]:
|
||||
if path and os.path.exists(path) and os.stat(path).st_size > 0:
|
||||
with open(path) as f:
|
||||
@ -740,8 +735,8 @@ def parse_yaml(
|
||||
|
||||
# (2) Return BackendIndices if kernel index is absent
|
||||
def map_index(
|
||||
m: dict[OperatorName, BackendMetadata]
|
||||
) -> dict[OperatorName, BackendMetadata]:
|
||||
m: Dict[OperatorName, BackendMetadata]
|
||||
) -> Dict[OperatorName, BackendMetadata]:
|
||||
return {op: m[op] for op in m if op in op_names}
|
||||
|
||||
backend_indices = {
|
||||
@ -756,11 +751,11 @@ def parse_yaml(
|
||||
def parse_yaml_files(
|
||||
tags_yaml_path: str,
|
||||
aten_yaml_path: str,
|
||||
native_yaml_path: str | None,
|
||||
custom_ops_yaml_path: str | None,
|
||||
native_yaml_path: Optional[str],
|
||||
custom_ops_yaml_path: Optional[str],
|
||||
selector: SelectiveBuilder,
|
||||
use_aten_lib: bool,
|
||||
) -> tuple[ETParsedYaml, ETParsedYaml | None]:
|
||||
) -> Tuple[ETParsedYaml, Optional[ETParsedYaml]]:
|
||||
"""Parses functions.yaml and custom_ops.yaml files.
|
||||
|
||||
Args:
|
||||
@ -983,7 +978,7 @@ def main() -> None:
|
||||
)
|
||||
|
||||
if options.output_dependencies:
|
||||
depfile_path = Path(options.output_dependencies).resolve()
|
||||
depfile_path = pathlib.Path(options.output_dependencies).resolve()
|
||||
depfile_name = depfile_path.name
|
||||
depfile_stem = depfile_path.stem
|
||||
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
from torchgen.api import cpp, dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
@ -48,13 +46,10 @@ from torchgen.native_function_generation import (
|
||||
MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT,
|
||||
OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY,
|
||||
)
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
from torchgen.utils import dataclass_repr
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.selective_build.selector import SelectiveBuilder
|
||||
|
||||
|
||||
# Note: [Mutable Ops Not Using Functionalization]
|
||||
# Ops in this list currently do not work with functionalization and should be fixed.
|
||||
MUTABLE_OPS_NOT_USING_FUNCTIONALIZATION = (
|
||||
@ -93,7 +88,7 @@ class GenCompositeViewCopyKernel:
|
||||
backend_index: BackendIndex
|
||||
|
||||
@method_with_native_function
|
||||
def __call__(self, g: NativeFunctionsViewGroup) -> str | None:
|
||||
def __call__(self, g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
if g.view_copy is None:
|
||||
return None
|
||||
elif g.view_copy.func.name.name.base != f"{g.view.func.name.name}_copy":
|
||||
@ -165,7 +160,7 @@ at::Tensor view_copy_symint(const at::Tensor & self, at::SymIntArrayRef size) {
|
||||
"""
|
||||
|
||||
|
||||
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
|
||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
||||
assert len(rets) == len(names)
|
||||
if len(rets) == 0:
|
||||
return ""
|
||||
@ -189,7 +184,7 @@ def wrapper_name(func: FunctionSchema) -> str:
|
||||
return cpp.name(func)
|
||||
|
||||
|
||||
def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
|
||||
def is_tensor_like(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> bool:
|
||||
return isinstance(a, SelfArgument) or (
|
||||
isinstance(a, Argument) and a.type.is_tensor_like()
|
||||
)
|
||||
@ -199,7 +194,7 @@ def is_tensor_like(a: Argument | TensorOptionsArguments | SelfArgument) -> bool:
|
||||
# Some op schemas include non-owning types though (like TensorList),
|
||||
# and when we unwrap them we expect to get out an owning type!.
|
||||
# We also return a lambda that tells you how to conver the non-owning type argument into the owning type.
|
||||
def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
|
||||
def get_owning_type(t: CType) -> Tuple[CType, Callable[[str], str]]:
|
||||
if t == BaseCType(tensorListT):
|
||||
return VectorCType(BaseCType(tensorT)), lambda x: f"{x}.vec()"
|
||||
if t == BaseCType(iTensorListRefT):
|
||||
@ -214,9 +209,9 @@ def get_owning_type(t: CType) -> tuple[CType, Callable[[str], str]]:
|
||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||
def unwrap_tensor_args(
|
||||
sig: DispatcherSignature, *, is_view_op: bool
|
||||
) -> tuple[str, list[Binding]]:
|
||||
context: list[Binding] = []
|
||||
unwrapped_tensor_args: list[str] = []
|
||||
) -> Tuple[str, List[Binding]]:
|
||||
context: List[Binding] = []
|
||||
unwrapped_tensor_args: List[str] = []
|
||||
for arg in sig.arguments():
|
||||
if is_tensor_like(arg.argument):
|
||||
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
||||
@ -252,9 +247,9 @@ def unwrap_tensor_args(
|
||||
# converts all tensor-like arguments to meta tensors, which are used to compute stride info. Returns:
|
||||
# (1) a string containing all of the logic that does the conversions.
|
||||
# (2) a context, to be used by translate(), with all of the relevant bindings.
|
||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
|
||||
context: list[Binding] = []
|
||||
unwrapped_tensor_args: list[str] = []
|
||||
def convert_to_meta_tensors(sig: DispatcherSignature) -> Tuple[str, List[Binding]]:
|
||||
context: List[Binding] = []
|
||||
unwrapped_tensor_args: List[str] = []
|
||||
for arg in sig.arguments():
|
||||
if is_tensor_like(arg.argument):
|
||||
# for tensor inputs, we want to unwrap them before passing them into the redispatch calls.
|
||||
@ -322,7 +317,7 @@ def emit_expr_has_symbolic_values(expr: str, type: CType) -> str:
|
||||
|
||||
# Detects whether any of the SymInt arguments are, in fact, symbolic values.
|
||||
# This is used in the constructor of ViewMeta.
|
||||
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> tuple[str, str]:
|
||||
def emit_has_symbolic_inputs(sig: DispatcherSignature) -> Tuple[str, str]:
|
||||
name = "has_symbolic_inputs"
|
||||
statements = [
|
||||
f"{name} = {name} | ({emit_expr_has_symbolic_values(binding.name, binding.nctype.type)});"
|
||||
@ -527,7 +522,7 @@ def maybe_create_output(f: NativeFunction, var_name: str) -> str:
|
||||
# - the names of returns corresponding to the (immutable) outputs of the inner redispatched function
|
||||
def get_mutable_redispatch_return_names(
|
||||
f: NativeFunction, inner_return_var: str
|
||||
) -> tuple[list[str], list[str]]:
|
||||
) -> Tuple[List[str], List[str]]:
|
||||
aliased_returns = []
|
||||
non_aliased_returns = []
|
||||
for i, name in enumerate(f.func.aliased_return_names()):
|
||||
@ -756,11 +751,11 @@ def emit_inplace_functionalization_body(
|
||||
# See Note [Functionalization Pass: View Inverses].
|
||||
def gen_functionalization_view_inverse_declaration(
|
||||
selector: SelectiveBuilder, g: NativeFunctionsViewGroup
|
||||
) -> str | None:
|
||||
) -> Optional[str]:
|
||||
# For every (non-composite) view op, we need a corresponding "inverse view" function.
|
||||
# This generates the declarations so we get a good compiler error when someone adds a new view.
|
||||
@with_native_function
|
||||
def emit_decl_helper(g: NativeFunctionsViewGroup) -> str | None:
|
||||
def emit_decl_helper(g: NativeFunctionsViewGroup) -> Optional[str]:
|
||||
if g.view.has_composite_implicit_autograd_kernel:
|
||||
return None
|
||||
view_inverse_sig = ViewInverseSignature(g)
|
||||
@ -771,9 +766,9 @@ def gen_functionalization_view_inverse_declaration(
|
||||
|
||||
def gen_functionalization_registration(
|
||||
selector: SelectiveBuilder,
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
||||
composite_implicit_autograd_index: BackendIndex,
|
||||
) -> list[str]:
|
||||
) -> List[str]:
|
||||
@with_native_function
|
||||
def emit_registration_helper(f: NativeFunction) -> str:
|
||||
assert not f.has_composite_implicit_autograd_kernel
|
||||
@ -837,8 +832,8 @@ def gen_functionalization_definition(
|
||||
# (and instead only need to operate on grouped NativeFunctions).
|
||||
# The only reason currently is because we need to emit direct dispatch registrations
|
||||
# For CompositeImplicitAutograd operators, which are potentially ungrouped.
|
||||
g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
|
||||
) -> list[str]:
|
||||
g: Union[NativeFunction, NativeFunctionsGroup, NativeFunctionsViewGroup],
|
||||
) -> List[str]:
|
||||
# Don't generate kernels in mobile build
|
||||
if not selector.include_all_operators:
|
||||
return []
|
||||
|
@ -1,10 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Iterable, Iterator, Sequence
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import yaml
|
||||
|
||||
@ -93,8 +102,8 @@ ParsedExternalYaml = namedtuple(
|
||||
|
||||
def parse_native_functions_keys(
|
||||
backend_yaml_path: str,
|
||||
grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
|
||||
) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
|
||||
grouped_native_functions: Sequence[Union[NativeFunction, NativeFunctionsGroup]],
|
||||
) -> Tuple[List[OperatorName], List[Any], List[OperatorName]]:
|
||||
with open(backend_yaml_path) as f:
|
||||
yaml_values = yaml.load(f, Loader=YamlLoader)
|
||||
assert isinstance(yaml_values, dict)
|
||||
@ -111,7 +120,7 @@ def parse_native_functions_keys(
|
||||
|
||||
|
||||
def validate_shape_inference_header(
|
||||
shape_inference_hdr: str, expected_shape_infr_decls: list[str]
|
||||
shape_inference_hdr: str, expected_shape_infr_decls: List[str]
|
||||
) -> None:
|
||||
try:
|
||||
with open(shape_inference_hdr) as f:
|
||||
@ -171,12 +180,12 @@ std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
|
||||
|
||||
class default_args:
|
||||
node_base: str = "Node"
|
||||
node_base_hdr: str | None = None
|
||||
node_base_hdr: Optional[str] = None
|
||||
shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
|
||||
tensor_class: str = "torch::lazy::LazyTensor"
|
||||
tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
|
||||
lazy_ir_generator: type[GenLazyIR] = GenLazyIR
|
||||
native_func_definition_generator: type[
|
||||
lazy_ir_generator: Type[GenLazyIR] = GenLazyIR
|
||||
native_func_definition_generator: Type[
|
||||
GenLazyNativeFuncDefinition
|
||||
] = GenLazyNativeFuncDefinition
|
||||
backend_name: str = "TorchScript"
|
||||
@ -254,10 +263,10 @@ def main() -> None:
|
||||
# Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
|
||||
torch_root = Path(__file__).absolute().parents[2]
|
||||
aten_path = str(torch_root / "aten" / "src" / "ATen")
|
||||
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
|
||||
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator
|
||||
if options.gen_ts_lowerings:
|
||||
lazy_ir_generator = GenTSLazyIR
|
||||
native_func_definition_generator: type[
|
||||
native_func_definition_generator: Type[
|
||||
GenLazyNativeFuncDefinition
|
||||
] = default_args.native_func_definition_generator
|
||||
|
||||
@ -283,14 +292,14 @@ def run_gen_lazy_tensor(
|
||||
source_yaml: str,
|
||||
output_dir: str,
|
||||
dry_run: bool,
|
||||
impl_path: str | None,
|
||||
impl_path: Optional[str],
|
||||
node_base: str = default_args.node_base,
|
||||
node_base_hdr: str | None = default_args.node_base_hdr,
|
||||
node_base_hdr: Optional[str] = default_args.node_base_hdr,
|
||||
tensor_class: str = default_args.tensor_class,
|
||||
tensor_class_hdr: str = default_args.tensor_class_hdr,
|
||||
shape_inference_hdr: str = default_args.shape_inference_hdr,
|
||||
lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
|
||||
native_func_definition_generator: type[
|
||||
lazy_ir_generator: Type[GenLazyIR] = default_args.lazy_ir_generator,
|
||||
native_func_definition_generator: Type[
|
||||
GenLazyNativeFuncDefinition
|
||||
] = default_args.native_func_definition_generator,
|
||||
# build_in_tree is true for TS backend and affects include paths
|
||||
@ -338,7 +347,7 @@ def run_gen_lazy_tensor(
|
||||
)
|
||||
grouped_native_functions = get_grouped_native_functions(native_functions)
|
||||
|
||||
def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
|
||||
def sort_native_function(f: Union[NativeFunctionsGroup, NativeFunction]) -> str:
|
||||
"""
|
||||
We sort the native function because of the note in concat_map_codegen.
|
||||
TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
|
||||
@ -368,8 +377,8 @@ def run_gen_lazy_tensor(
|
||||
|
||||
def concat_map_codegen(
|
||||
func: Callable[[NativeFunction], Sequence[str]],
|
||||
xs: Iterable[NativeFunctionsGroup | NativeFunction],
|
||||
ops_list: list[OperatorName] = full_codegen,
|
||||
xs: Iterable[Union[NativeFunctionsGroup, NativeFunction]],
|
||||
ops_list: List[OperatorName] = full_codegen,
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
|
||||
|
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
from torchgen.api.translate import translate
|
||||
from torchgen.api.types import DispatcherSignature
|
||||
@ -34,7 +32,7 @@ def is_tensor_list(typ: Type) -> bool:
|
||||
return isinstance(typ, ListType) and is_tensor(typ.elem)
|
||||
|
||||
|
||||
def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
def unwrap_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
result = f"""\
|
||||
Tensor {name}_value;
|
||||
optional<int64_t> {name}_bdim;
|
||||
@ -42,7 +40,7 @@ def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
return textwrap.dedent(result).split("\n")
|
||||
|
||||
|
||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
def unwrap_optional_tensor(name: str, cur_level_var: str) -> List[str]:
|
||||
result = f"""\
|
||||
optional<Tensor> {name}_value;
|
||||
optional<int64_t> {name}_bdim;
|
||||
@ -54,7 +52,7 @@ def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
|
||||
|
||||
def gen_unwraps(
|
||||
flat_arguments: Sequence[Argument], cur_level_var: str
|
||||
) -> tuple[str, list[str]]:
|
||||
) -> Tuple[str, List[str]]:
|
||||
arg_names = [a.name for a in flat_arguments]
|
||||
arg_types = [a.type for a in flat_arguments]
|
||||
|
||||
@ -101,7 +99,7 @@ if ({' && '.join(conditions)}) {{
|
||||
|
||||
|
||||
def gen_returns(
|
||||
returns: tuple[Return, ...], cur_level_var: str, results_var: str
|
||||
returns: Tuple[Return, ...], cur_level_var: str, results_var: str
|
||||
) -> str:
|
||||
idx = 0
|
||||
wrapped_returns = []
|
||||
@ -134,7 +132,7 @@ def is_mutated_arg(argument: Argument) -> bool:
|
||||
return argument.annotation is not None and argument.annotation.is_write
|
||||
|
||||
|
||||
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
|
||||
def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||
# Assumptions:
|
||||
# - only one argument is being modified in-place
|
||||
# - the argument that is being modified in-place is the first argument
|
||||
@ -199,7 +197,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
}}"""
|
||||
|
||||
|
||||
def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
|
||||
def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
|
||||
schema = native_function.func
|
||||
sig = DispatcherSignature.from_schema(schema)
|
||||
returns = schema.returns
|
||||
@ -246,7 +244,7 @@ template <typename batch_rule_t, batch_rule_t batch_rule>
|
||||
@dataclass(frozen=True)
|
||||
class ComputeBatchRulePlumbing:
|
||||
@method_with_native_function
|
||||
def __call__(self, f: NativeFunction) -> str | None:
|
||||
def __call__(self, f: NativeFunction) -> Optional[str]:
|
||||
result = gen_vmap_plumbing(f)
|
||||
return result
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||
@ -19,8 +17,8 @@ from typing import Iterator
|
||||
|
||||
|
||||
class Locals(threading.local):
|
||||
use_const_ref_for_mutable_tensors: bool | None = None
|
||||
use_ilistref_for_tensor_lists: bool | None = None
|
||||
use_const_ref_for_mutable_tensors: Optional[bool] = None
|
||||
use_ilistref_for_tensor_lists: Optional[bool] = None
|
||||
|
||||
|
||||
_locals = Locals()
|
||||
|
@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Callable, Iterator, Sequence
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from torchgen.utils import assert_never, NamespaceHelper, OrderedSet
|
||||
|
||||
@ -231,7 +229,7 @@ class DispatchKey(Enum):
|
||||
return str(self).lower()
|
||||
|
||||
@staticmethod
|
||||
def parse(value: str) -> DispatchKey:
|
||||
def parse(value: str) -> "DispatchKey":
|
||||
for k, v in DispatchKey.__members__.items():
|
||||
if k == value:
|
||||
return v
|
||||
@ -352,20 +350,20 @@ class ScalarType(Enum):
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def maybe_parse(value: str) -> ScalarType | None:
|
||||
def maybe_parse(value: str) -> Optional["ScalarType"]:
|
||||
for k, v in ScalarType.__members__.items():
|
||||
if k == value:
|
||||
return v
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse(value: str) -> ScalarType:
|
||||
def parse(value: str) -> "ScalarType":
|
||||
mb_r = ScalarType.maybe_parse(value)
|
||||
assert mb_r is not None, f"unknown dtype {value}"
|
||||
return mb_r
|
||||
|
||||
@staticmethod
|
||||
def parse_set(values: str) -> OrderedSet[ScalarType]:
|
||||
def parse_set(values: str) -> OrderedSet["ScalarType"]:
|
||||
dtypes: OrderedSet[ScalarType] = OrderedSet()
|
||||
for value in values.split(", "):
|
||||
if value in DTYPE_CLASSES:
|
||||
@ -375,7 +373,7 @@ class ScalarType(Enum):
|
||||
return dtypes
|
||||
|
||||
|
||||
DTYPE_CLASSES: dict[str, OrderedSet[ScalarType]] = {}
|
||||
DTYPE_CLASSES: Dict[str, OrderedSet[ScalarType]] = {}
|
||||
# NB: Integral doesn't include boolean
|
||||
DTYPE_CLASSES["Integral"] = OrderedSet(
|
||||
[
|
||||
@ -421,7 +419,7 @@ class UfuncKey(Enum):
|
||||
return self.name
|
||||
|
||||
@staticmethod
|
||||
def parse(value: str) -> UfuncKey:
|
||||
def parse(value: str) -> "UfuncKey":
|
||||
for k, v in UfuncKey.__members__.items():
|
||||
if k == value:
|
||||
return v
|
||||
@ -464,7 +462,7 @@ class NativeFunction:
|
||||
# (This type is quoted as we are forward referencing a type
|
||||
# defined later in the file. I opted for this ordering of the
|
||||
# classes for expository clarity.)
|
||||
func: FunctionSchema
|
||||
func: "FunctionSchema"
|
||||
|
||||
# Whether or not to generate mutable tensor arguments like regular
|
||||
# ones
|
||||
@ -477,14 +475,14 @@ class NativeFunction:
|
||||
device_check: DeviceCheckType
|
||||
|
||||
# What python module to put the function in
|
||||
python_module: str | None
|
||||
python_module: Optional[str]
|
||||
|
||||
# TODO: figure out what this does
|
||||
category_override: str | None
|
||||
category_override: Optional[str]
|
||||
|
||||
# If no variants are specified in native_functions.yaml, this is
|
||||
# assumed to be {'function'}.
|
||||
variants: set[Variant]
|
||||
variants: Set[Variant]
|
||||
|
||||
# Whether or not we should skip generating registrations for
|
||||
# this kernel. This is a bit of a double-edged sword, as manual
|
||||
@ -499,7 +497,7 @@ class NativeFunction:
|
||||
|
||||
# The location in the YAML file were this native function entry was
|
||||
# defined. This is for conveniently reporting error messages!
|
||||
loc: Location
|
||||
loc: "Location"
|
||||
|
||||
# A list of operators that are expected to be auto-generated for this NativeFunction.
|
||||
# Note: This list isn't actually directly used by the codegen to generate anything.
|
||||
@ -507,11 +505,11 @@ class NativeFunction:
|
||||
# function schema, and uses the autogen declarations to error check.
|
||||
# We expect every NativeFunction that gets auto-generated be explicitly called out
|
||||
# in native_functions.yaml
|
||||
autogen: list[OperatorName]
|
||||
autogen: List["OperatorName"]
|
||||
|
||||
# If non-empty, this kernel is subject to ufunc codegen.
|
||||
# Sorted by ufunc_key
|
||||
ufunc_inner_loop: dict[UfuncKey, UfuncInnerLoop]
|
||||
ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"]
|
||||
|
||||
# Whether or not this out functions is a "structured kernel". Structured
|
||||
# kernels are defined a little differently from normal kernels; in
|
||||
@ -524,13 +522,13 @@ class NativeFunction:
|
||||
|
||||
# Whether or not this non-out function is a structured kernel, defined
|
||||
# in terms of the out kernel referenced by the string here.
|
||||
structured_delegate: OperatorName | None
|
||||
structured_delegate: Optional["OperatorName"]
|
||||
|
||||
# Only valid for structured kernels. Specifies alternative of what
|
||||
# to inherit from when defining the meta class for the structured
|
||||
# operator. This will usually be TensorIteratorBase. This also
|
||||
# changes the semantics of set_output to call the parent class.
|
||||
structured_inherits: str | None
|
||||
structured_inherits: Optional[str]
|
||||
|
||||
# Structured kernels can declare elements as "precomputed". These elements
|
||||
# are returned by the meta function in one struct and passed to the impl
|
||||
@ -538,11 +536,11 @@ class NativeFunction:
|
||||
# elements supersede. Information about the names and types of these
|
||||
# precomputed elements and how they correspond to kernel arguments is stored
|
||||
# in this member, if applicable.
|
||||
precomputed: Precompute | None
|
||||
precomputed: Optional["Precompute"]
|
||||
|
||||
# Argument names whose default should be excluded from the C++ interface.
|
||||
# Intended for resolving overload ambiguities between signatures.
|
||||
cpp_no_default_args: set[str]
|
||||
cpp_no_default_args: Set[str]
|
||||
|
||||
# Note [Abstract ATen methods]
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
@ -562,7 +560,7 @@ class NativeFunction:
|
||||
|
||||
# Tags are used to describe semantic information about (groups of) operators,
|
||||
# That aren't easily inferrable directly from the operator's schema.
|
||||
tags: set[str]
|
||||
tags: Set[str]
|
||||
|
||||
# NB: The benefit of defining a dataclass is that we automatically get
|
||||
# a constructor defined for all the fields we specify. No need
|
||||
@ -571,11 +569,13 @@ class NativeFunction:
|
||||
# We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex.
|
||||
@staticmethod
|
||||
def from_yaml(
|
||||
ei: dict[str, object],
|
||||
loc: Location,
|
||||
valid_tags: set[str],
|
||||
ignore_keys: set[DispatchKey] | None = None,
|
||||
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
|
||||
ei: Dict[str, object],
|
||||
loc: "Location",
|
||||
valid_tags: Set[str],
|
||||
ignore_keys: Optional[Set[DispatchKey]] = None,
|
||||
) -> Tuple[
|
||||
"NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]
|
||||
]:
|
||||
"""
|
||||
Parse a NativeFunction from a dictionary as directly parsed
|
||||
from native_functions.yaml
|
||||
@ -602,7 +602,7 @@ class NativeFunction:
|
||||
|
||||
variants_s = e.pop("variants", "function")
|
||||
assert isinstance(variants_s, str)
|
||||
variants: set[Variant] = set()
|
||||
variants: Set[Variant] = set()
|
||||
for v in variants_s.split(", "):
|
||||
if v == "function":
|
||||
variants.add(Variant.function)
|
||||
@ -646,7 +646,7 @@ class NativeFunction:
|
||||
"namespace is not supported in structured delegate,"
|
||||
" using the same namespace as the native function"
|
||||
)
|
||||
structured_delegate: OperatorName | None = None
|
||||
structured_delegate: Optional[OperatorName] = None
|
||||
if structured_delegate_s is not None:
|
||||
structured_delegate = OperatorName.parse(structured_delegate_s)
|
||||
|
||||
@ -685,7 +685,7 @@ class NativeFunction:
|
||||
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
||||
tags_inp.append("pt2_compliant_tag")
|
||||
|
||||
tags: set[str] = set()
|
||||
tags: Set[str] = set()
|
||||
for t in tags_inp:
|
||||
assert len(valid_tags) > 0
|
||||
# TODO: verify that the tag is valid and has an entry in tags.yaml
|
||||
@ -698,7 +698,7 @@ class NativeFunction:
|
||||
|
||||
raw_dispatch = e.pop("dispatch", None)
|
||||
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||||
dispatch: dict[DispatchKey, BackendMetadata] = {}
|
||||
dispatch: Dict[DispatchKey, BackendMetadata] = {}
|
||||
num_dispatch_keys: int = 0
|
||||
if raw_dispatch is not None:
|
||||
assert not manual_kernel_registration, (
|
||||
@ -1081,8 +1081,8 @@ class SchemaKind(Enum):
|
||||
@dataclass(frozen=True)
|
||||
class NativeFunctionsGroup:
|
||||
functional: NativeFunction
|
||||
inplace: NativeFunction | None
|
||||
mutable: NativeFunction | None
|
||||
inplace: Optional[NativeFunction]
|
||||
mutable: Optional[NativeFunction]
|
||||
out: NativeFunction
|
||||
|
||||
@property
|
||||
@ -1136,7 +1136,7 @@ class NativeFunctionsGroup:
|
||||
[str(f.func.name) for f in self.functions() if "generated" in f.tags]
|
||||
)
|
||||
generated_fns_str = ", ".join(str(x) for x in generated_fns)
|
||||
expected_generated_fns: set[str] = set()
|
||||
expected_generated_fns: Set[str] = set()
|
||||
for f in self.functions():
|
||||
expected_generated_fns.update(str(op) for op in f.autogen)
|
||||
expected_generated_fns_str = ", ".join(
|
||||
@ -1155,7 +1155,7 @@ class NativeFunctionsGroup:
|
||||
f" Instead, it found 'autogen: {expected_generated_fns_str}'"
|
||||
)
|
||||
|
||||
def signature(self) -> FunctionSchema:
|
||||
def signature(self) -> "FunctionSchema":
|
||||
return self.out.func.signature()
|
||||
|
||||
def functions(self) -> Iterator[NativeFunction]:
|
||||
@ -1171,7 +1171,9 @@ class NativeFunctionsGroup:
|
||||
return self.functional.root_name
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: dict[SchemaKind, NativeFunction]) -> NativeFunctionsGroup | None:
|
||||
def from_dict(
|
||||
d: Dict[SchemaKind, NativeFunction]
|
||||
) -> Optional["NativeFunctionsGroup"]:
|
||||
assert d
|
||||
if len(d) == 1:
|
||||
return None
|
||||
@ -1227,7 +1229,7 @@ class UfuncInnerLoop:
|
||||
ufunc_key: UfuncKey
|
||||
|
||||
@staticmethod
|
||||
def parse(value: str, ufunc_key: UfuncKey) -> UfuncInnerLoop:
|
||||
def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop":
|
||||
name, supported_dtypes_str = value.split(" ", 1)
|
||||
assert supported_dtypes_str[0] == "("
|
||||
assert supported_dtypes_str[-1] == ")"
|
||||
@ -1259,12 +1261,12 @@ class BackendIndex:
|
||||
# Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA)
|
||||
external: bool
|
||||
# Other backend-specific information that is on a per-operator basis
|
||||
index: dict[OperatorName, BackendMetadata]
|
||||
index: Dict["OperatorName", BackendMetadata]
|
||||
|
||||
@staticmethod
|
||||
def grow_index(
|
||||
parent_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
child_index: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
||||
child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]],
|
||||
) -> None:
|
||||
for k, v in child_index.items():
|
||||
for op_name, metadata in v.items():
|
||||
@ -1279,13 +1281,13 @@ class BackendIndex:
|
||||
else:
|
||||
return g.functional
|
||||
|
||||
def has_kernel(self, g: NativeFunction | NativeFunctionsGroup) -> bool:
|
||||
def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool:
|
||||
m = self.get_kernel(g)
|
||||
return m is not None
|
||||
|
||||
def get_kernel(
|
||||
self, g: NativeFunction | NativeFunctionsGroup
|
||||
) -> BackendMetadata | None:
|
||||
self, g: Union[NativeFunction, NativeFunctionsGroup]
|
||||
) -> Optional[BackendMetadata]:
|
||||
if isinstance(g, NativeFunction):
|
||||
f = g
|
||||
elif isinstance(g, NativeFunctionsGroup):
|
||||
@ -1296,7 +1298,7 @@ class BackendIndex:
|
||||
return None
|
||||
return self.index[f.func.name]
|
||||
|
||||
def native_function_class_name(self) -> str | None:
|
||||
def native_function_class_name(self) -> Optional[str]:
|
||||
if self.external:
|
||||
return f"{str(self.dispatch_key)}NativeFunctions"
|
||||
else:
|
||||
@ -1362,16 +1364,16 @@ class BackendIndex:
|
||||
@dataclass(frozen=True)
|
||||
class FunctionSchema:
|
||||
# The name of the operator this function schema describes.
|
||||
name: OperatorName
|
||||
name: "OperatorName"
|
||||
|
||||
arguments: Arguments
|
||||
arguments: "Arguments"
|
||||
|
||||
# TODO: Need to handle collisions with argument names at some point
|
||||
returns: tuple[Return, ...]
|
||||
returns: Tuple["Return", ...]
|
||||
|
||||
@property
|
||||
def is_mutable(self) -> bool:
|
||||
def is_write(arg: Argument) -> bool:
|
||||
def is_write(arg: "Argument") -> bool:
|
||||
if arg.annotation is None:
|
||||
return False
|
||||
return arg.annotation.is_write
|
||||
@ -1380,7 +1382,7 @@ class FunctionSchema:
|
||||
# See aten/src/ATen/core/function_schema.h (keep these in sync)
|
||||
return any(is_write(a) for a in self.arguments.flat_all)
|
||||
|
||||
def schema_order_arguments(self) -> Iterator[Argument]:
|
||||
def schema_order_arguments(self) -> Iterator["Argument"]:
|
||||
return itertools.chain(
|
||||
self.arguments.flat_positional,
|
||||
self.arguments.flat_kwarg_only,
|
||||
@ -1390,7 +1392,7 @@ class FunctionSchema:
|
||||
decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)")
|
||||
|
||||
@staticmethod
|
||||
def parse(func: str) -> FunctionSchema:
|
||||
def parse(func: str) -> "FunctionSchema":
|
||||
# We should probably get a proper parser here
|
||||
decls = FunctionSchema.decl_re.findall(func)
|
||||
assert len(decls) == 1, f"Invalid function schema: {func}"
|
||||
@ -1585,8 +1587,8 @@ class FunctionSchema:
|
||||
# - If the return aliases an input, we return the input name
|
||||
# - Otherwise, we return None.
|
||||
# If return names were enforced to be consistent with aliasing information, then we wouldn't need this.
|
||||
def aliased_return_names(self) -> list[str | None]:
|
||||
outs: list[str | None] = []
|
||||
def aliased_return_names(self) -> List[Optional[str]]:
|
||||
outs: List[Optional[str]] = []
|
||||
for r in self.returns:
|
||||
aliased_args = [
|
||||
a
|
||||
@ -1610,7 +1612,7 @@ class FunctionSchema:
|
||||
strip_default: bool = False,
|
||||
strip_view_copy_name: bool = False,
|
||||
keep_return_names: bool = False,
|
||||
) -> FunctionSchema:
|
||||
) -> "FunctionSchema":
|
||||
"""
|
||||
Certain schemas are 'related', in that they are simply
|
||||
inplace/out/functional versions of the same function. This method
|
||||
@ -1707,10 +1709,10 @@ class FunctionSchema:
|
||||
returns=returns,
|
||||
)
|
||||
|
||||
def view_signature(self) -> FunctionSchema:
|
||||
def view_signature(self) -> "FunctionSchema":
|
||||
return self.signature(strip_view_copy_name=True)
|
||||
|
||||
def with_name(self, name: OperatorName) -> FunctionSchema:
|
||||
def with_name(self, name: "OperatorName") -> "FunctionSchema":
|
||||
return FunctionSchema(
|
||||
name=name,
|
||||
arguments=self.arguments,
|
||||
@ -1745,12 +1747,12 @@ class FunctionSchema:
|
||||
class Annotation:
|
||||
# Typically only has one element. Not actually a set so
|
||||
# we can conveniently assume it is canonically ordered
|
||||
alias_set: tuple[str, ...]
|
||||
alias_set: Tuple[str, ...]
|
||||
is_write: bool
|
||||
alias_set_after: tuple[str, ...]
|
||||
alias_set_after: Tuple[str, ...]
|
||||
|
||||
@staticmethod
|
||||
def parse(ann: str) -> Annotation:
|
||||
def parse(ann: str) -> "Annotation":
|
||||
# TODO: implement a proper parser if this gets more ugly
|
||||
# Regex Explanation:
|
||||
# Example: "a! -> a|b"
|
||||
@ -1803,13 +1805,13 @@ class Annotation:
|
||||
@dataclass(frozen=True)
|
||||
class Type:
|
||||
@staticmethod
|
||||
def parse(t: str) -> Type:
|
||||
def parse(t: str) -> "Type":
|
||||
r = Type._parse(t)
|
||||
assert str(r) == t, f"{r} != {t}"
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def _parse(t: str) -> Type:
|
||||
def _parse(t: str) -> "Type":
|
||||
m = re.match(r"^(.+)\?$", t)
|
||||
if m is not None:
|
||||
return OptionalType(Type.parse(m.group(1)))
|
||||
@ -1835,7 +1837,7 @@ class Type:
|
||||
# so we can conveniently generate legacy Declarations.yaml but
|
||||
# really we should probably just remove these at some point
|
||||
|
||||
def is_base_ty_like(self, base_ty: BaseTy) -> bool:
|
||||
def is_base_ty_like(self, base_ty: "BaseTy") -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_tensor_like(self) -> bool:
|
||||
@ -1850,7 +1852,7 @@ class Type:
|
||||
def is_nullable(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_list_like(self) -> ListType | None:
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1890,7 +1892,7 @@ class BaseType(Type):
|
||||
def is_nullable(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_list_like(self) -> ListType | None:
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return None
|
||||
|
||||
def is_symint_like(self) -> bool:
|
||||
@ -1914,7 +1916,7 @@ class OptionalType(Type):
|
||||
def is_nullable(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_list_like(self) -> ListType | None:
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return self.elem.is_list_like()
|
||||
|
||||
|
||||
@ -1941,7 +1943,7 @@ class CustomClassType(Type):
|
||||
"""
|
||||
return False
|
||||
|
||||
def is_list_like(self) -> ListType | None:
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return None
|
||||
|
||||
|
||||
@ -1955,7 +1957,7 @@ class CustomClassType(Type):
|
||||
@dataclass(frozen=True)
|
||||
class ListType(Type):
|
||||
elem: Type
|
||||
size: int | None
|
||||
size: Optional[int]
|
||||
|
||||
def __str__(self) -> str:
|
||||
size = f"{self.size}" if self.size else ""
|
||||
@ -1970,7 +1972,7 @@ class ListType(Type):
|
||||
def is_nullable(self) -> bool:
|
||||
return self.elem.is_nullable()
|
||||
|
||||
def is_list_like(self) -> ListType | None:
|
||||
def is_list_like(self) -> Optional["ListType"]:
|
||||
return self
|
||||
|
||||
|
||||
@ -1981,7 +1983,7 @@ class Argument:
|
||||
|
||||
name: str
|
||||
type: Type
|
||||
default: str | None
|
||||
default: Optional[str]
|
||||
|
||||
# The semantics of the annotation field are a little strange.
|
||||
#
|
||||
@ -2002,16 +2004,16 @@ class Argument:
|
||||
# structure of annotated types is very simple. So we just hard
|
||||
# code it here. But if we ever do get anything more complex, this
|
||||
# model will have to change!
|
||||
annotation: Annotation | None
|
||||
annotation: Optional[Annotation]
|
||||
|
||||
@property
|
||||
def alias_info(self) -> Annotation | None:
|
||||
def alias_info(self) -> Optional[Annotation]:
|
||||
return self.annotation
|
||||
|
||||
@staticmethod
|
||||
def parse(arg: str) -> Argument:
|
||||
def parse(arg: str) -> "Argument":
|
||||
name: str
|
||||
default: str | None
|
||||
default: Optional[str]
|
||||
assert " " in arg, f"illegal argument '{arg}'"
|
||||
type_and_annot, name_and_default = arg.rsplit(" ", 1)
|
||||
if "=" in name_and_default:
|
||||
@ -2024,7 +2026,7 @@ class Argument:
|
||||
default = None
|
||||
# TODO: deduplicate annotation matching with Return
|
||||
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||||
annotation: Annotation | None
|
||||
annotation: Optional[Annotation]
|
||||
if match:
|
||||
# If you update this, make sure the __str__ still works too
|
||||
assert match.group(2) in [
|
||||
@ -2067,24 +2069,24 @@ class Argument:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Return:
|
||||
name: str | None
|
||||
name: Optional[str]
|
||||
type: Type
|
||||
annotation: Annotation | None
|
||||
annotation: Optional[Annotation]
|
||||
|
||||
@property
|
||||
def alias_info(self) -> Annotation | None:
|
||||
def alias_info(self) -> Optional[Annotation]:
|
||||
return self.annotation
|
||||
|
||||
@staticmethod
|
||||
def parse(arg: str) -> Return:
|
||||
name: str | None
|
||||
def parse(arg: str) -> "Return":
|
||||
name: Optional[str]
|
||||
if " " in arg:
|
||||
type_and_annot, name = arg.rsplit(" ", 1)
|
||||
else:
|
||||
type_and_annot = arg
|
||||
name = None
|
||||
match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot)
|
||||
annotation: Annotation | None
|
||||
annotation: Optional[Annotation]
|
||||
if match:
|
||||
# If you update this, make sure the __str__ still works too
|
||||
assert match.group(2) in [
|
||||
@ -2146,34 +2148,34 @@ class Arguments:
|
||||
# pre_self_positional is usually empty, but is notably non-empty
|
||||
# for where.self, where the condition argument comes before the
|
||||
# self argument
|
||||
pre_self_positional: tuple[Argument, ...]
|
||||
self_arg: SelfArgument | None
|
||||
post_self_positional: tuple[Argument, ...]
|
||||
pre_self_positional: Tuple[Argument, ...]
|
||||
self_arg: Optional[SelfArgument]
|
||||
post_self_positional: Tuple[Argument, ...]
|
||||
|
||||
pre_tensor_options_kwarg_only: tuple[Argument, ...]
|
||||
tensor_options: TensorOptionsArguments | None
|
||||
pre_tensor_options_kwarg_only: Tuple[Argument, ...]
|
||||
tensor_options: Optional[TensorOptionsArguments]
|
||||
# post_tensor_options is typically memory format, which should be
|
||||
# part of tensor options but isn't right now, and is usually
|
||||
# placed after the tensor options arguments
|
||||
post_tensor_options_kwarg_only: tuple[Argument, ...]
|
||||
post_tensor_options_kwarg_only: Tuple[Argument, ...]
|
||||
|
||||
# Unlike in the previous codegen, we have factored out 'out' arguments
|
||||
# in the canonical representation, removing them from kwarg
|
||||
# arguments. This choice is justified by numerous downstream
|
||||
# transformations which treat out arguments specially; additionally,
|
||||
# you can see that canonicity is not violated!
|
||||
out: tuple[Argument, ...] # these are also kwarg-only
|
||||
out: Tuple[Argument, ...] # these are also kwarg-only
|
||||
|
||||
@property
|
||||
def flat_non_out(self) -> Sequence[Argument]:
|
||||
ret: list[Argument] = []
|
||||
ret: List[Argument] = []
|
||||
ret.extend(self.flat_positional)
|
||||
ret.extend(self.flat_kwarg_only)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def flat_positional(self) -> Sequence[Argument]:
|
||||
ret: list[Argument] = []
|
||||
ret: List[Argument] = []
|
||||
ret.extend(self.pre_self_positional)
|
||||
if self.self_arg is not None:
|
||||
ret.append(self.self_arg.argument)
|
||||
@ -2187,7 +2189,7 @@ class Arguments:
|
||||
# NB: doesn't contain out arguments
|
||||
@property
|
||||
def flat_kwarg_only(self) -> Sequence[Argument]:
|
||||
ret: list[Argument] = []
|
||||
ret: List[Argument] = []
|
||||
ret.extend(self.pre_tensor_options_kwarg_only)
|
||||
if self.tensor_options is not None:
|
||||
ret.extend(self.tensor_options.all())
|
||||
@ -2196,7 +2198,7 @@ class Arguments:
|
||||
|
||||
@property
|
||||
def flat_all(self) -> Sequence[Argument]:
|
||||
ret: list[Argument] = []
|
||||
ret: List[Argument] = []
|
||||
ret.extend(self.flat_positional)
|
||||
ret.extend(self.flat_kwarg_only)
|
||||
ret.extend(self.out)
|
||||
@ -2205,15 +2207,15 @@ class Arguments:
|
||||
@property
|
||||
def non_out(
|
||||
self,
|
||||
) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
||||
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
||||
) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
||||
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
||||
ret.extend(self.positional)
|
||||
ret.extend(self.kwarg_only)
|
||||
return ret
|
||||
|
||||
@property
|
||||
def positional(self) -> Sequence[Argument | SelfArgument]:
|
||||
ret: list[Argument | SelfArgument] = []
|
||||
def positional(self) -> Sequence[Union[Argument, SelfArgument]]:
|
||||
ret: List[Union[Argument, SelfArgument]] = []
|
||||
ret.extend(self.pre_self_positional)
|
||||
if self.self_arg is not None:
|
||||
ret.append(self.self_arg)
|
||||
@ -2221,8 +2223,8 @@ class Arguments:
|
||||
return ret
|
||||
|
||||
@property
|
||||
def kwarg_only(self) -> Sequence[Argument | TensorOptionsArguments]:
|
||||
ret: list[Argument | TensorOptionsArguments] = []
|
||||
def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]:
|
||||
ret: List[Union[Argument, TensorOptionsArguments]] = []
|
||||
ret.extend(self.pre_tensor_options_kwarg_only)
|
||||
if self.tensor_options is not None:
|
||||
ret.append(self.tensor_options)
|
||||
@ -2230,14 +2232,14 @@ class Arguments:
|
||||
return ret
|
||||
|
||||
@property
|
||||
def all(self) -> Sequence[Argument | SelfArgument | TensorOptionsArguments]:
|
||||
ret: list[Argument | SelfArgument | TensorOptionsArguments] = []
|
||||
def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]:
|
||||
ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = []
|
||||
ret.extend(self.positional)
|
||||
ret.extend(self.kwarg_only)
|
||||
ret.extend(self.out)
|
||||
return ret
|
||||
|
||||
def mutable_arg_names(self) -> list[str]:
|
||||
def mutable_arg_names(self) -> List[str]:
|
||||
return [
|
||||
a.name
|
||||
for a in self.flat_all
|
||||
@ -2253,7 +2255,7 @@ class Arguments:
|
||||
def has_generator_arg(self) -> bool:
|
||||
return any(a.type.is_generator_like() for a in self.flat_non_out)
|
||||
|
||||
def signature(self, *, strip_default: bool = False) -> Arguments:
|
||||
def signature(self, *, strip_default: bool = False) -> "Arguments":
|
||||
# dataclasses.replace could be used here, but it is less
|
||||
# type safe so for now I've opted to type everything out
|
||||
def strip_arg_annotation(a: Argument) -> Argument:
|
||||
@ -2288,7 +2290,7 @@ class Arguments:
|
||||
out=(),
|
||||
)
|
||||
|
||||
def remove_self_annotation(self) -> Arguments:
|
||||
def remove_self_annotation(self) -> "Arguments":
|
||||
assert self.self_arg is not None
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
@ -2297,7 +2299,7 @@ class Arguments:
|
||||
),
|
||||
)
|
||||
|
||||
def with_out_args(self, outs: list[Argument]) -> Arguments:
|
||||
def with_out_args(self, outs: List[Argument]) -> "Arguments":
|
||||
assert len(self.out) == 0
|
||||
return dataclasses.replace(
|
||||
self,
|
||||
@ -2305,10 +2307,10 @@ class Arguments:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _preparse(args: str) -> tuple[list[Argument], list[Argument], list[Argument]]:
|
||||
positional: list[Argument] = []
|
||||
kwarg_only: list[Argument] = []
|
||||
out: list[Argument] = []
|
||||
def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]:
|
||||
positional: List[Argument] = []
|
||||
kwarg_only: List[Argument] = []
|
||||
out: List[Argument] = []
|
||||
arguments_acc = positional
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
@ -2341,7 +2343,7 @@ class Arguments:
|
||||
return positional, kwarg_only, out
|
||||
|
||||
@staticmethod
|
||||
def parse(args: str) -> Arguments:
|
||||
def parse(args: str) -> "Arguments":
|
||||
"""
|
||||
Input: 'int x, int y, int z'
|
||||
"""
|
||||
@ -2359,9 +2361,9 @@ class Arguments:
|
||||
if a.name == "self":
|
||||
self_ix = i
|
||||
break
|
||||
pre_self_positional: list[Argument]
|
||||
self_arg: SelfArgument | None
|
||||
post_self_positional: list[Argument]
|
||||
pre_self_positional: List[Argument]
|
||||
self_arg: Optional[SelfArgument]
|
||||
post_self_positional: List[Argument]
|
||||
if self_ix is not None:
|
||||
pre_self_positional = positional[:self_ix]
|
||||
self_arg = SelfArgument(positional[self_ix])
|
||||
@ -2372,9 +2374,9 @@ class Arguments:
|
||||
post_self_positional = positional
|
||||
|
||||
# Group tensor options arguments
|
||||
pre_tensor_options_kwarg_only: list[Argument] = []
|
||||
tensor_options: TensorOptionsArguments | None = None
|
||||
post_tensor_options_kwarg_only: list[Argument] = []
|
||||
pre_tensor_options_kwarg_only: List[Argument] = []
|
||||
tensor_options: Optional[TensorOptionsArguments] = None
|
||||
post_tensor_options_kwarg_only: List[Argument] = []
|
||||
kwarg_only_acc = pre_tensor_options_kwarg_only
|
||||
|
||||
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
||||
@ -2421,7 +2423,7 @@ class Arguments:
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
all_arguments: list[str] = []
|
||||
all_arguments: List[str] = []
|
||||
all_arguments.extend(map(str, self.flat_positional))
|
||||
if self.flat_kwarg_only or self.out:
|
||||
all_arguments.append("*")
|
||||
@ -2500,7 +2502,7 @@ class BaseOperatorName:
|
||||
functional_overload: bool = False
|
||||
|
||||
@staticmethod
|
||||
def parse(op: str) -> BaseOperatorName:
|
||||
def parse(op: str) -> "BaseOperatorName":
|
||||
assert op != ""
|
||||
assert not op.endswith("_out"), (
|
||||
"_out suffix is reserved and not permitted for operator names; "
|
||||
@ -2572,7 +2574,7 @@ class OperatorName:
|
||||
overload_name: str
|
||||
|
||||
@staticmethod
|
||||
def parse(op_name: str) -> OperatorName:
|
||||
def parse(op_name: str) -> "OperatorName":
|
||||
if "." in op_name:
|
||||
name, overload_name = op_name.split(".", 1)
|
||||
else:
|
||||
@ -2599,7 +2601,7 @@ class OperatorName:
|
||||
else:
|
||||
return f"{self.name}"
|
||||
|
||||
def remove_inplace(self) -> OperatorName:
|
||||
def remove_inplace(self) -> "OperatorName":
|
||||
return OperatorName(
|
||||
name=BaseOperatorName(
|
||||
base=self.name.base,
|
||||
@ -2609,7 +2611,7 @@ class OperatorName:
|
||||
overload_name=self.overload_name,
|
||||
)
|
||||
|
||||
def with_overload(self, overload: str) -> OperatorName:
|
||||
def with_overload(self, overload: str) -> "OperatorName":
|
||||
return OperatorName(
|
||||
name=BaseOperatorName(
|
||||
base=self.name.base,
|
||||
@ -2647,9 +2649,9 @@ class NativeFunctionsViewGroup:
|
||||
# Note: the {view}_copy operator is optional because we currently don't generate copy variants
|
||||
# for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views
|
||||
# (we already get them "for free" through decomposition)
|
||||
view_copy: NativeFunction | None
|
||||
view_copy: Optional[NativeFunction]
|
||||
# view_inplace ops are also optional, but every view_inplace op should have out-of-place variant.
|
||||
view_inplace: NativeFunction | None
|
||||
view_inplace: Optional[NativeFunction]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.view.is_view_op
|
||||
@ -2729,7 +2731,7 @@ def gets_generated_view_copy(f: NativeFunction) -> bool:
|
||||
|
||||
# Given a NativeFunction that corresponds to a view op,
|
||||
# returns the OperatorName of the corresponding "copy" variant of the op.
|
||||
def get_view_copy_name(f: NativeFunction) -> OperatorName:
|
||||
def get_view_copy_name(f: NativeFunction) -> "OperatorName":
|
||||
# Right now, when asking for a view op's corresponding "view_copy" name
|
||||
# we assert for sanity that the op is allowed to have a generated view_copy variant.
|
||||
# (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op).
|
||||
@ -2753,7 +2755,7 @@ def get_view_copy_name(f: NativeFunction) -> OperatorName:
|
||||
# Helper functions for parsing argument lists (both inputs and returns)
|
||||
|
||||
|
||||
def parse_returns(return_decl: str) -> tuple[Return, ...]:
|
||||
def parse_returns(return_decl: str) -> Tuple[Return, ...]:
|
||||
"""
|
||||
Input: '()'
|
||||
Output: []
|
||||
@ -2772,12 +2774,12 @@ def parse_returns(return_decl: str) -> tuple[Return, ...]:
|
||||
class Precompute:
|
||||
# A map from kernel argument name -> a list of precomputed
|
||||
# elements that replaces/supersedes it.
|
||||
replace: dict[str, list[Argument]]
|
||||
replace: Dict[str, List[Argument]]
|
||||
# List of precomputed args added without replacement
|
||||
add: list[Argument]
|
||||
add: List[Argument]
|
||||
|
||||
@staticmethod
|
||||
def parse(src: object) -> Precompute:
|
||||
def parse(src: object) -> "Precompute":
|
||||
assert isinstance(src, list)
|
||||
|
||||
# src is a list of strings of the format:
|
||||
@ -2822,7 +2824,7 @@ class Precompute:
|
||||
for a in args:
|
||||
assert a.name.upper() != a.name
|
||||
|
||||
def to_list(self) -> list[str]:
|
||||
def to_list(self) -> List[str]:
|
||||
replace_list = []
|
||||
for kernel_param, replacement_params in self.replace.items():
|
||||
replacements = ", ".join(str(param) for param in replacement_params)
|
||||
|
@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torchgen.api.dispatcher as dispatcher
|
||||
from torchgen.api.translate import translate
|
||||
@ -103,9 +101,9 @@ INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
|
||||
# But have differing SchemaKinds.
|
||||
def pre_group_native_functions(
|
||||
native_functions: Sequence[NativeFunction],
|
||||
) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
|
||||
pre_grouped_native_functions: dict[
|
||||
FunctionSchema, dict[SchemaKind, NativeFunction]
|
||||
) -> Dict[FunctionSchema, Dict[SchemaKind, NativeFunction]]:
|
||||
pre_grouped_native_functions: Dict[
|
||||
FunctionSchema, Dict[SchemaKind, NativeFunction]
|
||||
] = defaultdict(dict)
|
||||
for f in native_functions:
|
||||
d = pre_grouped_native_functions[f.func.signature()]
|
||||
@ -115,7 +113,7 @@ def pre_group_native_functions(
|
||||
|
||||
|
||||
# Returns the out variant overload name given a base function overload name
|
||||
def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
|
||||
def get_expected_out_variant_overload_name(overload_name: Optional[str]) -> str:
|
||||
return "out" if not overload_name else f"{overload_name}_out"
|
||||
|
||||
|
||||
@ -180,7 +178,7 @@ def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
||||
# Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
|
||||
def generate_out_args_from_schema(
|
||||
func: FunctionSchema,
|
||||
) -> tuple[list[Return], list[Argument]]:
|
||||
) -> Tuple[List[Return], List[Argument]]:
|
||||
# More of a sanity check - our existing restrictions on schemas should enforce that
|
||||
# mutable schema kinds never return their mutable arguments.
|
||||
assert not any(
|
||||
@ -200,11 +198,11 @@ def generate_out_args_from_schema(
|
||||
|
||||
all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
|
||||
|
||||
new_out_args: list[Argument] = []
|
||||
new_out_args: List[Argument] = []
|
||||
# The end result of new_returns is that:
|
||||
# - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
|
||||
# - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
|
||||
new_returns: list[Return] = []
|
||||
new_returns: List[Return] = []
|
||||
for i, r in enumerate(func.returns):
|
||||
if r.type.is_tensor_like():
|
||||
new_out = Argument(
|
||||
@ -268,7 +266,7 @@ def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
|
||||
# Details are in the function, but we only generate composite kernels (in some cases) today.
|
||||
def generate_function(
|
||||
f: NativeFunction, k: SchemaKind
|
||||
) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
|
||||
) -> Tuple[NativeFunction, Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]]]:
|
||||
from torchgen.api import cpp
|
||||
|
||||
if k == SchemaKind.functional:
|
||||
@ -377,8 +375,8 @@ def generate_function(
|
||||
# Note: this function *mutates* its two inputs,
|
||||
# adding the new NativeFunctions / BackendMetadata to them
|
||||
def add_generated_native_functions(
|
||||
rs: list[NativeFunction],
|
||||
indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
|
||||
rs: List[NativeFunction],
|
||||
indices: Dict[DispatchKey, Dict[OperatorName, BackendMetadata]],
|
||||
) -> None:
|
||||
# The main code for generating new NativeFunctions
|
||||
# First we group of NativeFunctions by schema kind,
|
||||
@ -499,7 +497,7 @@ out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THA
|
||||
rs.append(fn)
|
||||
|
||||
|
||||
def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
|
||||
def return_str(rets: Tuple[Return, ...], names: List[str]) -> str:
|
||||
assert len(rets) == len(names)
|
||||
if len(rets) == 0:
|
||||
return ""
|
||||
@ -511,7 +509,7 @@ def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
|
||||
|
||||
# Given a function, and the name of a variable corresponding to the output of that function,
|
||||
# gather up all of the individual returns that are not aliased
|
||||
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
|
||||
def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> List[str]:
|
||||
aliased_rets = func.aliased_return_names()
|
||||
non_aliased_names = []
|
||||
is_out_var_a_tuple = len(func.returns) > 1
|
||||
@ -526,7 +524,7 @@ def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str
|
||||
# Generates functional kernels in terms of their inplace.mutable counterparts.
|
||||
# We only do this for "generated" NativeFunctions
|
||||
@with_native_function
|
||||
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||
def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
# We should only be generating these for code-generated NativeFunctions
|
||||
if "generated" not in g.functional.tags:
|
||||
return None
|
||||
@ -543,7 +541,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||
sig = DispatcherSignature(g.functional.func)
|
||||
target_sig = DispatcherSignature(target_f.func)
|
||||
|
||||
context: list[Binding | Expr] = []
|
||||
context: List[Union[Binding, Expr]] = []
|
||||
clone_mutable_inputs = []
|
||||
cloned_return_names = []
|
||||
# We can't just directly pass all of the arguments from the functional op into the mutating op.
|
||||
@ -589,7 +587,7 @@ def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||
# Generates out= kernels in terms of their functional counterparts.
|
||||
# We only do this for "generated" NativeFunctions
|
||||
@with_native_function
|
||||
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
|
||||
def gen_composite_out_kernel(g: NativeFunctionsGroup) -> Optional[str]:
|
||||
# We should only be generating these for code-generated NativeFunctions
|
||||
if "generated" not in g.out.tags:
|
||||
return None
|
||||
|
@ -1,12 +1,9 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from enum import Enum
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch.jit.generate_bytecode import generate_upgraders_bytecode
|
||||
@ -188,7 +185,7 @@ PER_OPERATOR_UPGRADER_LIST = CodeTemplate(
|
||||
)
|
||||
|
||||
|
||||
def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
|
||||
def construct_instruction(instruction_list_from_yaml: List[Any]) -> str:
|
||||
instruction_list_part = []
|
||||
for instruction in instruction_list_from_yaml:
|
||||
instruction_list_part.append(
|
||||
@ -203,7 +200,7 @@ def construct_instruction(instruction_list_from_yaml: list[Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def construct_constants(constants_list_from_yaml: list[Any]) -> str:
|
||||
def construct_constants(constants_list_from_yaml: List[Any]) -> str:
|
||||
constants_list_part = []
|
||||
for constant_from_yaml in constants_list_from_yaml:
|
||||
convert_constant = None
|
||||
@ -229,7 +226,7 @@ def construct_constants(constants_list_from_yaml: list[Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def construct_operators(operator_list_from_yaml: list[Any]) -> str:
|
||||
def construct_operators(operator_list_from_yaml: List[Any]) -> str:
|
||||
operator_list_part = []
|
||||
for operator in operator_list_from_yaml:
|
||||
operator_list_part.append(
|
||||
@ -244,7 +241,7 @@ def construct_operators(operator_list_from_yaml: list[Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
def construct_types(types_tr_list_from_yaml: list[Any]) -> str:
|
||||
def construct_types(types_tr_list_from_yaml: List[Any]) -> str:
|
||||
types_tr_list_part = []
|
||||
for types_tr in types_tr_list_from_yaml:
|
||||
types_tr_list_part.append(ONE_TYPE.substitute(type_str=types_tr))
|
||||
@ -263,7 +260,7 @@ def construct_register_size(register_size_from_yaml: int) -> str:
|
||||
|
||||
|
||||
def construct_version_maps(
|
||||
upgrader_bytecode_function_to_index_map: dict[str, Any]
|
||||
upgrader_bytecode_function_to_index_map: Dict[str, Any]
|
||||
) -> str:
|
||||
version_map = torch._C._get_operator_version_map()
|
||||
sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return]
|
||||
@ -305,8 +302,8 @@ def construct_version_maps(
|
||||
|
||||
|
||||
def get_upgrader_bytecode_function_to_index_map(
|
||||
upgrader_dict: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
upgrader_dict: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
upgrader_bytecode_function_to_index_map = {}
|
||||
index = 0
|
||||
for upgrader_bytecode in upgrader_dict:
|
||||
@ -318,7 +315,7 @@ def get_upgrader_bytecode_function_to_index_map(
|
||||
return upgrader_bytecode_function_to_index_map
|
||||
|
||||
|
||||
def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
|
||||
def write_cpp(cpp_path: str, upgrader_dict: List[Dict[str, Any]]) -> None:
|
||||
body_parts = []
|
||||
upgrader_bytecode_function_to_index_map = (
|
||||
get_upgrader_bytecode_function_to_index_map(upgrader_dict)
|
||||
@ -373,7 +370,7 @@ def write_cpp(cpp_path: str, upgrader_dict: list[dict[str, Any]]) -> None:
|
||||
out_file.write(upgrader_file_content.encode("utf-8"))
|
||||
|
||||
|
||||
def sort_upgrader(upgrader_list: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
def sort_upgrader(upgrader_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
sorted_upgrader_list = sorted(
|
||||
upgrader_list, key=lambda one_upgrader: next(iter(one_upgrader))
|
||||
)
|
||||
|
@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
# This class holds information about a single operator used to determine
|
||||
@ -47,12 +46,12 @@ class SelectiveBuildOperator:
|
||||
include_all_overloads: bool
|
||||
|
||||
# Debug Information at the operator level
|
||||
_debug_info: tuple[str, ...] | None
|
||||
_debug_info: Optional[Tuple[str, ...]]
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_dict(
|
||||
op_name: str, op_info: dict[str, object]
|
||||
) -> SelectiveBuildOperator:
|
||||
op_name: str, op_info: Dict[str, object]
|
||||
) -> "SelectiveBuildOperator":
|
||||
allowed_keys = {
|
||||
"name",
|
||||
"is_root_operator",
|
||||
@ -80,7 +79,7 @@ class SelectiveBuildOperator:
|
||||
include_all_overloads = op_info.get("include_all_overloads", True)
|
||||
assert isinstance(include_all_overloads, bool)
|
||||
|
||||
debug_info: tuple[str, ...] | None = None
|
||||
debug_info: Optional[Tuple[str, ...]] = None
|
||||
if "debug_info" in op_info:
|
||||
di_list = op_info["debug_info"]
|
||||
assert isinstance(di_list, list)
|
||||
@ -97,7 +96,7 @@ class SelectiveBuildOperator:
|
||||
@staticmethod
|
||||
def from_legacy_operator_name_without_overload(
|
||||
name: str,
|
||||
) -> SelectiveBuildOperator:
|
||||
) -> "SelectiveBuildOperator":
|
||||
return SelectiveBuildOperator(
|
||||
name=name,
|
||||
is_root_operator=True,
|
||||
@ -106,8 +105,8 @@ class SelectiveBuildOperator:
|
||||
_debug_info=None,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
ret: dict[str, object] = {
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
ret: Dict[str, object] = {
|
||||
"is_root_operator": self.is_root_operator,
|
||||
"is_used_for_training": self.is_used_for_training,
|
||||
"include_all_overloads": self.include_all_overloads,
|
||||
@ -119,9 +118,9 @@ class SelectiveBuildOperator:
|
||||
|
||||
|
||||
def merge_debug_info(
|
||||
lhs: tuple[str, ...] | None,
|
||||
rhs: tuple[str, ...] | None,
|
||||
) -> tuple[str, ...] | None:
|
||||
lhs: Optional[Tuple[str, ...]],
|
||||
rhs: Optional[Tuple[str, ...]],
|
||||
) -> Optional[Tuple[str, ...]]:
|
||||
# Ensure that when merging, each entry shows up just once.
|
||||
if lhs is None and rhs is None:
|
||||
return None
|
||||
@ -130,8 +129,8 @@ def merge_debug_info(
|
||||
|
||||
|
||||
def combine_operators(
|
||||
lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator
|
||||
) -> SelectiveBuildOperator:
|
||||
lhs: "SelectiveBuildOperator", rhs: "SelectiveBuildOperator"
|
||||
) -> "SelectiveBuildOperator":
|
||||
if str(lhs.name) != str(rhs.name):
|
||||
raise Exception( # noqa: TRY002
|
||||
f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead"
|
||||
@ -153,10 +152,10 @@ def combine_operators(
|
||||
|
||||
|
||||
def merge_operator_dicts(
|
||||
lhs: dict[str, SelectiveBuildOperator],
|
||||
rhs: dict[str, SelectiveBuildOperator],
|
||||
) -> dict[str, SelectiveBuildOperator]:
|
||||
operators: dict[str, SelectiveBuildOperator] = {}
|
||||
lhs: Dict[str, SelectiveBuildOperator],
|
||||
rhs: Dict[str, SelectiveBuildOperator],
|
||||
) -> Dict[str, SelectiveBuildOperator]:
|
||||
operators: Dict[str, SelectiveBuildOperator] = {}
|
||||
for op_name, op in list(lhs.items()) + list(rhs.items()):
|
||||
new_op = op
|
||||
if op_name in operators:
|
||||
|
@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from torchgen.model import NativeFunction
|
||||
from torchgen.selective_build.operator import (
|
||||
merge_debug_info,
|
||||
merge_operator_dicts,
|
||||
@ -15,10 +14,6 @@ from torchgen.selective_build.operator import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torchgen.model import NativeFunction
|
||||
|
||||
|
||||
# A SelectiveBuilder holds information extracted from the selective build
|
||||
# YAML specification.
|
||||
#
|
||||
@ -33,10 +28,10 @@ class SelectiveBuilder:
|
||||
include_all_operators: bool
|
||||
|
||||
# Debug Information at the selective/custom build level.
|
||||
_debug_info: tuple[str, ...] | None
|
||||
_debug_info: Optional[Tuple[str, ...]]
|
||||
|
||||
# A dictionary of operator -> operator metadata.
|
||||
operators: dict[str, SelectiveBuildOperator]
|
||||
operators: Dict[str, SelectiveBuildOperator]
|
||||
|
||||
# A dictionary of selected kernel tags and dtypes. Typically a
|
||||
# PyTorch Operator Kernel (function) may have many code paths
|
||||
@ -44,22 +39,22 @@ class SelectiveBuilder:
|
||||
# one per kernel function, but there could be many per kernel
|
||||
# function. The tag isn't a kernel function name, but some fragment
|
||||
# of the kernel function implementation itself.
|
||||
kernel_metadata: dict[str, list[str]]
|
||||
kernel_metadata: Dict[str, List[str]]
|
||||
|
||||
# ExecuTorch only. A dictionary of kernel tag -> list of (list of input
|
||||
# dtypes for tensor-like input args).
|
||||
# This is from selective.yaml
|
||||
et_kernel_metadata: dict[str, list[str]]
|
||||
et_kernel_metadata: Dict[str, List[str]]
|
||||
|
||||
# A set of all the custom torch bind classes used by the selected models
|
||||
# Stored as a set internally to remove duplicates proactively, but written
|
||||
# as a list to yamls
|
||||
custom_classes: set[str]
|
||||
custom_classes: Set[str]
|
||||
|
||||
# A set of all the build features used by the selected models
|
||||
# Stored as a set internally to remove duplicates proactively, but written
|
||||
# as a list to yamls
|
||||
build_features: set[str]
|
||||
build_features: Set[str]
|
||||
|
||||
# If true, then fragments for all dtypes for all kernel functions
|
||||
# are included as well as all custom classes. This is typically set when any one of the
|
||||
@ -68,11 +63,11 @@ class SelectiveBuilder:
|
||||
include_all_non_op_selectives: bool
|
||||
|
||||
@staticmethod
|
||||
def get_nop_selector() -> SelectiveBuilder:
|
||||
def get_nop_selector() -> "SelectiveBuilder":
|
||||
return SelectiveBuilder.from_yaml_dict({"include_all_operators": True})
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_dict(data: dict[str, object]) -> SelectiveBuilder:
|
||||
def from_yaml_dict(data: Dict[str, object]) -> "SelectiveBuilder":
|
||||
valid_top_level_keys = {
|
||||
"include_all_non_op_selectives",
|
||||
"include_all_operators",
|
||||
@ -140,20 +135,20 @@ class SelectiveBuilder:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_str(config_contents: str) -> SelectiveBuilder:
|
||||
def from_yaml_str(config_contents: str) -> "SelectiveBuilder":
|
||||
contents = yaml.safe_load(config_contents)
|
||||
return SelectiveBuilder.from_yaml_dict(contents)
|
||||
|
||||
@staticmethod
|
||||
def from_yaml_path(config_path: str) -> SelectiveBuilder:
|
||||
def from_yaml_path(config_path: str) -> "SelectiveBuilder":
|
||||
with open(config_path) as f:
|
||||
contents = yaml.safe_load(f)
|
||||
return SelectiveBuilder.from_yaml_dict(contents)
|
||||
|
||||
@staticmethod
|
||||
def from_legacy_op_registration_allow_list(
|
||||
allow_list: set[str], is_root_operator: bool, is_used_for_training: bool
|
||||
) -> SelectiveBuilder:
|
||||
allow_list: Set[str], is_root_operator: bool, is_used_for_training: bool
|
||||
) -> "SelectiveBuilder":
|
||||
operators = {}
|
||||
for op in allow_list:
|
||||
operators[op] = {
|
||||
@ -236,7 +231,7 @@ class SelectiveBuilder:
|
||||
and dtype in self.kernel_metadata[kernel_tag]
|
||||
)
|
||||
|
||||
def et_get_selected_kernels(self, op_name: str, kernel_key: list[str]) -> list[str]:
|
||||
def et_get_selected_kernels(self, op_name: str, kernel_key: List[str]) -> List[str]:
|
||||
"""
|
||||
Return a list of kernel keys that cover the used ops
|
||||
"""
|
||||
@ -266,8 +261,8 @@ class SelectiveBuilder:
|
||||
|
||||
return list(result_set)
|
||||
|
||||
def to_dict(self) -> dict[str, object]:
|
||||
ret: dict[str, object] = {
|
||||
def to_dict(self) -> Dict[str, object]:
|
||||
ret: Dict[str, object] = {
|
||||
"include_all_non_op_selectives": self.include_all_non_op_selectives,
|
||||
"include_all_operators": self.include_all_operators,
|
||||
}
|
||||
@ -293,10 +288,10 @@ class SelectiveBuilder:
|
||||
|
||||
|
||||
def merge_kernel_metadata(
|
||||
lhs: dict[str, list[str]],
|
||||
rhs: dict[str, list[str]],
|
||||
) -> dict[str, list[str]]:
|
||||
kernel_metadata: dict[str, list[str]] = {}
|
||||
lhs: Dict[str, List[str]],
|
||||
rhs: Dict[str, List[str]],
|
||||
) -> Dict[str, List[str]]:
|
||||
kernel_metadata: Dict[str, List[str]] = {}
|
||||
for tag_name, dtypes in list(lhs.items()) + list(rhs.items()):
|
||||
dtypes_copy = set(dtypes)
|
||||
if tag_name in kernel_metadata:
|
||||
@ -308,10 +303,10 @@ def merge_kernel_metadata(
|
||||
|
||||
|
||||
def merge_et_kernel_metadata(
|
||||
lhs: dict[str, list[str]],
|
||||
rhs: dict[str, list[str]],
|
||||
) -> dict[str, list[str]]:
|
||||
merge_et_kernel_metadata: dict[str, set[str]] = defaultdict(set)
|
||||
lhs: Dict[str, List[str]],
|
||||
rhs: Dict[str, List[str]],
|
||||
) -> Dict[str, List[str]]:
|
||||
merge_et_kernel_metadata: Dict[str, Set[str]] = defaultdict(set)
|
||||
for op in list(lhs.keys()) + list(rhs.keys()):
|
||||
merge_et_kernel_metadata[op].update(lhs.get(op, []))
|
||||
merge_et_kernel_metadata[op].update(rhs.get(op, []))
|
||||
|
@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
@ -18,9 +18,9 @@ you are in the root directory of the Pytorch git repo"""
|
||||
if not file_path.exists():
|
||||
raise Exception(err_msg) # noqa: TRY002
|
||||
|
||||
spec = spec_from_file_location(module_name, file_path)
|
||||
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
||||
assert spec is not None
|
||||
module = module_from_spec(spec)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
assert spec.loader is not None
|
||||
assert module is not None
|
||||
|
@ -1,9 +1,9 @@
|
||||
from __future__ import annotations
|
||||
from typing import Dict, Union
|
||||
|
||||
from torchgen.model import NativeFunctionsGroup, NativeFunctionsViewGroup
|
||||
|
||||
|
||||
def func_name_base_str(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> str:
|
||||
def func_name_base_str(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> str:
|
||||
if isinstance(g, NativeFunctionsGroup):
|
||||
return str(g.functional.func.name.name.base)
|
||||
else:
|
||||
@ -55,12 +55,12 @@ is_hand_written_ops_ = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def is_hand_written(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||
def is_hand_written(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
||||
name_base = func_name_base_str(g)
|
||||
return name_base in is_hand_written_ops_
|
||||
|
||||
|
||||
def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> None:
|
||||
def override_test_values(arg_map: Dict[str, str], op_name: str, index: int) -> None:
|
||||
assert index == 0 or index == 1
|
||||
if op_name == "addr":
|
||||
if index == 0:
|
||||
|
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import os
|
||||
@ -30,7 +28,7 @@ def group_functions_by_op_name(
|
||||
return []
|
||||
groups = []
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
||||
with native_function_manager(g):
|
||||
return generator.is_supported(g)
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from typing import Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torchgen.api.cpp as cpp
|
||||
from torchgen.context import native_function_manager
|
||||
@ -27,7 +25,7 @@ logger: logging.Logger = logging.getLogger()
|
||||
|
||||
|
||||
def has_alias(
|
||||
arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
|
||||
arguments: Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]
|
||||
) -> bool:
|
||||
for arg in arguments:
|
||||
annotation = getattr(arg, "annotation", None)
|
||||
@ -239,7 +237,7 @@ BLOCKED_OPS = frozenset(
|
||||
)
|
||||
|
||||
|
||||
def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||
def is_supported(g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup]) -> bool:
|
||||
base_op_name = ""
|
||||
func = None
|
||||
if isinstance(g, NativeFunctionsViewGroup):
|
||||
@ -300,8 +298,8 @@ def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
|
||||
|
||||
|
||||
def ivalue_type_conversion_method(
|
||||
arg_type: BaseType | OptionalType | Type,
|
||||
) -> tuple[bool, str] | None:
|
||||
arg_type: Union[BaseType, OptionalType, Type]
|
||||
) -> Optional[Tuple[bool, str]]:
|
||||
"""
|
||||
Return the method call expression of `c10::ivalue' to convert its contained value to
|
||||
the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
|
||||
@ -396,7 +394,7 @@ def test_tensor_dim(op_name: str) -> int:
|
||||
|
||||
|
||||
test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
|
||||
test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
|
||||
test_tensor_shape_json: Dict[str, str] = json.loads(test_tensor_shapes_string)
|
||||
|
||||
|
||||
def test_tensor_shape(op_name: str) -> str:
|
||||
@ -407,7 +405,7 @@ def test_tensor_shape(op_name: str) -> str:
|
||||
|
||||
|
||||
def test_value_expression(
|
||||
arg_type: BaseType | OptionalType | Type, index: int, op_name: str
|
||||
arg_type: Union[BaseType, OptionalType, Type], index: int, op_name: str
|
||||
) -> str:
|
||||
tensor_size_ex = test_tensor_shape(op_name)
|
||||
if tensor_size_ex == "":
|
||||
@ -477,8 +475,8 @@ generate_test_ir_arguments_base_ty_to_type_str_ = {
|
||||
|
||||
def generate_test_ir_arguments(
|
||||
schema: FunctionSchema,
|
||||
) -> list[tuple[str, str | None]]:
|
||||
def ir_argument(arg: Argument) -> tuple[str, str | None]:
|
||||
) -> List[Tuple[str, Optional[str]]]:
|
||||
def ir_argument(arg: Argument) -> Tuple[str, Optional[str]]:
|
||||
t = arg.type
|
||||
add_optional = False
|
||||
if isinstance(t, OptionalType):
|
||||
|
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
@ -7,29 +5,31 @@ import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from argparse import Namespace
|
||||
from dataclasses import fields, is_dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Set,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from torchgen.code_template import CodeTemplate
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from argparse import Namespace
|
||||
|
||||
|
||||
# Many of these functions share logic for defining both the definition
|
||||
# and declaration (for example, the function signature is the same), so
|
||||
# we organize them into one function that takes a Target to say which
|
||||
@ -57,7 +57,7 @@ IDENT_REGEX = r"(^|\W){}($|\W)"
|
||||
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
def split_name_params(schema: str) -> tuple[str, list[str]]:
|
||||
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
||||
m = re.match(r"(\w+)(\.\w+)?\((.*)\)", schema)
|
||||
if m is None:
|
||||
raise RuntimeError(f"Unsupported function schema: {schema}")
|
||||
@ -73,7 +73,7 @@ S = TypeVar("S")
|
||||
|
||||
|
||||
# Map over function that may return None; omit Nones from output sequence
|
||||
def mapMaybe(func: Callable[[T], S | None], xs: Iterable[T]) -> Iterator[S]:
|
||||
def mapMaybe(func: Callable[[T], Optional[S]], xs: Iterable[T]) -> Iterator[S]:
|
||||
for x in xs:
|
||||
r = func(x)
|
||||
if r is not None:
|
||||
@ -127,7 +127,7 @@ class FileManager:
|
||||
install_dir: str
|
||||
template_dir: str
|
||||
dry_run: bool
|
||||
filenames: set[str]
|
||||
filenames: Set[str]
|
||||
|
||||
def __init__(self, install_dir: str, template_dir: str, dry_run: bool) -> None:
|
||||
self.install_dir = install_dir
|
||||
@ -136,7 +136,7 @@ class FileManager:
|
||||
self.dry_run = dry_run
|
||||
|
||||
def _write_if_changed(self, filename: str, contents: str) -> None:
|
||||
old_contents: str | None
|
||||
old_contents: Optional[str]
|
||||
try:
|
||||
with open(filename) as f:
|
||||
old_contents = f.read()
|
||||
@ -150,7 +150,7 @@ class FileManager:
|
||||
|
||||
# Read from template file and replace pattern with callable (type could be dict or str).
|
||||
def substitute_with_template(
|
||||
self, template_fn: str, env_callable: Callable[[], str | dict[str, Any]]
|
||||
self, template_fn: str, env_callable: Callable[[], Union[str, Dict[str, Any]]]
|
||||
) -> str:
|
||||
template_path = os.path.join(self.template_dir, template_fn)
|
||||
env = env_callable()
|
||||
@ -171,7 +171,7 @@ class FileManager:
|
||||
self,
|
||||
filename: str,
|
||||
template_fn: str,
|
||||
env_callable: Callable[[], str | dict[str, Any]],
|
||||
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
||||
) -> None:
|
||||
filename = f"{self.install_dir}/{filename}"
|
||||
assert filename not in self.filenames, "duplicate file write {filename}"
|
||||
@ -186,7 +186,7 @@ class FileManager:
|
||||
def write(
|
||||
self,
|
||||
filename: str,
|
||||
env_callable: Callable[[], str | dict[str, Any]],
|
||||
env_callable: Callable[[], Union[str, Dict[str, Any]]],
|
||||
) -> None:
|
||||
self.write_with_template(filename, filename, env_callable)
|
||||
|
||||
@ -196,13 +196,13 @@ class FileManager:
|
||||
items: Iterable[T],
|
||||
*,
|
||||
key_fn: Callable[[T], str],
|
||||
env_callable: Callable[[T], dict[str, list[str]]],
|
||||
env_callable: Callable[[T], Dict[str, List[str]]],
|
||||
num_shards: int,
|
||||
base_env: dict[str, Any] | None = None,
|
||||
sharded_keys: set[str],
|
||||
base_env: Optional[Dict[str, Any]] = None,
|
||||
sharded_keys: Set[str],
|
||||
) -> None:
|
||||
everything: dict[str, Any] = {"shard_id": "Everything"}
|
||||
shards: list[dict[str, Any]] = [
|
||||
everything: Dict[str, Any] = {"shard_id": "Everything"}
|
||||
shards: List[Dict[str, Any]] = [
|
||||
{"shard_id": f"_{i}"} for i in range(num_shards)
|
||||
]
|
||||
all_shards = [everything] + shards
|
||||
@ -221,7 +221,7 @@ class FileManager:
|
||||
else:
|
||||
shard[key] = []
|
||||
|
||||
def merge_env(into: dict[str, list[str]], from_: dict[str, list[str]]) -> None:
|
||||
def merge_env(into: Dict[str, List[str]], from_: Dict[str, List[str]]) -> None:
|
||||
for k, v in from_.items():
|
||||
assert k in sharded_keys, f"undeclared sharded key {k}"
|
||||
into[k] += v
|
||||
@ -275,7 +275,7 @@ class FileManager:
|
||||
|
||||
# Helper function to generate file manager
|
||||
def make_file_manager(
|
||||
options: Namespace, install_dir: str | None = None
|
||||
options: Namespace, install_dir: Optional[str] = None
|
||||
) -> FileManager:
|
||||
template_dir = os.path.join(options.source_path, "templates")
|
||||
install_dir = install_dir if install_dir else options.install_dir
|
||||
@ -335,7 +335,7 @@ def _pformat(
|
||||
|
||||
|
||||
def _format_dict(
|
||||
attr: dict[Any, Any],
|
||||
attr: Dict[Any, Any],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
@ -355,7 +355,7 @@ def _format_dict(
|
||||
|
||||
|
||||
def _format_list(
|
||||
attr: list[Any] | set[Any] | tuple[Any, ...],
|
||||
attr: Union[List[Any], Set[Any], Tuple[Any, ...]],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
@ -370,7 +370,7 @@ def _format_list(
|
||||
|
||||
|
||||
def _format(
|
||||
fields_str: list[str],
|
||||
fields_str: List[str],
|
||||
indent: int,
|
||||
width: int,
|
||||
curr_indent: int,
|
||||
@ -402,9 +402,7 @@ class NamespaceHelper:
|
||||
} // namespace torch
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, namespace_str: str, entity_name: str = "", max_level: int = 2
|
||||
) -> None:
|
||||
def __init__(self, namespace_str: str, entity_name: str = "", max_level: int = 2):
|
||||
# cpp_namespace can be a colon joined string such as torch::lazy
|
||||
cpp_namespaces = namespace_str.split("::")
|
||||
assert (
|
||||
@ -421,7 +419,7 @@ class NamespaceHelper:
|
||||
@staticmethod
|
||||
def from_namespaced_entity(
|
||||
namespaced_entity: str, max_level: int = 2
|
||||
) -> NamespaceHelper:
|
||||
) -> "NamespaceHelper":
|
||||
"""
|
||||
Generate helper from nested namespaces as long as class/function name. E.g.: "torch::lazy::add"
|
||||
"""
|
||||
@ -454,9 +452,9 @@ class NamespaceHelper:
|
||||
|
||||
|
||||
class OrderedSet(Generic[T]):
|
||||
storage: dict[T, Literal[None]]
|
||||
storage: Dict[T, Literal[None]]
|
||||
|
||||
def __init__(self, iterable: Iterable[T] | None = None) -> None:
|
||||
def __init__(self, iterable: Optional[Iterable[T]] = None):
|
||||
if iterable is None:
|
||||
self.storage = {}
|
||||
else:
|
||||
@ -468,28 +466,28 @@ class OrderedSet(Generic[T]):
|
||||
def __iter__(self) -> Iterator[T]:
|
||||
return iter(self.storage.keys())
|
||||
|
||||
def update(self, items: OrderedSet[T]) -> None:
|
||||
def update(self, items: "OrderedSet[T]") -> None:
|
||||
self.storage.update(items.storage)
|
||||
|
||||
def add(self, item: T) -> None:
|
||||
self.storage[item] = None
|
||||
|
||||
def copy(self) -> OrderedSet[T]:
|
||||
def copy(self) -> "OrderedSet[T]":
|
||||
ret: OrderedSet[T] = OrderedSet()
|
||||
ret.storage = self.storage.copy()
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def union(*args: OrderedSet[T]) -> OrderedSet[T]:
|
||||
def union(*args: "OrderedSet[T]") -> "OrderedSet[T]":
|
||||
ret = args[0].copy()
|
||||
for s in args[1:]:
|
||||
ret.update(s)
|
||||
return ret
|
||||
|
||||
def __or__(self, other: OrderedSet[T]) -> OrderedSet[T]:
|
||||
def __or__(self, other: "OrderedSet[T]") -> "OrderedSet[T]":
|
||||
return OrderedSet.union(self, other)
|
||||
|
||||
def __ior__(self, other: OrderedSet[T]) -> Self:
|
||||
def __ior__(self, other: "OrderedSet[T]") -> Self:
|
||||
self.update(other)
|
||||
return self
|
||||
|
||||
|
Reference in New Issue
Block a user