Prevent legacy_load when weights_only=True (correctly) (#145020)

Only prevent `legacy_load` (.tar format removed in https://github.com/pytorch/pytorch/pull/713), not the whole of `_legacy_load` (.tar format + _use_new_zipfile_serialization=False)

Differential Revision: [D68301405](https://our.internmc.facebook.com/intern/diff/D68301405)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145020
Approved by: https://github.com/kit1980, https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2025-01-17 08:51:46 -08:00
committed by PyTorch MergeBot
parent 2ef7b68666
commit 0eda02a94c
2 changed files with 17 additions and 7 deletions

View File

@ -466,7 +466,11 @@ class SerializationMixin:
b += [a[0].storage()]
b += [a[0].reshape(-1)[1:4].clone().storage()]
path = download_file('https://download.pytorch.org/test_data/legacy_serialized.pt')
c = torch.load(path, weights_only=weights_only)
if weights_only:
with self.assertRaisesRegex(RuntimeError,
"Cannot use ``weights_only=True`` with files saved in the legacy .tar format."):
c = torch.load(path, weights_only=weights_only)
c = torch.load(path, weights_only=False)
self.assertEqual(b, c, atol=0, rtol=0)
self.assertTrue(isinstance(c[0], torch.FloatTensor))
self.assertTrue(isinstance(c[1], torch.FloatTensor))