Remove dataclass_slots (#163623)

`dataclass` now has `slots` kwarg.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163623
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanyuan Chen
2025-09-26 00:54:42 +00:00
committed by PyTorch MergeBot
parent b776e0c71e
commit 5daa79fd6e
5 changed files with 17 additions and 154 deletions

View File

@ -24,7 +24,6 @@ import uuid
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from ..utils._backport_slots import dataclass_slots
from . import config
from .bytecode_analysis import (
get_indexof,
@ -39,8 +38,7 @@ if TYPE_CHECKING:
from .output_graph import DynamoTracerOutput
@dataclass_slots
@dataclasses.dataclass
@dataclasses.dataclass(slots=True)
class InstructionExnTabEntry:
start: "Instruction"
end: "Instruction"
@ -68,8 +66,7 @@ class InstructionExnTabEntry:
)
@dataclass_slots
@dataclasses.dataclass
@dataclasses.dataclass(slots=True)
class Instruction:
"""A mutable version of dis.Instruction"""
@ -642,8 +639,7 @@ def linetable_311_writer(
return linetable, update
@dataclass_slots
@dataclasses.dataclass
@dataclasses.dataclass(slots=True)
class ExceptionTableEntry:
start: int
end: int

View File

@ -27,7 +27,6 @@ from typing import (
import torch
from torch.utils import _pytree as pytree
from torch.utils._backport_slots import dataclass_slots
from torch.utils._traceback import CapturedTraceback, format_frame
from torch.utils.weak import WeakTensorKeyDictionary
@ -241,8 +240,7 @@ class ShapeGuard(NamedTuple):
size_oblivious: bool
@dataclass_slots
@dataclasses.dataclass
@dataclasses.dataclass(slots=True)
class Guard:
# originating_source is the source that called the make_guard method to
# construct this guard object. The property name specifies what exactly it

View File

@ -7,7 +7,6 @@ import torch
from torch import SymInt
from torch.fx.experimental.sym_node import SymNode
from torch.types import py_sym_types, PySymType
from torch.utils._backport_slots import dataclass_slots
if TYPE_CHECKING:
@ -18,8 +17,7 @@ if TYPE_CHECKING:
from .fake_tensor import _DispatchCacheKey, _MetadataIntLike
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _DeconstructedSymNode:
"""
Represents a SymNode without the associated ShapeEnv
@ -73,8 +71,7 @@ class _DeconstructedSymNode:
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _DeconstructedSymType:
"""
Represents a SymInt, SymFloat, SymBool without the associated ShapeEnv
@ -103,14 +100,12 @@ class _DeconstructedSymType:
return NotImplemented
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _InputBackref:
value: int
@dataclass_slots
@dataclass
@dataclass(slots=True)
class _PySymInputStub:
"""
Represents a SymInt in the cached key. Needed because SymInt doesn't
@ -172,8 +167,7 @@ class _PySymInputStub:
return self.value.node._value_hash()
@dataclass_slots
@dataclass
@dataclass(slots=True)
class _SymIntOutputStub:
"""
Represents a SymInt in the cached output.
@ -207,8 +201,7 @@ class _SymIntOutputStub:
raise NotImplementedError
@dataclass_slots
@dataclass
@dataclass(slots=True)
class _CacheKeyState:
"""
State used while building our cache key.

View File

@ -40,7 +40,6 @@ from torch.fx.operator_schemas import normalize_function
from torch.multiprocessing.reductions import StorageWeakRef
from torch.overrides import TorchFunctionMode
from torch.types import IntLikeType, py_sym_types
from torch.utils._backport_slots import dataclass_slots
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import (
is_traceable_wrapper_subclass,
@ -1008,8 +1007,7 @@ class FakeTensor(Tensor):
_MetadataIntLike = Union[IntLikeType, "_PySymInputStub", "_SymIntOutputStub"]
@dataclass_slots
@dataclass
@dataclass(slots=True)
class TensorMetadata:
"""
The Tensor metadata relevant to hashing FakeTensors when caching.
@ -1093,8 +1091,7 @@ def extract_tensor_metadata(t: Tensor) -> TensorMetadata:
)
@dataclass_slots
@dataclass
@dataclass(slots=True)
class _DispatchCacheKey:
"""
Key for the FakeTensor dispatch cache.
@ -1127,8 +1124,7 @@ class SingletonConstant:
pass
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _DispatchCacheEntryOutputInfo:
"""
Entry type for the FakeTensor dispatch cache for an output. Accounts for three
@ -1147,8 +1143,7 @@ class _DispatchCacheEntryOutputInfo:
constant_value: Optional[Any] = SingletonConstant
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _DispatchCacheValidEntry:
"""
Entry type for the FakeTensor dispatch cache. It supports two types of outputs
@ -1162,8 +1157,7 @@ class _DispatchCacheValidEntry:
is_output_tuple: bool = False
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _DispatchCacheBypassEntry:
"""
Entry type for a negative cache entry.
@ -1176,8 +1170,7 @@ if TYPE_CHECKING:
_DispatchCacheEntry = Union[_DispatchCacheValidEntry, _DispatchCacheBypassEntry]
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class _BypassDispatchCache(Exception):
"""
Signals cases that should skip FakeTensor caching.
@ -1186,8 +1179,7 @@ class _BypassDispatchCache(Exception):
reason: str
@dataclass_slots
@dataclass(frozen=True)
@dataclass(frozen=True, slots=True)
class DispatchCacheInfo:
"""
Information about the state of the FakeTensor dispatch cache.

View File

@ -1,116 +0,0 @@
# This code is backported from python 3.10 dataclasses. Once 3.10 becomes the
# minimum supported we should use dataclass(slots=True) instead.
from __future__ import annotations
import dataclasses
import itertools
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from collections.abc import Generator
from _typeshed import DataclassInstance
__all__ = ["dataclass_slots"]
_T = TypeVar("_T", bound="DataclassInstance")
def dataclass_slots(cls: type[_T]) -> type[DataclassInstance]:
assert dataclasses.is_dataclass(cls), "Can only be used on dataclasses."
def _get_slots(cls: type[DataclassInstance]) -> Generator[str, None, None]:
slots = cls.__dict__.get("__slots__")
# `__dictoffset__` and `__weakrefoffset__` can tell us whether
# the base type has dict/weakref slots, in a way that works correctly
# for both Python classes and C extension types. Extension types
# don't use `__slots__` for slot creation
if slots is None:
slots = []
if getattr(cls, "__weakrefoffset__", -1) != 0:
slots.append("__weakref__")
if getattr(cls, "__dictrefoffset__", -1) != 0:
slots.append("__dict__")
yield from slots
elif isinstance(slots, str):
yield slots
# Slots may be any iterable, but we cannot handle an iterator
# because it will already be (partially) consumed.
elif not hasattr(cls, "__next__"):
yield from slots
else:
raise TypeError(f"Slots of '{cls.__name__}' cannot be determined")
def _add_slots(
cls: type[DataclassInstance], is_frozen: bool, weakref_slot: bool
) -> type[DataclassInstance]:
# Need to create a new class, since we can't set __slots__
# after a class has been created.
# Make sure __slots__ isn't already set.
if "__slots__" in cls.__dict__:
raise TypeError(f"{cls.__name__} already specifies __slots__")
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
# Make sure slots don't overlap with those in base classes.
inherited_slots = set(
itertools.chain.from_iterable(map(_get_slots, cls.__mro__[1:-1]))
)
# The slots for our class. Remove slots from our base classes. Add
# '__weakref__' if weakref_slot was given, unless it is already present.
cls_dict["__slots__"] = tuple(
itertools.filterfalse(
inherited_slots.__contains__,
itertools.chain(
# gh-93521: '__weakref__' also needs to be filtered out if
# already present in inherited_slots
field_names,
("__weakref__",) if weakref_slot else (),
),
),
)
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
cls_dict.pop(field_name, None)
# Remove __dict__ itself.
cls_dict.pop("__dict__", None)
# Clear existing `__weakref__` descriptor, it belongs to a previous type:
cls_dict.pop("__weakref__", None) # gh-102069
# And finally create the class.
qualname = getattr(cls, "__qualname__", None)
cls = type(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
def _dataclass_getstate(self: _T) -> object:
fields = dataclasses.fields(self)
return [getattr(self, f.name) for f in fields]
def _dataclass_setstate(self: _T, state: list[object]) -> None:
fields = dataclasses.fields(self)
for field, value in zip(fields, state):
# use setattr because dataclass may be frozen
object.__setattr__(self, field.name, value)
if is_frozen:
# Need this for pickling frozen classes with slots.
if "__getstate__" not in cls_dict:
cls.__getstate__ = _dataclass_getstate # type: ignore[method-assign, assignment]
if "__setstate__" not in cls_dict:
cls.__setstate__ = _dataclass_setstate # type: ignore[attr-defined]
return cls
params = getattr(cls, dataclasses._PARAMS) # type: ignore[attr-defined]
weakref_slot = getattr(params, "weakref_slot", False)
return _add_slots(cls, params.frozen, weakref_slot)