Add exhaustive module and optimizer tests for torch.load(state_dict, weights_only=True) (#121049)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121049
Approved by: https://github.com/janeyx99
This commit is contained in:
Mikayla Gawarecki
2024-03-04 15:21:55 -08:00
committed by PyTorch MergeBot
parent 42821d462a
commit d621e3e3b8
4 changed files with 77 additions and 5 deletions

View File

@ -212,7 +212,7 @@ class TestModule(TestCase):
str(m)
@modules(module_db)
def test_pickle(self, device, dtype, module_info, training):
def test_save_load(self, device, dtype, module_info, training):
# Test that module can be pickled and unpickled.
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
@ -229,12 +229,13 @@ class TestModule(TestCase):
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
m.train(training)
sd = m.state_dict()
# === Do forward pass. ===
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
output = m(*args, **kwargs)
# === Check unpickled module gives the same output. ===
# === Check saved/loaded module gives the same output. ===
with tempfile.TemporaryFile() as f:
torch.save(m, f)
f.seek(0)
@ -242,6 +243,17 @@ class TestModule(TestCase):
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
# === Check saved/loaded state_dict are the same (including weights_only load). ===
with tempfile.TemporaryFile() as f:
torch.save(sd, f)
f.seek(0)
sd_copy = torch.load(f)
self.assertEqual(sd_copy, sd)
del sd_copy
f.seek(0)
sd_copy_wo = torch.load(f, weights_only=True)
self.assertEqual(sd_copy_wo, sd)
@skipMeta
@modules([module_info for module_info in module_db
if 'inplace' in signature(module_info.module_cls).parameters])
@ -816,10 +828,8 @@ class TestModule(TestCase):
for module_input, module_input_meta in zip(module_inputs, module_inputs_meta):
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs
fw_args_meta, fw_kwargs_meta = module_input_meta.forward_input.args, module_input_meta.forward_input.kwargs
m_cpu = module_cls(*c_args, **c_kwargs)