[AOTInductor] Prepare for ProxyExecutor, OSS only change (#107065)

Summary: Minor fixes to export schema and serialization

Test Plan: OSS CI

Differential Revision: D48280809

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107065
Approved by: https://github.com/zhxchen17
This commit is contained in:
Sherlock Huang
2023-08-14 20:04:41 +00:00
committed by PyTorch MergeBot
parent 4a6ca4cc05
commit 1e007d044d
5 changed files with 27 additions and 12 deletions

View File

@ -2,7 +2,7 @@
# Anything is subject to change and no guarantee is provided at this point.
from dataclasses import dataclass, fields
from enum import Enum
from enum import IntEnum
from typing import Dict, List, Optional, Tuple
# TODO (zhxchen17) Move to a separate file.
@ -27,8 +27,14 @@ class _Union:
assert val_type is not None
return val_type
def __str__(self):
return self.__repr__()
class ScalarType(Enum):
def __repr__(self):
return f"{type(self).__name__}({self.type}={self.value})"
class ScalarType(IntEnum):
UNKNOWN = 0
BYTE = 1
CHAR = 2
@ -45,7 +51,7 @@ class ScalarType(Enum):
BFLOAT16 = 13
class Layout(Enum):
class Layout(IntEnum):
Unknown = 0
SparseCoo = 1
SparseCsr = 2
@ -56,7 +62,7 @@ class Layout(Enum):
Strided = 7
class MemoryFormat(Enum):
class MemoryFormat(IntEnum):
Unknown = 0
ContiguousFormat = 1
ChannelsLast = 2
@ -76,13 +82,13 @@ class SymExpr:
hint: Optional[int]
@dataclass
@dataclass(repr=False)
class SymInt(_Union):
as_expr: SymExpr
as_int: int
@dataclass
@dataclass(repr=False)
class SymBool(_Union):
as_expr: str
as_bool: bool
@ -99,13 +105,13 @@ class TensorMeta:
layout: Layout
@dataclass
@dataclass(repr=False)
class SymIntArgument(_Union):
as_name: str
as_int: int
@dataclass
@dataclass(repr=False)
class SymBoolArgument(_Union):
as_name: str
as_bool: bool
@ -116,7 +122,7 @@ class TensorArgument:
name: str
@dataclass
@dataclass(repr=False)
class OptionalTensorArgument(_Union):
as_tensor: str
as_none: Tuple[()]
@ -129,7 +135,7 @@ class GraphArgument:
# This is actually a union type
@dataclass
@dataclass(repr=False)
class Argument(_Union):
as_none: Tuple[()]
as_tensor: TensorArgument
@ -139,6 +145,7 @@ class Argument(_Union):
as_float: float
as_floats: List[float]
as_string: str
as_strings: List[str]
as_sym_int: SymIntArgument
as_sym_ints: List[SymIntArgument]
as_scalar_type: ScalarType