mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
TST Adds pickle testing for ModuleInfo (#63736)
Summary: Follow up to https://github.com/pytorch/pytorch/pull/61935 This PR adds `test_pickle` to `test_modules`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63736 Reviewed By: heitorschueroff Differential Revision: D30522462 Pulled By: jbschlosser fbshipit-source-id: a03b66ea0d81c6d0845c4fddf0ddc3714bbf0ab1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8dda299d96
commit
58ef99bd5a
@ -1,3 +1,5 @@
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
@ -108,6 +110,36 @@ class TestModule(TestCase):
|
||||
buffer.dtype, dtype,
|
||||
f'Buffer {name} is of dtype {buffer.dtype} instead of the expected dtype {dtype}')
|
||||
|
||||
@modules(module_db)
|
||||
def test_pickle(self, device, dtype, module_info):
|
||||
# Test that module can be pickled and unpickled.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=False)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
|
||||
with freeze_rng_state():
|
||||
# === Instantiate the module. ===
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
output = m(*args, **kwargs)
|
||||
|
||||
# === Check unpickled module gives the same output. ===
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(m, f)
|
||||
f.seek(0)
|
||||
m_copy = torch.load(f)
|
||||
output_from_copy = m_copy(*args, **kwargs)
|
||||
self.assertEqual(output, output_from_copy)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestModule, globals())
|
||||
|
||||
|
Reference in New Issue
Block a user