reimport pr137735 due to merging check issues (#138959)

This is  a cherry-pick from #137735 by @mikaylagawarecki , that cannot be merged due to a (wrongly) failing check for codev

@diff-train-skip-merge

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138959
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Wouter Devriendt
2024-10-27 16:31:34 +00:00
committed by PyTorch MergeBot
parent 144d75d934
commit bae3426af7
10 changed files with 119 additions and 36 deletions

View File

@ -53,6 +53,8 @@ __all__ = [
"load",
"StorageType",
"LoadEndianness",
"get_crc32_options",
"set_crc32_options",
"get_default_load_endianness",
"set_default_load_endianness",
"get_default_mmap_options",
@ -167,6 +169,34 @@ def set_default_load_endianness(endianness):
_default_load_endian = endianness
_compute_crc32: bool = True
def get_crc32_options() -> bool:
"""
Get whether :func:`torch.save` computes and writes crc32 for each record.
Defaults to ``True``.
"""
return _compute_crc32
def set_crc32_options(compute_crc32: bool):
"""
Set whether :func:`torch.save` computes and writes crc32 for each record.
.. note::
Setting this to ``False`` may make unzipping of the ``torch.save`` output
fail or warn due to corrupted CRC32. However ``torch.load`` will be
able to load the file.
Args:
compute_crc32 (bool): set crc32 compuation flag
"""
global _compute_crc32
_compute_crc32 = compute_crc32
_default_mmap_options: int = MAP_PRIVATE
@ -682,9 +712,11 @@ class _open_zipfile_writer_file(_opener):
# For filenames with non-ascii characters, we rely on Python
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
super().__init__(
torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)
)
else:
super().__init__(torch._C.PyTorchFileWriter(self.name))
super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -700,7 +732,7 @@ class _open_zipfile_writer_buffer(_opener):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer))
super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()