mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/13141 This is an example diff to show what lint rules are being applied. Reviewed By: mingzhe09088 Differential Revision: D10858478 fbshipit-source-id: cbeb013f10f755b0095478adf79366e7cf7836ff
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
import collections
|
|
import weakref
|
|
import warnings
|
|
|
|
|
|
class RemovableHandle(object):
|
|
"""A handle which provides the capability to remove a hook."""
|
|
|
|
next_id = 0
|
|
|
|
def __init__(self, hooks_dict):
|
|
self.hooks_dict_ref = weakref.ref(hooks_dict)
|
|
self.id = RemovableHandle.next_id
|
|
RemovableHandle.next_id += 1
|
|
|
|
def remove(self):
|
|
hooks_dict = self.hooks_dict_ref()
|
|
if hooks_dict is not None and self.id in hooks_dict:
|
|
del hooks_dict[self.id]
|
|
|
|
def __getstate__(self):
|
|
return (self.hooks_dict_ref(), self.id)
|
|
|
|
def __setstate__(self, state):
|
|
if state[0] is None:
|
|
# create a dead reference
|
|
self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
|
|
else:
|
|
self.hooks_dict_ref = weakref.ref(state[0])
|
|
self.id = state[1]
|
|
RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1)
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, type, value, tb):
|
|
self.remove()
|
|
|
|
|
|
def unserializable_hook(f):
|
|
"""
|
|
Decorator which marks a function as an unserializable hook.
|
|
This suppresses warnings that would otherwise arise if you attempt
|
|
to serialize a tensor that has a hook.
|
|
"""
|
|
f.__torch_unserializable__ = True
|
|
return f
|
|
|
|
|
|
def warn_if_has_hooks(tensor):
|
|
if tensor._backward_hooks:
|
|
for k in tensor._backward_hooks:
|
|
hook = tensor._backward_hooks[k]
|
|
if not hasattr(k, "__torch_unserializable__"):
|
|
warnings.warn("backward hook {} on tensor will not be "
|
|
"serialized. If this is expected, you can "
|
|
"decorate the function with @torch.utils.hooks.unserializable_hook "
|
|
"to suppress this warning".format(repr(hook)))
|