mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145201 Approved by: https://github.com/bobrenjc93
117 lines
4.5 KiB
Python
117 lines
4.5 KiB
Python
# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the
|
|
# minimum supported we should use dataclass(slots=True) instead.
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import itertools
|
|
from typing import TYPE_CHECKING, TypeVar
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Generator
|
|
|
|
from _typeshed import DataclassInstance
|
|
|
|
|
|
__all__ = ["dataclass_slots"]
|
|
|
|
_T = TypeVar("_T", bound="DataclassInstance")
|
|
|
|
|
|
def dataclass_slots(cls: type[_T]) -> type[DataclassInstance]:
|
|
assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses."
|
|
|
|
def _get_slots(cls: type[DataclassInstance]) -> Generator[str, None, None]:
|
|
slots = cls.__dict__.get("__slots__")
|
|
# `__dictoffset__` and `__weakrefoffset__` can tell us whether
|
|
# the base type has dict/weakref slots, in a way that works correctly
|
|
# for both Python classes and C extension types. Extension types
|
|
# don't use `__slots__` for slot creation
|
|
if slots is None:
|
|
slots = []
|
|
if getattr(cls, "__weakrefoffset__", -1) != 0:
|
|
slots.append("__weakref__")
|
|
if getattr(cls, "__dictrefoffset__", -1) != 0:
|
|
slots.append("__dict__")
|
|
yield from slots
|
|
elif isinstance(slots, str):
|
|
yield slots
|
|
# Slots may be any iterable, but we cannot handle an iterator
|
|
# because it will already be (partially) consumed.
|
|
elif not hasattr(cls, "__next__"):
|
|
yield from slots
|
|
else:
|
|
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
|
|
|
|
def _add_slots(
|
|
cls: type[DataclassInstance], is_frozen: bool, weakref_slot: bool
|
|
) -> type[DataclassInstance]:
|
|
# Need to create a new class, since we can't set __slots__
|
|
# after a class has been created.
|
|
|
|
# Make sure __slots__ isn't already set.
|
|
if "__slots__" in cls.__dict__:
|
|
raise TypeError(f"{cls.__name__} already specifies __slots__")
|
|
|
|
# Create a new dict for our new class.
|
|
cls_dict = dict(cls.__dict__)
|
|
field_names = tuple(f.name for f in dataclasses.fields(cls))
|
|
# Make sure slots don't overlap with those in base classes.
|
|
inherited_slots = set(
|
|
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
|
|
)
|
|
# The slots for our class. Remove slots from our base classes. Add
|
|
# '__weakref__' if weakref_slot was given, unless it is already present.
|
|
cls_dict["__slots__"] = tuple(
|
|
itertools.filterfalse(
|
|
inherited_slots.__contains__,
|
|
itertools.chain(
|
|
# gh-93521: '__weakref__' also needs to be filtered out if
|
|
# already present in inherited_slots
|
|
field_names,
|
|
("__weakref__",) if weakref_slot else (),
|
|
),
|
|
),
|
|
)
|
|
|
|
for field_name in field_names:
|
|
# Remove our attributes, if present. They'll still be
|
|
# available in _MARKER.
|
|
cls_dict.pop(field_name, None)
|
|
|
|
# Remove __dict__ itself.
|
|
cls_dict.pop("__dict__", None)
|
|
|
|
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
|
|
cls_dict.pop("__weakref__", None) # gh-102069
|
|
|
|
# And finally create the class.
|
|
qualname = getattr(cls, "__qualname__", None)
|
|
cls = type(cls.__name__, cls.__bases__, cls_dict)
|
|
if qualname is not None:
|
|
cls.__qualname__ = qualname
|
|
|
|
def _dataclass_getstate(self: _T) -> object:
|
|
fields = dataclasses.fields(self)
|
|
return [getattr(self, f.name) for f in fields]
|
|
|
|
def _dataclass_setstate(self: _T, state: list[object]) -> None:
|
|
fields = dataclasses.fields(self)
|
|
for field, value in zip(fields, state):
|
|
# use setattr because dataclass may be frozen
|
|
object.__setattr__(self, field.name, value)
|
|
|
|
if is_frozen:
|
|
# Need this for pickling frozen classes with slots.
|
|
if "__getstate__" not in cls_dict:
|
|
cls.__getstate__ = _dataclass_getstate # type: ignore[method-assign, assignment]
|
|
if "__setstate__" not in cls_dict:
|
|
cls.__setstate__ = _dataclass_setstate # type: ignore[attr-defined]
|
|
|
|
return cls
|
|
|
|
params = getattr(cls, dataclasses._PARAMS) # type: ignore[attr-defined]
|
|
weakref_slot = getattr(params, "weakref_slot", False)
|
|
return _add_slots(cls, params.frozen, weakref_slot)
|