mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 08:34:52 +08:00 
			
		
		
		
	Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/141938 Approved by: https://github.com/ezyang
		
			
				
	
	
		
			98 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			98 lines
		
	
	
		
			3.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Any, Optional, Union
 | |
| 
 | |
| from torchgen.model import (
 | |
|     Annotation,
 | |
|     Argument,
 | |
|     Arguments,
 | |
|     BaseOperatorName,
 | |
|     BaseTy,
 | |
|     BaseType,
 | |
|     CustomClassType,
 | |
|     FunctionSchema,
 | |
|     ListType,
 | |
|     OperatorName,
 | |
|     Return,
 | |
| )
 | |
| 
 | |
| 
 | |
| # Note: These aren't actually used in torchgen, they're some utilities for generating a schema
 | |
| # from real arguments. For example, this is used to generate HigherOrderOperators' schema since
 | |
| # their schemas can vary for different instances of the same HOP.
 | |
| 
 | |
| 
 | |
| class TypeGen:
 | |
|     convert_to_base_ty = {
 | |
|         int: BaseTy.int,
 | |
|         float: BaseTy.float,
 | |
|         str: BaseTy.str,
 | |
|         bool: BaseTy.bool,
 | |
|     }
 | |
| 
 | |
|     @staticmethod
 | |
|     def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]:
 | |
|         import torch
 | |
| 
 | |
|         if isinstance(obj, torch.fx.GraphModule):
 | |
|             return BaseType(BaseTy.GraphModule)
 | |
|         elif isinstance(obj, torch.Tensor):
 | |
|             return BaseType(BaseTy.Tensor)
 | |
|         elif isinstance(obj, torch.SymInt):
 | |
|             return BaseType(BaseTy.SymInt)
 | |
|         elif isinstance(obj, torch.SymBool):
 | |
|             return BaseType(BaseTy.SymBool)
 | |
|         elif isinstance(obj, torch.ScriptObject):
 | |
|             return CustomClassType(obj._type().name())  # type: ignore[attr-defined]
 | |
|         elif isinstance(obj, (list, tuple)):
 | |
|             assert len(obj) > 0
 | |
|             all_base_tys = [TypeGen.from_example(x) for x in obj]
 | |
|             if len(set(all_base_tys)) > 1:
 | |
|                 raise RuntimeError(
 | |
|                     f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. "
 | |
|                     "Consider unpacking the argument and give proper names to them if possible "
 | |
|                     "instead of using *args."
 | |
|                 )
 | |
|             return ListType(all_base_tys[0], len(obj))
 | |
|         tp = type(obj)
 | |
|         if tp not in TypeGen.convert_to_base_ty:
 | |
|             raise RuntimeError(f"unsupported type {tp}")
 | |
|         return BaseType(TypeGen.convert_to_base_ty[tp])
 | |
| 
 | |
| 
 | |
| class ReturnGen:
 | |
|     @staticmethod
 | |
|     def from_example(
 | |
|         name: Optional[str], obj: Any, annotation: Optional[Annotation]
 | |
|     ) -> Return:
 | |
|         return Return(name, TypeGen.from_example(obj), annotation)
 | |
| 
 | |
| 
 | |
| class ArgumentGen:
 | |
|     @staticmethod
 | |
|     def from_example(
 | |
|         name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation]
 | |
|     ) -> Argument:
 | |
|         return Argument(
 | |
|             name, TypeGen.from_example(obj), default=default, annotation=annotation
 | |
|         )
 | |
| 
 | |
| 
 | |
| class FunctionSchemaGen:
 | |
|     @staticmethod
 | |
|     def from_example(
 | |
|         op_name: str,
 | |
|         example_inputs: tuple[tuple[str, Any], ...],
 | |
|         example_outputs: tuple[Any, ...],
 | |
|     ) -> FunctionSchema:
 | |
|         args = []
 | |
|         for name, inp in example_inputs:
 | |
|             args.append(ArgumentGen.from_example(name, inp, None, None))
 | |
|         # ignore the annotations and other attributes for now, we could add more when needed.
 | |
|         arguments = Arguments(
 | |
|             tuple(), None, tuple(args), tuple(), None, tuple(), tuple()
 | |
|         )
 | |
|         returns = tuple(
 | |
|             ReturnGen.from_example(None, out, None) for out in example_outputs
 | |
|         )
 | |
|         op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "")
 | |
|         return FunctionSchema(op_name, arguments, returns)
 |