mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Expose option to disable CRC-32 computation during torch.save
(#137735)
Option only works in open source, not internal Pull Request resolved: https://github.com/pytorch/pytorch/pull/137735 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3cc8c8b944
commit
534fa96f2d
@ -4308,6 +4308,35 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
f.seek(0)
|
||||
torch.load(f, weights_only=True)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")
|
||||
@parametrize("compute_crc32", (True, False))
|
||||
@parametrize("filename", (True, False))
|
||||
def test_crc32_options(self, compute_crc32, filename):
|
||||
# test both path and buffer case
|
||||
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
|
||||
sd = torch.nn.Linear(3, 5).state_dict()
|
||||
with file_creation_func() as f:
|
||||
try:
|
||||
torch.serialization.set_crc32_options(compute_crc32)
|
||||
torch.save(sd, f)
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_loaded, sd)
|
||||
finally:
|
||||
torch.serialization.set_crc32_options(True)
|
||||
|
||||
args = () if compute_crc32 else (zipfile.BadZipFile, "Bad CRC-32 for file")
|
||||
ctx = contextlib.nullcontext if compute_crc32 else self.assertRaisesRegex
|
||||
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
# zip_file.extractall() will raise BadZipFile if CRC32 is not populated
|
||||
# we use the context manager to check whether CRC32 was populated
|
||||
with ctx(*args), tempfile.TemporaryDirectory() as temp_dir:
|
||||
with zipfile.ZipFile(f) as zip_file:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user