Files
pytorch/torch/_export/serde/union.py

97 lines
2.9 KiB
Python

# mypy: allow-untyped-defs
import functools
from collections.abc import Hashable
from dataclasses import dataclass, fields
from typing import TypeVar
from typing_extensions import dataclass_transform
T = TypeVar("T", bound="_Union")
class _UnionTag(str):
__slots__ = ("_cls",)
_cls: Hashable
@staticmethod
def create(t, cls):
tag = _UnionTag(t)
assert not hasattr(tag, "_cls")
tag._cls = cls
return tag
def __eq__(self, cmp) -> bool:
assert isinstance(cmp, str)
other = str(cmp)
assert other in _get_field_names(self._cls), (
f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
)
return str(self) == other
def __hash__(self):
return hash(str(self))
@functools.cache
def _get_field_names(cls) -> set[str]:
return {f.name for f in fields(cls)}
# If you turn a schema class that inherits from union into a dataclass, please use
# this decorator to configure it. It's safe, faster and allows code sharing.
#
# For example, _union_dataclass customizes the __eq__ method to only check the type
# and value property instead of default implementation of dataclass which goes
# through every field in the dataclass.
@dataclass_transform(eq_default=False)
def _union_dataclass(cls: type[T]) -> type[T]:
assert issubclass(cls, _Union), f"{cls} must inheirt from {_Union}."
return dataclass(repr=False, eq=False)(cls)
class _Union:
_type: _UnionTag
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
return obj
def __post_init__(self):
assert not any(
f.name in ("type", "_type", "create", "value")
for f in fields(self) # type: ignore[arg-type, misc]
)
@property
def type(self) -> str:
try:
return self._type
except AttributeError as e:
raise RuntimeError(
f"Please use {type(self).__name__}.create to instantiate the union type."
) from e
@property
def value(self):
return getattr(self, self.type)
def __getattribute__(self, name):
attr = super().__getattribute__(name)
if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
raise AttributeError(f"Field {name} is not set.")
return attr
def __eq__(self, other: object) -> bool:
if not isinstance(other, _Union):
return False
return self.type == other.type and self.value == other.value
def __str__(self):
return self.__repr__()
def __repr__(self):
return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"