mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add exhaustive module and optimizer tests for torch.load(state_dict, weights_only=True) (#121049)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121049 Approved by: https://github.com/janeyx99
This commit is contained in:
committed by
PyTorch MergeBot
parent
42821d462a
commit
d621e3e3b8
@ -212,7 +212,7 @@ class TestModule(TestCase):
|
||||
str(m)
|
||||
|
||||
@modules(module_db)
|
||||
def test_pickle(self, device, dtype, module_info, training):
|
||||
def test_save_load(self, device, dtype, module_info, training):
|
||||
# Test that module can be pickled and unpickled.
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
@ -229,12 +229,13 @@ class TestModule(TestCase):
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
m.train(training)
|
||||
sd = m.state_dict()
|
||||
|
||||
# === Do forward pass. ===
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
output = m(*args, **kwargs)
|
||||
|
||||
# === Check unpickled module gives the same output. ===
|
||||
# === Check saved/loaded module gives the same output. ===
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(m, f)
|
||||
f.seek(0)
|
||||
@ -242,6 +243,17 @@ class TestModule(TestCase):
|
||||
output_from_copy = m_copy(*args, **kwargs)
|
||||
self.assertEqual(output, output_from_copy)
|
||||
|
||||
# === Check saved/loaded state_dict are the same (including weights_only load). ===
|
||||
with tempfile.TemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
sd_copy = torch.load(f)
|
||||
self.assertEqual(sd_copy, sd)
|
||||
del sd_copy
|
||||
f.seek(0)
|
||||
sd_copy_wo = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_copy_wo, sd)
|
||||
|
||||
@skipMeta
|
||||
@modules([module_info for module_info in module_db
|
||||
if 'inplace' in signature(module_info.module_cls).parameters])
|
||||
@ -816,10 +828,8 @@ class TestModule(TestCase):
|
||||
|
||||
for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
|
||||
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
|
||||
c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
|
||||
fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
|
||||
|
||||
m_cpu = module_cls(*c_args, **c_kwargs)
|
||||
|
||||
|
Reference in New Issue
Block a user