From b4fe5ca58a6032f76ea93d24efadfb31a7119310 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 27 May 2025 21:30:49 -0700 Subject: [PATCH] pymft lint torch/utils/weak.py (#154484) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154484 Approved by: https://github.com/Skylion007 ghstack dependencies: #154483 --- .lintrunner.toml | 1 - torch/utils/weak.py | 40 ++++++++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 6f7f62c566b9..07747c1bc39f 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -1438,7 +1438,6 @@ exclude_patterns = [ 'torch/utils/throughput_benchmark.py', 'torch/utils/viz/__init__.py', 'torch/utils/viz/_cycles.py', - 'torch/utils/weak.py', ] init_command = [ 'python3', diff --git a/torch/utils/weak.py b/torch/utils/weak.py index f729ff06489f..8bf2ba5ed02b 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -1,18 +1,25 @@ # mypy: allow-untyped-defs from __future__ import annotations -import weakref -from weakref import ref -from _weakrefset import _IterationGuard # type: ignore[attr-defined] -from collections.abc import MutableMapping, Mapping -from torch import Tensor import collections.abc as _collections_abc +import weakref + +from _weakrefset import _IterationGuard # type: ignore[attr-defined] +from collections.abc import Mapping, MutableMapping +from weakref import ref + +from torch import Tensor WeakRef = ref -__all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDictionary'] +__all__ = [ + "TensorWeakRef", + "WeakIdRef", + "WeakIdKeyDictionary", + "WeakTensorKeyDictionary", +] # This file defines a variant of WeakKeyDictionary that overrides the hashing @@ -41,7 +48,7 @@ __all__ = ['TensorWeakRef', 'WeakIdRef', 'WeakIdKeyDictionary', 'WeakTensorKeyDi # WeakIdRef(tensor) rather than weakref.ref(tensor); it handles a number of # easy to get wrong cases transparently for you. class WeakIdRef(weakref.ref): - __slots__ = ['_id'] + __slots__ = ["_id"] def __init__(self, key, callback=None): # Unlike stock weakref, which preserves hash semantics of the @@ -55,7 +62,7 @@ class WeakIdRef(weakref.ref): def __call__(self): r = super().__call__() # Special logic for Tensor PyObject resurrection - if hasattr(r, '_fix_weakref'): + if hasattr(r, "_fix_weakref"): r._fix_weakref() # type: ignore[union-attr] return r @@ -81,10 +88,11 @@ class WeakIdRef(weakref.ref): return a is b return self is other + # This is the same as WeakIdRef but equality is checked using hash() rather than id. # This will be equivalent to the one above except for classes where hash is not their id. class _WeakHashRef(weakref.ref): - __slots__ = ['_id'] + __slots__ = ["_id"] def __init__(self, key, callback=None): # Unlike stock weakref, which preserves hash semantics of the @@ -98,7 +106,7 @@ class _WeakHashRef(weakref.ref): def __call__(self): r = super().__call__() # Special logic for Tensor PyObject resurrection - if hasattr(r, '_fix_weakref'): + if hasattr(r, "_fix_weakref"): r._fix_weakref() # type: ignore[union-attr] return r @@ -115,6 +123,7 @@ class _WeakHashRef(weakref.ref): return hash(a) == hash(b) return self is other + # This is directly adapted from cpython/Lib/weakref.py class WeakIdKeyDictionary(MutableMapping): def __init__(self, dict=None, ref_type=WeakIdRef): # CHANGED @@ -132,6 +141,7 @@ class WeakIdKeyDictionary(MutableMapping): del self.data[k] except KeyError: pass + self._remove = remove # A list of dead weakrefs (keys to be removed) self._pending_removals = [] @@ -196,6 +206,7 @@ class WeakIdKeyDictionary(MutableMapping): def __deepcopy__(self, memo): from copy import deepcopy + new = self.__class__() with _IterationGuard(self): for key, value in self.data.items(): @@ -261,7 +272,9 @@ class WeakIdKeyDictionary(MutableMapping): return self.data.pop(self.ref_type(key), *args) # CHANGED def setdefault(self, key, default=None): - return self.data.setdefault(self.ref_type(key, self._remove), default) # CHANGED + return self.data.setdefault( + self.ref_type(key, self._remove), default + ) # CHANGED def update(self, dict=None, **kwargs): # type: ignore[override] d = self.data @@ -297,7 +310,10 @@ class WeakIdKeyDictionary(MutableMapping): def __eq__(self, other): if not isinstance(other, Mapping): return NotImplemented - return {id(k): v for k, v in self.items()} == {id(k): v for k, v in other.items()} + return {id(k): v for k, v in self.items()} == { + id(k): v for k, v in other.items() + } + # Convenience alias WeakTensorKeyDictionary = WeakIdKeyDictionary