mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-01 13:34:57 +08:00
[AOTInductor] Prepare for ProxyExecutor, OSS only change (#107065)
Summary: Minor fixes to export schema and serialization Test Plan: OSS CI Differential Revision: D48280809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/107065 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
4a6ca4cc05
commit
1e007d044d
@ -2,7 +2,7 @@
|
||||
# Anything is subject to change and no guarantee is provided at this point.
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
from enum import IntEnum
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
# TODO (zhxchen17) Move to a separate file.
|
||||
@ -27,8 +27,14 @@ class _Union:
|
||||
assert val_type is not None
|
||||
return val_type
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
class ScalarType(Enum):
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}({self.type}={self.value})"
|
||||
|
||||
|
||||
class ScalarType(IntEnum):
|
||||
UNKNOWN = 0
|
||||
BYTE = 1
|
||||
CHAR = 2
|
||||
@ -45,7 +51,7 @@ class ScalarType(Enum):
|
||||
BFLOAT16 = 13
|
||||
|
||||
|
||||
class Layout(Enum):
|
||||
class Layout(IntEnum):
|
||||
Unknown = 0
|
||||
SparseCoo = 1
|
||||
SparseCsr = 2
|
||||
@ -56,7 +62,7 @@ class Layout(Enum):
|
||||
Strided = 7
|
||||
|
||||
|
||||
class MemoryFormat(Enum):
|
||||
class MemoryFormat(IntEnum):
|
||||
Unknown = 0
|
||||
ContiguousFormat = 1
|
||||
ChannelsLast = 2
|
||||
@ -76,13 +82,13 @@ class SymExpr:
|
||||
hint: Optional[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class SymInt(_Union):
|
||||
as_expr: SymExpr
|
||||
as_int: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class SymBool(_Union):
|
||||
as_expr: str
|
||||
as_bool: bool
|
||||
@ -99,13 +105,13 @@ class TensorMeta:
|
||||
layout: Layout
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class SymIntArgument(_Union):
|
||||
as_name: str
|
||||
as_int: int
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class SymBoolArgument(_Union):
|
||||
as_name: str
|
||||
as_bool: bool
|
||||
@ -116,7 +122,7 @@ class TensorArgument:
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class OptionalTensorArgument(_Union):
|
||||
as_tensor: str
|
||||
as_none: Tuple[()]
|
||||
@ -129,7 +135,7 @@ class GraphArgument:
|
||||
|
||||
|
||||
# This is actually a union type
|
||||
@dataclass
|
||||
@dataclass(repr=False)
|
||||
class Argument(_Union):
|
||||
as_none: Tuple[()]
|
||||
as_tensor: TensorArgument
|
||||
@ -139,6 +145,7 @@ class Argument(_Union):
|
||||
as_float: float
|
||||
as_floats: List[float]
|
||||
as_string: str
|
||||
as_strings: List[str]
|
||||
as_sym_int: SymIntArgument
|
||||
as_sym_ints: List[SymIntArgument]
|
||||
as_scalar_type: ScalarType
|
||||
|
||||
Reference in New Issue
Block a user