mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Remove dataclass_slots (#163623)
`dataclass` now has `slots` kwarg. Pull Request resolved: https://github.com/pytorch/pytorch/pull/163623 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
b776e0c71e
commit
5daa79fd6e
@ -24,7 +24,6 @@ import uuid
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from ..utils._backport_slots import dataclass_slots
|
||||
from . import config
|
||||
from .bytecode_analysis import (
|
||||
get_indexof,
|
||||
@ -39,8 +38,7 @@ if TYPE_CHECKING:
|
||||
from .output_graph import DynamoTracerOutput
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class InstructionExnTabEntry:
|
||||
start: "Instruction"
|
||||
end: "Instruction"
|
||||
@ -68,8 +66,7 @@ class InstructionExnTabEntry:
|
||||
)
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class Instruction:
|
||||
"""A mutable version of dis.Instruction"""
|
||||
|
||||
@ -642,8 +639,7 @@ def linetable_311_writer(
|
||||
return linetable, update
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class ExceptionTableEntry:
|
||||
start: int
|
||||
end: int
|
||||
|
@ -27,7 +27,6 @@ from typing import (
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
from torch.utils._traceback import CapturedTraceback, format_frame
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
|
||||
@ -241,8 +240,7 @@ class ShapeGuard(NamedTuple):
|
||||
size_oblivious: bool
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclasses.dataclass
|
||||
@dataclasses.dataclass(slots=True)
|
||||
class Guard:
|
||||
# originating_source is the source that called the make_guard method to
|
||||
# construct this guard object. The property name specifies what exactly it
|
||||
|
@ -7,7 +7,6 @@ import torch
|
||||
from torch import SymInt
|
||||
from torch.fx.experimental.sym_node import SymNode
|
||||
from torch.types import py_sym_types, PySymType
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -18,8 +17,7 @@ if TYPE_CHECKING:
|
||||
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _DeconstructedSymNode:
|
||||
"""
|
||||
Represents a SymNode without the associated ShapeEnv
|
||||
@ -73,8 +71,7 @@ class _DeconstructedSymNode:
|
||||
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _DeconstructedSymType:
|
||||
"""
|
||||
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
|
||||
@ -103,14 +100,12 @@ class _DeconstructedSymType:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _InputBackref:
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class _PySymInputStub:
|
||||
"""
|
||||
Represents a SymInt in the cached key. Needed because SymInt doesn't
|
||||
@ -172,8 +167,7 @@ class _PySymInputStub:
|
||||
return self.value.node._value_hash()
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class _SymIntOutputStub:
|
||||
"""
|
||||
Represents a SymInt in the cached output.
|
||||
@ -207,8 +201,7 @@ class _SymIntOutputStub:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class _CacheKeyState:
|
||||
"""
|
||||
State used while building our cache key.
|
||||
|
@ -40,7 +40,6 @@ from torch.fx.operator_schemas import normalize_function
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
from torch.overrides import TorchFunctionMode
|
||||
from torch.types import IntLikeType, py_sym_types
|
||||
from torch.utils._backport_slots import dataclass_slots
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import (
|
||||
is_traceable_wrapper_subclass,
|
||||
@ -1008,8 +1007,7 @@ class FakeTensor(Tensor):
|
||||
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class TensorMetadata:
|
||||
"""
|
||||
The Tensor metadata relevant to hashing FakeTensors when caching.
|
||||
@ -1093,8 +1091,7 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
|
||||
)
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class _DispatchCacheKey:
|
||||
"""
|
||||
Key for the FakeTensor dispatch cache.
|
||||
@ -1127,8 +1124,7 @@ class SingletonConstant:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _DispatchCacheEntryOutputInfo:
|
||||
"""
|
||||
Entry type for the FakeTensor dispatch cache for an output. Accounts for three
|
||||
@ -1147,8 +1143,7 @@ class _DispatchCacheEntryOutputInfo:
|
||||
constant_value: Optional[Any] = SingletonConstant
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _DispatchCacheValidEntry:
|
||||
"""
|
||||
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
|
||||
@ -1162,8 +1157,7 @@ class _DispatchCacheValidEntry:
|
||||
is_output_tuple: bool = False
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _DispatchCacheBypassEntry:
|
||||
"""
|
||||
Entry type for a negative cache entry.
|
||||
@ -1176,8 +1170,7 @@ if TYPE_CHECKING:
|
||||
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class _BypassDispatchCache(Exception):
|
||||
"""
|
||||
Signals cases that should skip FakeTensor caching.
|
||||
@ -1186,8 +1179,7 @@ class _BypassDispatchCache(Exception):
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass_slots
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class DispatchCacheInfo:
|
||||
"""
|
||||
Information about the state of the FakeTensor dispatch cache.
|
||||
|
@ -1,116 +0,0 @@
|
||||
# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the
|
||||
# minimum supported we should use dataclass(slots=True) instead.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
from _typeshed import DataclassInstance
|
||||
|
||||
|
||||
__all__ = ["dataclass_slots"]
|
||||
|
||||
_T = TypeVar("_T", bound="DataclassInstance")
|
||||
|
||||
|
||||
def dataclass_slots(cls: type[_T]) -> type[DataclassInstance]:
|
||||
assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses."
|
||||
|
||||
def _get_slots(cls: type[DataclassInstance]) -> Generator[str, None, None]:
|
||||
slots = cls.__dict__.get("__slots__")
|
||||
# `__dictoffset__` and `__weakrefoffset__` can tell us whether
|
||||
# the base type has dict/weakref slots, in a way that works correctly
|
||||
# for both Python classes and C extension types. Extension types
|
||||
# don't use `__slots__` for slot creation
|
||||
if slots is None:
|
||||
slots = []
|
||||
if getattr(cls, "__weakrefoffset__", -1) != 0:
|
||||
slots.append("__weakref__")
|
||||
if getattr(cls, "__dictrefoffset__", -1) != 0:
|
||||
slots.append("__dict__")
|
||||
yield from slots
|
||||
elif isinstance(slots, str):
|
||||
yield slots
|
||||
# Slots may be any iterable, but we cannot handle an iterator
|
||||
# because it will already be (partially) consumed.
|
||||
elif not hasattr(cls, "__next__"):
|
||||
yield from slots
|
||||
else:
|
||||
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
|
||||
|
||||
def _add_slots(
|
||||
cls: type[DataclassInstance], is_frozen: bool, weakref_slot: bool
|
||||
) -> type[DataclassInstance]:
|
||||
# Need to create a new class, since we can't set __slots__
|
||||
# after a class has been created.
|
||||
|
||||
# Make sure __slots__ isn't already set.
|
||||
if "__slots__" in cls.__dict__:
|
||||
raise TypeError(f"{cls.__name__} already specifies __slots__")
|
||||
|
||||
# Create a new dict for our new class.
|
||||
cls_dict = dict(cls.__dict__)
|
||||
field_names = tuple(f.name for f in dataclasses.fields(cls))
|
||||
# Make sure slots don't overlap with those in base classes.
|
||||
inherited_slots = set(
|
||||
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
|
||||
)
|
||||
# The slots for our class. Remove slots from our base classes. Add
|
||||
# '__weakref__' if weakref_slot was given, unless it is already present.
|
||||
cls_dict["__slots__"] = tuple(
|
||||
itertools.filterfalse(
|
||||
inherited_slots.__contains__,
|
||||
itertools.chain(
|
||||
# gh-93521: '__weakref__' also needs to be filtered out if
|
||||
# already present in inherited_slots
|
||||
field_names,
|
||||
("__weakref__",) if weakref_slot else (),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
for field_name in field_names:
|
||||
# Remove our attributes, if present. They'll still be
|
||||
# available in _MARKER.
|
||||
cls_dict.pop(field_name, None)
|
||||
|
||||
# Remove __dict__ itself.
|
||||
cls_dict.pop("__dict__", None)
|
||||
|
||||
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
|
||||
cls_dict.pop("__weakref__", None) # gh-102069
|
||||
|
||||
# And finally create the class.
|
||||
qualname = getattr(cls, "__qualname__", None)
|
||||
cls = type(cls.__name__, cls.__bases__, cls_dict)
|
||||
if qualname is not None:
|
||||
cls.__qualname__ = qualname
|
||||
|
||||
def _dataclass_getstate(self: _T) -> object:
|
||||
fields = dataclasses.fields(self)
|
||||
return [getattr(self, f.name) for f in fields]
|
||||
|
||||
def _dataclass_setstate(self: _T, state: list[object]) -> None:
|
||||
fields = dataclasses.fields(self)
|
||||
for field, value in zip(fields, state):
|
||||
# use setattr because dataclass may be frozen
|
||||
object.__setattr__(self, field.name, value)
|
||||
|
||||
if is_frozen:
|
||||
# Need this for pickling frozen classes with slots.
|
||||
if "__getstate__" not in cls_dict:
|
||||
cls.__getstate__ = _dataclass_getstate # type: ignore[method-assign, assignment]
|
||||
if "__setstate__" not in cls_dict:
|
||||
cls.__setstate__ = _dataclass_setstate # type: ignore[attr-defined]
|
||||
|
||||
return cls
|
||||
|
||||
params = getattr(cls, dataclasses._PARAMS) # type: ignore[attr-defined]
|
||||
weakref_slot = getattr(params, "weakref_slot", False)
|
||||
return _add_slots(cls, params.frozen, weakref_slot)
|
Reference in New Issue
Block a user