Break reference cycle in load_state_dict (#20397)

Summary:
load_state_dict includes a recursive inner function `load` that captures
Tensors through the close-over variable `state_dict`. Because it's
recursive, it also captures itself leading to a reference cycle.

This breaks the reference cycle so that any Tensors in state_dict can be
collected immediately instead of waiting until the next GC cycle.

Alternatively, we could have passed `state_dict` and `metadata` as
arguments to load to prevent capture of Tensors. (That would still
result in cyclic garbage, but not any cyclic garbage of Tensors).

See:
https://github.com/pytorch/pytorch/issues/20199#issuecomment-491089004
Pull Request resolved: https://github.com/pytorch/pytorch/pull/20397

Differential Revision: D15414834

Pulled By: colesbury

fbshipit-source-id: 4c2275a08b2d8043deb3779db28be03bda15872d
This commit is contained in:
Sam Gross
2019-05-20 11:42:41 -07:00
committed by Facebook Github Bot
parent 796e359601
commit c1fa449763
2 changed files with 14 additions and 0 deletions

View File

@ -4396,6 +4396,19 @@ class TestNN(NNTestCase):
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
self.assertEqual(bn.num_batches_tracked.item(), 0)
@unittest.skipIf(not PY3, 'Python 2.7 generates cyclic trash')
def test_load_state_dict_ref_cycle(self):
# load_state_dict shouldn't cause a reference cycle involving Tensors
import gc
m = torch.nn.LSTM(16, 16, bidirectional=True)
gc.collect()
m.load_state_dict(deepcopy(m).state_dict())
refcycles = gc.collect()
self.assertEqual(refcycles, 0)
def test_parameter_assignment(self):
l = nn.Linear(5, 5)

View File

@ -761,6 +761,7 @@ class Module(object):
load(child, prefix + name + '.')
load(self)
load = None # break load->load reference cycle
if strict:
if len(unexpected_keys) > 0: