Expose option to disable CRC-32 computation during torch.save (#137735)

Option only works in open source, not internal

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137735
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki
2024-10-11 13:02:50 -07:00
committed by PyTorch MergeBot
parent 3cc8c8b944
commit 534fa96f2d
9 changed files with 116 additions and 35 deletions

View File

@ -4308,6 +4308,35 @@ class TestSerialization(TestCase, SerializationMixin):
f.seek(0)
torch.load(f, weights_only=True)
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")
@parametrize("compute_crc32", (True, False))
@parametrize("filename", (True, False))
def test_crc32_options(self, compute_crc32, filename):
# test both path and buffer case
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
sd = torch.nn.Linear(3, 5).state_dict()
with file_creation_func() as f:
try:
torch.serialization.set_crc32_options(compute_crc32)
torch.save(sd, f)
if not filename:
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd)
finally:
torch.serialization.set_crc32_options(True)
args = () if compute_crc32 else (zipfile.BadZipFile, "Bad CRC-32 for file")
ctx = contextlib.nullcontext if compute_crc32 else self.assertRaisesRegex
if not filename:
f.seek(0)
# zip_file.extractall() will raise BadZipFile if CRC32 is not populated
# we use the context manager to check whether CRC32 was populated
with ctx(*args), tempfile.TemporaryDirectory() as temp_dir:
with zipfile.ZipFile(f) as zip_file:
zip_file.extractall(path=temp_dir)
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)