mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Prevent legacy_load when weights_only=True (correctly) (#145020)
Only prevent `legacy_load` (.tar format removed in https://github.com/pytorch/pytorch/pull/713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False) Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145020 Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
2ef7b68666
commit
0eda02a94c
@ -466,7 +466,11 @@ class SerializationMixin:
|
||||
b += [a[0].storage()]
|
||||
b += [a[0].reshape(-1)[1:4].clone().storage()]
|
||||
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
|
||||
c = torch.load(path, weights_only=weights_only)
|
||||
if weights_only:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Cannot use ``weights_only=True`` with files saved in the legacy .tar format."):
|
||||
c = torch.load(path, weights_only=weights_only)
|
||||
c = torch.load(path, weights_only=False)
|
||||
self.assertEqual(b, c, atol=0, rtol=0)
|
||||
self.assertTrue(isinstance(c[0], torch.FloatTensor))
|
||||
self.assertTrue(isinstance(c[1], torch.FloatTensor))
|
||||
|
Reference in New Issue
Block a user