mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[HOP] support generating schema for hop (#133521)
Add a way of generating a FunctionSchema from example values because hop's schema varies even for the same hop. We didn't use torch._C.FunctionSchema because we cannot construct the classes directly (e.g. "__init__" cannot be used for torch._C.FunctionSchema). Also extending the Basic types in c++ seems not that easy. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133521 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
dd5a7c8397
commit
6835f20d20
97
torchgen/gen_schema_utils.py
Normal file
97
torchgen/gen_schema_utils.py
Normal file
@ -0,0 +1,97 @@
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from torchgen.model import (
|
||||
Annotation,
|
||||
Argument,
|
||||
Arguments,
|
||||
BaseOperatorName,
|
||||
BaseTy,
|
||||
BaseType,
|
||||
CustomClassType,
|
||||
FunctionSchema,
|
||||
ListType,
|
||||
OperatorName,
|
||||
Return,
|
||||
)
|
||||
|
||||
|
||||
# Note: These aren't actually used in torchgen, they're some utilities for generating a schema
|
||||
# from real arguments. For example, this is used to generate HigherOrderOperators' schema since
|
||||
# their schemas can vary for different instances of the same HOP.
|
||||
|
||||
|
||||
class TypeGen:
|
||||
convert_to_base_ty = {
|
||||
int: BaseTy.int,
|
||||
float: BaseTy.float,
|
||||
str: BaseTy.str,
|
||||
bool: BaseTy.bool,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
|
||||
import torch
|
||||
|
||||
if isinstance(obj, torch.fx.GraphModule):
|
||||
return BaseType(BaseTy.GraphModule)
|
||||
elif isinstance(obj, torch.Tensor):
|
||||
return BaseType(BaseTy.Tensor)
|
||||
elif isinstance(obj, torch.SymInt):
|
||||
return BaseType(BaseTy.SymInt)
|
||||
elif isinstance(obj, torch.SymBool):
|
||||
return BaseType(BaseTy.SymBool)
|
||||
elif isinstance(obj, torch.ScriptObject):
|
||||
return CustomClassType(obj._type().name()) # type: ignore[attr-defined]
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
assert len(obj) > 0
|
||||
all_base_tys = [TypeGen.from_example(x) for x in obj]
|
||||
if len(set(all_base_tys)) > 1:
|
||||
raise RuntimeError(
|
||||
f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
|
||||
"Consider unpacking the argument and give proper names to them if possible "
|
||||
"instead of using *args."
|
||||
)
|
||||
return ListType(all_base_tys[0], len(obj))
|
||||
tp = type(obj)
|
||||
if tp not in TypeGen.convert_to_base_ty:
|
||||
raise RuntimeError(f"unsupported type {tp}")
|
||||
return BaseType(TypeGen.convert_to_base_ty[tp])
|
||||
|
||||
|
||||
class ReturnGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
name: Optional[str], obj: Any, annotation: Optional[Annotation]
|
||||
) -> Return:
|
||||
return Return(name, TypeGen.from_example(obj), annotation)
|
||||
|
||||
|
||||
class ArgumentGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
|
||||
) -> Argument:
|
||||
return Argument(
|
||||
name, TypeGen.from_example(obj), default=default, annotation=annotation
|
||||
)
|
||||
|
||||
|
||||
class FunctionSchemaGen:
|
||||
@staticmethod
|
||||
def from_example(
|
||||
op_name: str,
|
||||
example_inputs: Tuple[Tuple[str, Any], ...],
|
||||
example_outputs: Tuple[Any, ...],
|
||||
) -> FunctionSchema:
|
||||
args = []
|
||||
for name, inp in example_inputs:
|
||||
args.append(ArgumentGen.from_example(name, inp, None, None))
|
||||
# ignore the annotations and other attributes for now, we could add more when needed.
|
||||
arguments = Arguments(
|
||||
tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
|
||||
)
|
||||
returns = tuple(
|
||||
ReturnGen.from_example(None, out, None) for out in example_outputs
|
||||
)
|
||||
op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
|
||||
return FunctionSchema(op_name, arguments, returns)
|
Reference in New Issue
Block a user