pymft lint torch/utils/weak.py (#154484)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154484
Approved by: https://github.com/Skylion007
ghstack dependencies: #154483
This commit is contained in:
Laith Sakka
2025-05-27 21:30:49 -07:00
committed by PyTorch MergeBot
parent 4de1b25df7
commit b4fe5ca58a
2 changed files with 28 additions and 13 deletions

View File

@ -1438,7 +1438,6 @@ exclude_patterns = [
'torch/utils/throughput_benchmark.py',
'torch/utils/viz/__init__.py',
'torch/utils/viz/_cycles.py',
'torch/utils/weak.py',
]
init_command = [
'python3',

View File

@ -1,18 +1,25 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import weakref
from weakref import ref
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import MutableMapping, Mapping
from torch import Tensor
import collections.abc as _collections_abc
import weakref
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import Mapping, MutableMapping
from weakref import ref
from torch import Tensor
WeakRef = ref
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
__all__ = [
"TensorWeakRef",
"WeakIdRef",
"WeakIdKeyDictionary",
"WeakTensorKeyDictionary",
]
# This file defines a variant of WeakKeyDictionary that overrides the hashing
@ -41,7 +48,7 @@ __all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDi
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
# easy to get wrong cases transparently for you.
class WeakIdRef(weakref.ref):
__slots__ = ['_id']
__slots__ = ["_id"]
def __init__(self, key, callback=None):
# Unlike stock weakref, which preserves hash semantics of the
@ -55,7 +62,7 @@ class WeakIdRef(weakref.ref):
def __call__(self):
r = super().__call__()
# Special logic for Tensor PyObject resurrection
if hasattr(r, '_fix_weakref'):
if hasattr(r, "_fix_weakref"):
r._fix_weakref() # type: ignore[union-attr]
return r
@ -81,10 +88,11 @@ class WeakIdRef(weakref.ref):
return a is b
return self is other
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
# This will be equivalent to the one above except for classes where hash is not their id.
class _WeakHashRef(weakref.ref):
__slots__ = ['_id']
__slots__ = ["_id"]
def __init__(self, key, callback=None):
# Unlike stock weakref, which preserves hash semantics of the
@ -98,7 +106,7 @@ class _WeakHashRef(weakref.ref):
def __call__(self):
r = super().__call__()
# Special logic for Tensor PyObject resurrection
if hasattr(r, '_fix_weakref'):
if hasattr(r, "_fix_weakref"):
r._fix_weakref() # type: ignore[union-attr]
return r
@ -115,6 +123,7 @@ class _WeakHashRef(weakref.ref):
return hash(a) == hash(b)
return self is other
# This is directly adapted from cpython/Lib/weakref.py
class WeakIdKeyDictionary(MutableMapping):
def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED
@ -132,6 +141,7 @@ class WeakIdKeyDictionary(MutableMapping):
del self.data[k]
except KeyError:
pass
self._remove = remove
# A list of dead weakrefs (keys to be removed)
self._pending_removals = []
@ -196,6 +206,7 @@ class WeakIdKeyDictionary(MutableMapping):
def __deepcopy__(self, memo):
from copy import deepcopy
new = self.__class__()
with _IterationGuard(self):
for key, value in self.data.items():
@ -261,7 +272,9 @@ class WeakIdKeyDictionary(MutableMapping):
return self.data.pop(self.ref_type(key), *args) # CHANGED
def setdefault(self, key, default=None):
return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
return self.data.setdefault(
self.ref_type(key, self._remove), default
) # CHANGED
def update(self, dict=None, **kwargs): # type: ignore[override]
d = self.data
@ -297,7 +310,10 @@ class WeakIdKeyDictionary(MutableMapping):
def __eq__(self, other):
if not isinstance(other, Mapping):
return NotImplemented
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
return {id(k): v for k, v in self.items()} == {
id(k): v for k, v in other.items()
}
# Convenience alias
WeakTensorKeyDictionary = WeakIdKeyDictionary