Files
pytorch/torch/_export/serde/union.py
Aaron Gokaslan 7f65a20884 [BE]: Enable ruff SLOT checks (#146276)
This enables a check that which a class which only inherits from immutable classes like str, tuple, and NamedTuple, also defined `__slots__` so they don't allocate memory unnecessarily. This also ensure contributors think about how they define their classes with subclass NamedTuples and str, of which we have many in our codebase

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146276
Approved by: https://github.com/aorenste
2025-02-04 19:18:23 +00:00

72 lines
2.0 KiB
Python

# mypy: allow-untyped-defs
import functools
from collections.abc import Hashable
from dataclasses import fields
class _UnionTag(str):
__slots__ = ("_cls",)
_cls: Hashable
@staticmethod
def create(t, cls):
tag = _UnionTag(t)
assert not hasattr(tag, "_cls")
tag._cls = cls
return tag
def __eq__(self, cmp) -> bool:
assert isinstance(cmp, str)
other = str(cmp)
assert other in _get_field_names(
self._cls
), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
return str(self) == other
def __hash__(self):
return hash(str(self))
@functools.cache
def _get_field_names(cls) -> set[str]:
return {f.name for f in fields(cls)}
class _Union:
_type: _UnionTag
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
return obj
def __post_init__(self):
assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
@property
def type(self) -> str:
try:
return self._type
except AttributeError as e:
raise RuntimeError(
f"Please use {type(self).__name__}.create to instantiate the union type."
) from e
@property
def value(self):
return getattr(self, self.type)
def __getattribute__(self, name):
attr = super().__getattribute__(name)
if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
raise AttributeError(f"Field {name} is not set.")
return attr
def __str__(self):
return self.__repr__()
def __repr__(self):
return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"