mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
pymft lint torch/utils/weak.py (#154484)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154484 Approved by: https://github.com/Skylion007 ghstack dependencies: #154483
This commit is contained in:
committed by
PyTorch MergeBot
parent
4de1b25df7
commit
b4fe5ca58a
@ -1438,7 +1438,6 @@ exclude_patterns = [
|
||||
'torch/utils/throughput_benchmark.py',
|
||||
'torch/utils/viz/__init__.py',
|
||||
'torch/utils/viz/_cycles.py',
|
||||
'torch/utils/weak.py',
|
||||
]
|
||||
init_command = [
|
||||
'python3',
|
||||
|
@ -1,18 +1,25 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from __future__ import annotations
|
||||
|
||||
import weakref
|
||||
from weakref import ref
|
||||
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
|
||||
from collections.abc import MutableMapping, Mapping
|
||||
from torch import Tensor
|
||||
import collections.abc as _collections_abc
|
||||
import weakref
|
||||
|
||||
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from weakref import ref
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
WeakRef = ref
|
||||
|
||||
|
||||
__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary']
|
||||
__all__ = [
|
||||
"TensorWeakRef",
|
||||
"WeakIdRef",
|
||||
"WeakIdKeyDictionary",
|
||||
"WeakTensorKeyDictionary",
|
||||
]
|
||||
|
||||
|
||||
# This file defines a variant of WeakKeyDictionary that overrides the hashing
|
||||
@ -41,7 +48,7 @@ __all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDi
|
||||
# WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of
|
||||
# easy to get wrong cases transparently for you.
|
||||
class WeakIdRef(weakref.ref):
|
||||
__slots__ = ['_id']
|
||||
__slots__ = ["_id"]
|
||||
|
||||
def __init__(self, key, callback=None):
|
||||
# Unlike stock weakref, which preserves hash semantics of the
|
||||
@ -55,7 +62,7 @@ class WeakIdRef(weakref.ref):
|
||||
def __call__(self):
|
||||
r = super().__call__()
|
||||
# Special logic for Tensor PyObject resurrection
|
||||
if hasattr(r, '_fix_weakref'):
|
||||
if hasattr(r, "_fix_weakref"):
|
||||
r._fix_weakref() # type: ignore[union-attr]
|
||||
return r
|
||||
|
||||
@ -81,10 +88,11 @@ class WeakIdRef(weakref.ref):
|
||||
return a is b
|
||||
return self is other
|
||||
|
||||
|
||||
# This is the same as WeakIdRef but equality is checked using hash() rather than id.
|
||||
# This will be equivalent to the one above except for classes where hash is not their id.
|
||||
class _WeakHashRef(weakref.ref):
|
||||
__slots__ = ['_id']
|
||||
__slots__ = ["_id"]
|
||||
|
||||
def __init__(self, key, callback=None):
|
||||
# Unlike stock weakref, which preserves hash semantics of the
|
||||
@ -98,7 +106,7 @@ class _WeakHashRef(weakref.ref):
|
||||
def __call__(self):
|
||||
r = super().__call__()
|
||||
# Special logic for Tensor PyObject resurrection
|
||||
if hasattr(r, '_fix_weakref'):
|
||||
if hasattr(r, "_fix_weakref"):
|
||||
r._fix_weakref() # type: ignore[union-attr]
|
||||
return r
|
||||
|
||||
@ -115,6 +123,7 @@ class _WeakHashRef(weakref.ref):
|
||||
return hash(a) == hash(b)
|
||||
return self is other
|
||||
|
||||
|
||||
# This is directly adapted from cpython/Lib/weakref.py
|
||||
class WeakIdKeyDictionary(MutableMapping):
|
||||
def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED
|
||||
@ -132,6 +141,7 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
del self.data[k]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
self._remove = remove
|
||||
# A list of dead weakrefs (keys to be removed)
|
||||
self._pending_removals = []
|
||||
@ -196,6 +206,7 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
from copy import deepcopy
|
||||
|
||||
new = self.__class__()
|
||||
with _IterationGuard(self):
|
||||
for key, value in self.data.items():
|
||||
@ -261,7 +272,9 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
return self.data.pop(self.ref_type(key), *args) # CHANGED
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED
|
||||
return self.data.setdefault(
|
||||
self.ref_type(key, self._remove), default
|
||||
) # CHANGED
|
||||
|
||||
def update(self, dict=None, **kwargs): # type: ignore[override]
|
||||
d = self.data
|
||||
@ -297,7 +310,10 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Mapping):
|
||||
return NotImplemented
|
||||
return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()}
|
||||
return {id(k): v for k, v in self.items()} == {
|
||||
id(k): v for k, v in other.items()
|
||||
}
|
||||
|
||||
|
||||
# Convenience alias
|
||||
WeakTensorKeyDictionary = WeakIdKeyDictionary
|
||||
|
Reference in New Issue
Block a user