mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add codegen infrastructure to generate IR nodes for non-native ops. The proposed change is to add a `non_native` key to the `{backend}_native_functions.yaml` file that contains schema definitions similar to what is found in `native_functions.yaml`. e.g. ``` non_native: ... - func: expand(Tensor input, int[] size, bool is_scalar_expand) -> Tensor ... ``` these definitions are parsed into a `LazyIrSchema` that can be used for generating IR nodes using `GenLazyIR`. Fixes #74628 CC: @wconstab @desertfire @henrytwo Pull Request resolved: https://github.com/pytorch/pytorch/pull/76535 Approved by: https://github.com/wconstab
411 lines
14 KiB
Python
411 lines
14 KiB
Python
from typing import Any, Dict, List, Union, Tuple, Optional
|
|
|
|
from torchgen.model import (
|
|
Type,
|
|
BaseTy,
|
|
BaseType,
|
|
OptionalType,
|
|
ListType,
|
|
OperatorName,
|
|
FunctionSchema,
|
|
Return,
|
|
TensorOptionsArguments,
|
|
Argument,
|
|
)
|
|
from torchgen.api.types import (
|
|
CType,
|
|
BaseCppType,
|
|
BaseCType,
|
|
OptionalCType,
|
|
NamedCType,
|
|
deviceT,
|
|
layoutT,
|
|
VectorCType,
|
|
boolT,
|
|
longT,
|
|
doubleT,
|
|
ListCType,
|
|
stringT,
|
|
scalarT,
|
|
scalarTypeT,
|
|
memoryFormatT,
|
|
SymIntT,
|
|
)
|
|
|
|
|
|
_valueT = None
|
|
|
|
|
|
def getValueT() -> BaseCppType:
|
|
global _valueT
|
|
if not _valueT:
|
|
raise NotImplementedError(
|
|
"The value type needs to be set with setValueT() in run_gen_lazy_tensor()"
|
|
)
|
|
|
|
return _valueT
|
|
|
|
|
|
def setValueT(val: BaseCppType) -> None:
|
|
global _valueT
|
|
_valueT = val
|
|
|
|
|
|
# this is a bad hack. I need to refactor the data model to represent each arg in the schema as an object,
|
|
# making it easier to represent special properties of an arg.
|
|
tensorListValueT = BaseCppType("torch::lazy", "Value")
|
|
|
|
|
|
def process_ir_type(
|
|
typ: Type, properties: "LazyIrProperties"
|
|
) -> Union[BaseCType, VectorCType, OptionalCType, ListCType]:
|
|
"""
|
|
This function takes a type from NativeFunctions and converts it for use with
|
|
lazy tensor codegen.
|
|
|
|
Type conversion for lazy currently consists of
|
|
(1) changing at::Tensors into lazy::Values
|
|
(2) wrapping everything in a BaseCType
|
|
(3) making cpp-reference types into cpp-value types (e.g. vector instead of IntArrayRef)
|
|
|
|
(1) converts at::Tensors to lazy::Values (which wrap lazy::Nodes, with which Lazy IR represents tensors.)
|
|
There is special handling for Optional[Tensor] or List[Tensor], etc- hence 'tensor-like'
|
|
|
|
This is incomplete- there are assertions in places that it's expected to need to add
|
|
more types as the codegen is used with more operators.
|
|
"""
|
|
if isinstance(typ, BaseType):
|
|
if typ.name == BaseTy.Tensor:
|
|
return BaseCType(getValueT())
|
|
elif typ.name == BaseTy.Scalar:
|
|
if properties.TreatScalarsAsConstants:
|
|
return BaseCType(scalarT)
|
|
# at::scalar has special handling,
|
|
# and is wrapped in an lazy::Value just like at::tensor
|
|
return BaseCType(getValueT())
|
|
elif typ.name == BaseTy.ScalarType:
|
|
return BaseCType(scalarTypeT)
|
|
elif typ.name == BaseTy.int:
|
|
return BaseCType(longT)
|
|
elif typ.name == BaseTy.SymInt:
|
|
return BaseCType(getValueT())
|
|
elif typ.name == BaseTy.bool:
|
|
return BaseCType(boolT)
|
|
elif typ.name == BaseTy.float:
|
|
return BaseCType(doubleT)
|
|
elif typ.name == BaseTy.str:
|
|
return BaseCType(stringT)
|
|
elif typ.name == BaseTy.Device:
|
|
return BaseCType(deviceT)
|
|
elif typ.name == BaseTy.Layout:
|
|
return BaseCType(layoutT)
|
|
elif typ.name == BaseTy.MemoryFormat:
|
|
return BaseCType(memoryFormatT)
|
|
else:
|
|
raise AssertionError(f"TODO add support for type {repr(typ)}")
|
|
elif isinstance(typ, OptionalType):
|
|
return OptionalCType(process_ir_type(typ.elem, properties))
|
|
elif isinstance(typ, ListType):
|
|
if str(typ.elem) == "Tensor?":
|
|
# TODO(whc) is this actually correct? or should it use a Vector like above
|
|
return ListCType(OptionalCType(BaseCType(getValueT())))
|
|
elif str(typ.elem) == "Tensor":
|
|
# this is a TensorList which comes in from GetTensorList as a Value
|
|
return BaseCType(tensorListValueT)
|
|
else:
|
|
return VectorCType(process_ir_type(typ.elem, properties))
|
|
else:
|
|
raise AssertionError(f"unrecognized type {repr(typ)}")
|
|
|
|
|
|
def isValueType(typ: CType, properties: "Optional[LazyIrProperties]" = None) -> bool:
|
|
"""
|
|
Given a type, determine if it is a Value-like type. This is equivalent to
|
|
being Tensor-like, but assumes the type has already been transformed.
|
|
"""
|
|
if isinstance(typ, BaseCType):
|
|
# I am regretting my naming conventions, but now we are wrapping at::scalar in
|
|
# lazy value, while preserving other 'scalar' types as scalars in the IR
|
|
treat_scalars_as_constants = properties and properties.TreatScalarsAsConstants
|
|
return (
|
|
typ.type == getValueT()
|
|
or (typ.type == scalarT and not treat_scalars_as_constants)
|
|
or typ.type == SymIntT
|
|
)
|
|
elif isinstance(typ, (OptionalCType, ListCType, VectorCType)):
|
|
return isValueType(typ.elem, properties)
|
|
return False
|
|
|
|
|
|
def isSymIntType(typ: Type) -> bool:
|
|
return isinstance(typ, BaseType) and typ.name == BaseTy.SymInt
|
|
|
|
|
|
def isWrappedScalarType(typ: Type) -> bool:
|
|
"""
|
|
Given a type, determine if it is a c10::scalar which we will wrap in a lazy Value.
|
|
Since we literally change the type from scalarT to valueT, information is lost.
|
|
This function helps build a list of wrapped scalars to save that information
|
|
"""
|
|
if isinstance(typ, BaseType):
|
|
# I am regretting my naming conventions, but now we are wrapping at::scalar in
|
|
# lazy value, while preserving other 'scalar' types as scalars in the IR
|
|
return typ.name == BaseTy.Scalar
|
|
elif isinstance(typ, (OptionalType, ListType)):
|
|
return isWrappedScalarType(typ.elem)
|
|
return False
|
|
|
|
|
|
def isGeneratorType(typ: Type) -> bool:
|
|
if isinstance(typ, BaseType):
|
|
return typ.name == BaseTy.Generator
|
|
elif isinstance(typ, (OptionalType)):
|
|
return isGeneratorType(typ.elem)
|
|
return False
|
|
|
|
|
|
class LazyArgument:
|
|
name: str
|
|
orig_type: Type
|
|
lazy_type_: Optional[CType]
|
|
is_wrapped_scalar: bool
|
|
is_generator: bool
|
|
is_symint_or_list: bool
|
|
|
|
# true if this argument is or contains a lazy IR value
|
|
is_lazy_value: bool
|
|
|
|
def __init__(self, arg: Argument, properties: "LazyIrProperties"):
|
|
self.name = arg.name
|
|
self.orig_type = arg.type
|
|
self.is_optional = isinstance(arg.type, OptionalType)
|
|
self.is_generator = isGeneratorType(arg.type)
|
|
if self.is_generator:
|
|
assert (
|
|
self.is_optional
|
|
), "We expect all generators are optional since currently they are"
|
|
# there is no handling for generators in TorchScript IR (or XLA)
|
|
# so we fall back to eager if the (optional)generator has value, and otherwise
|
|
# its null and safe to exclude from lazy IR
|
|
self.lazy_type_ = None
|
|
else:
|
|
self.lazy_type_ = process_ir_type(arg.type, properties)
|
|
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
|
|
self.is_symint_or_list = isSymIntType(arg.type)
|
|
|
|
self.is_lazy_value = not self.is_generator and isValueType(
|
|
self.lazy_type, properties
|
|
)
|
|
|
|
@property
|
|
def lazy_type(self) -> CType:
|
|
assert (
|
|
self.lazy_type_ is not None
|
|
), f"Attempted to access lazy_type for invalid argument {self.name}"
|
|
return self.lazy_type_
|
|
|
|
|
|
class LazyIrProperties:
|
|
"""Collection of properties for an IR node
|
|
|
|
The property groups are listed below. Each group is mutually
|
|
exclusive, meaning that only one property from each group can be True
|
|
at any one time. The properties can be accessed as if they were normal
|
|
attributes. The mutual exclusivity is automatically handled.
|
|
"""
|
|
|
|
Properties: Tuple[Tuple[str, ...], ...] = (
|
|
(
|
|
"ShapePrecompute", # Assume shape has been precomputed
|
|
"ShapeCompute", # Need to compute the shape on construction
|
|
"ShapeCache", # Utilize the shape cache to defer computation
|
|
),
|
|
(
|
|
"Lower", # Codegen full lower function
|
|
"LowerDeclOnly", # Codegen only lower function declaration
|
|
),
|
|
(
|
|
"CanBeReused", # Codegen full reuse function
|
|
"CanBeReusedDeclOnly", # Codegen only reuse function declaration
|
|
),
|
|
(
|
|
"CreateFn", # Codegen full create function
|
|
"CreateFnDeclOnly", # Codegen only create function declaration
|
|
),
|
|
(
|
|
"TreatScalarsAsConstants", # Treat Scalars as constants instead of handling like values
|
|
),
|
|
)
|
|
|
|
def __init__(self, *default_properties: str):
|
|
properties: Dict[Tuple[str, ...], Optional[str]] = {
|
|
p: None for p in LazyIrProperties.Properties
|
|
}
|
|
self.__dict__["properties"] = properties
|
|
for p in default_properties:
|
|
setattr(self, p, True)
|
|
|
|
def __getattr__(self, key: str) -> Any:
|
|
properties = self.__dict__["properties"]
|
|
for values in LazyIrProperties.Properties:
|
|
if key in values:
|
|
return properties[values] == key
|
|
|
|
return self.__getattribute__(key)
|
|
|
|
def __setattr__(self, key: str, value: Any) -> Any:
|
|
properties = self.__dict__["properties"]
|
|
for values in LazyIrProperties.Properties:
|
|
if key in values:
|
|
properties[values] = key if value else None
|
|
return value
|
|
|
|
raise KeyError(f"Invalid property: {key}")
|
|
|
|
|
|
# Inspired by a FunctionSchema object, a LazyIrSchema holds the schema of a Lazy IR node.
|
|
# Unlike a FunctionSchema, it has no round-trippable string form (relating to the YAML),
|
|
# but carries type information from a native FunctionSchema modified for use with IR nodes,
|
|
# and preserving original argument names.
|
|
class LazyIrSchema:
|
|
# The name of the operator this function schema describes.
|
|
name: "OperatorName"
|
|
|
|
positional_args: Tuple[LazyArgument, ...]
|
|
keyword_args: Tuple[LazyArgument, ...]
|
|
|
|
# TODO: Need to handle collisions with argument names at some point
|
|
returns: Tuple["Return", ...]
|
|
|
|
# if this schema has a Generator arg, list its orig ctype/name but don't
|
|
# build a LazyArgument since lazy IR doesn't support it
|
|
generator_arg: Optional[NamedCType] = None
|
|
|
|
properties: LazyIrProperties = LazyIrProperties(
|
|
# default properties
|
|
"ShapePrecompute",
|
|
"Lower",
|
|
"CanBeReused",
|
|
)
|
|
opkind: Optional[str] = None
|
|
|
|
def __init__(
|
|
self, func: FunctionSchema, properties: Optional[LazyIrProperties] = None
|
|
):
|
|
if properties:
|
|
self.properties = properties
|
|
|
|
positional_args: List[LazyArgument] = []
|
|
for arg_field in ["pre_self_positional", "self_arg", "post_self_positional"]:
|
|
if arg_field == "self_arg" and func.arguments.self_arg is not None:
|
|
arg = getattr(func.arguments, "self_arg").argument
|
|
positional_args.append(LazyArgument(arg, self.properties))
|
|
elif getattr(func.arguments, arg_field) is not None:
|
|
positional_args.extend(
|
|
LazyArgument(arg, self.properties)
|
|
for arg in getattr(func.arguments, arg_field)
|
|
)
|
|
self.positional_args = tuple(positional_args)
|
|
|
|
keyword_args: List[LazyArgument] = []
|
|
for arg_field in [
|
|
"pre_tensor_options_kwarg_only",
|
|
"tensor_options",
|
|
"post_tensor_options_kwarg_only",
|
|
"out",
|
|
]:
|
|
curr_args = getattr(func.arguments, arg_field)
|
|
if curr_args is not None:
|
|
if isinstance(curr_args, TensorOptionsArguments):
|
|
curr_args = curr_args.all()
|
|
for arg in curr_args:
|
|
if isGeneratorType(arg.type):
|
|
assert (
|
|
self.generator_arg is None
|
|
), "We expect there is only one generator arg"
|
|
self.generator_arg = NamedCType(arg.name, arg.type)
|
|
keyword_args.extend(
|
|
LazyArgument(arg, self.properties) for arg in curr_args
|
|
)
|
|
self.keyword_args = tuple(keyword_args)
|
|
self.name = func.name
|
|
self.returns = func.returns
|
|
|
|
@property
|
|
def node_name(self) -> str:
|
|
"""
|
|
Return camel-case version of op in node.
|
|
|
|
Note: This function also appends any `overload_name` in the operation.
|
|
For example, if the op is `bitwise_and.Tensor`, the returned name
|
|
will be `BitwiseAndTensor`.
|
|
"""
|
|
op_name = f"{self.name.name}_{self.name.overload_name}".lower()
|
|
return "".join(word.capitalize() or "" for word in op_name.split("_"))
|
|
|
|
@property
|
|
def aten_name(self) -> str:
|
|
return str(self.name.name)
|
|
|
|
@property
|
|
def base_name(self) -> str:
|
|
return f"{self.name.name.base}"
|
|
|
|
def filtered_args(
|
|
self,
|
|
positional: bool = True,
|
|
keyword: bool = True,
|
|
values: bool = True,
|
|
scalars: bool = True,
|
|
generator: bool = False,
|
|
) -> List[LazyArgument]:
|
|
# This function maintains the sorted order of arguments but provides different filtered views.
|
|
# Some parts of the code care about kwargs vs args (TS lowerings),
|
|
# other parts care about whether they need to wrap the arg in a lazy value or leave it alone.
|
|
# Generators are special cased, as they are needed for fallback/shape-inference but not supported
|
|
# in TS lowerings and therefore also omitted from lazy IR.
|
|
args: List[LazyArgument] = []
|
|
if positional:
|
|
args.extend(self.positional_args)
|
|
if keyword:
|
|
args.extend(self.keyword_args)
|
|
|
|
if values and scalars and generator:
|
|
return args
|
|
elif values and scalars:
|
|
return [a for a in args if not a.is_generator]
|
|
elif values:
|
|
return [a for a in args if a.is_lazy_value]
|
|
elif scalars:
|
|
return [
|
|
a
|
|
for a in args
|
|
if not a.is_lazy_value and (generator or not a.is_generator)
|
|
]
|
|
|
|
return []
|
|
|
|
@property
|
|
def positional_values(self) -> List[LazyArgument]:
|
|
return self.filtered_args(
|
|
positional=True, keyword=False, values=True, scalars=False
|
|
)
|
|
|
|
@property
|
|
def positional_scalars(self) -> List[LazyArgument]:
|
|
return self.filtered_args(
|
|
positional=True, keyword=False, values=False, scalars=True
|
|
)
|
|
|
|
@property
|
|
def keyword_values(self) -> List[LazyArgument]:
|
|
return self.filtered_args(
|
|
positional=False, keyword=True, values=True, scalars=False
|
|
)
|
|
|
|
@property
|
|
def keyword_scalars(self) -> List[LazyArgument]:
|
|
return self.filtered_args(
|
|
positional=False, keyword=True, values=False, scalars=True
|
|
)
|