mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Break reference cycle in load_state_dict (#20397)
Summary: load_state_dict includes a recursive inner function `load` that captures Tensors through the close-over variable `state_dict`. Because it's recursive, it also captures itself leading to a reference cycle. This breaks the reference cycle so that any Tensors in state_dict can be collected immediately instead of waiting until the next GC cycle. Alternatively, we could have passed `state_dict` and `metadata` as arguments to load to prevent capture of Tensors. (That would still result in cyclic garbage, but not any cyclic garbage of Tensors). See: https://github.com/pytorch/pytorch/issues/20199#issuecomment-491089004 Pull Request resolved: https://github.com/pytorch/pytorch/pull/20397 Differential Revision: D15414834 Pulled By: colesbury fbshipit-source-id: 4c2275a08b2d8043deb3779db28be03bda15872d
This commit is contained in:
committed by
Facebook Github Bot
parent
796e359601
commit
c1fa449763
@ -4396,6 +4396,19 @@ class TestNN(NNTestCase):
|
||||
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
|
||||
self.assertEqual(bn.num_batches_tracked.item(), 0)
|
||||
|
||||
@unittest.skipIf(not PY3, 'Python 2.7 generates cyclic trash')
|
||||
def test_load_state_dict_ref_cycle(self):
|
||||
# load_state_dict shouldn't cause a reference cycle involving Tensors
|
||||
import gc
|
||||
|
||||
m = torch.nn.LSTM(16, 16, bidirectional=True)
|
||||
|
||||
gc.collect()
|
||||
m.load_state_dict(deepcopy(m).state_dict())
|
||||
refcycles = gc.collect()
|
||||
|
||||
self.assertEqual(refcycles, 0)
|
||||
|
||||
def test_parameter_assignment(self):
|
||||
l = nn.Linear(5, 5)
|
||||
|
||||
|
@ -761,6 +761,7 @@ class Module(object):
|
||||
load(child, prefix + name + '.')
|
||||
|
||||
load(self)
|
||||
load = None # break load->load reference cycle
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
|
Reference in New Issue
Block a user