TST Adds pickle testing for ModuleInfo (#63736)

Summary:
Follow up to https://github.com/pytorch/pytorch/pull/61935

This PR adds `test_pickle` to `test_modules`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63736

Reviewed By: heitorschueroff

Differential Revision: D30522462

Pulled By: jbschlosser

fbshipit-source-id: a03b66ea0d81c6d0845c4fddf0ddc3714bbf0ab1
This commit is contained in:
Thomas J. Fan
2021-08-24 18:55:23 -07:00
committed by Facebook GitHub Bot
parent 8dda299d96
commit 58ef99bd5a

View File

@ -1,3 +1,5 @@
import tempfile
import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_modules import module_db, modules
@ -108,6 +110,36 @@ class TestModule(TestCase):
buffer.dtype, dtype,
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
@modules(module_db)
def test_pickle(self, device, dtype, module_info):
# Test that module can be pickled and unpickled.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
with freeze_rng_state():
# === Instantiate the module. ===
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
output = m(*args, **kwargs)
# === Check unpickled module gives the same output. ===
with tempfile.TemporaryFile() as f:
torch.save(m, f)
f.seek(0)
m_copy = torch.load(f)
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
instantiate_device_type_tests(TestModule, globals())