Make record/storage alignment in torch.save configurable (#147788)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788
Approved by: https://github.com/albanD
ghstack dependencies: #147786, #147787
This commit is contained in:
Mikayla Gawarecki
2025-03-06 08:50:55 +00:00
committed by PyTorch MergeBot
parent 209977e6e5
commit be0ceee1c3
8 changed files with 132 additions and 38 deletions

View File

@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]:
return config.load.mmap_flags
def _get_storage_alignment() -> int:
"""
Gets alignment for storages in torch.save files/
Defaults to 64.
Returns:
storage_alginment: int
"""
from torch.utils.serialization import config
return config.save.storage_alignment
class set_default_mmap_options:
"""
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
# for writing out the file.
self.file_stream = io.FileIO(self.name, mode="w")
super().__init__(
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
torch._C.PyTorchFileWriter(
self.file_stream, get_crc32_options(), _get_storage_alignment()
)
)
else:
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
super().__init__(
torch._C.PyTorchFileWriter(
self.name, get_crc32_options(), _get_storage_alignment()
)
)
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
raise AttributeError(msg)
raise TypeError(msg)
self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
super().__init__(
torch._C.PyTorchFileWriter(
buffer, get_crc32_options(), _get_storage_alignment()
)
)
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
@ -1188,7 +1212,13 @@ def _save(
# .format_version is used to track
# 1. version 1 represents the order of storages being changed from
# lexicographical based on keys to numerically ordered based on keys
# 2. version 2 represents including storage_alignment as a record
# within the zipfile
zip_file.write_record(".format_version", "1", len("1"))
storage_alignment = str(_get_storage_alignment())
zip_file.write_record(
".storage_alignment", storage_alignment, len(storage_alignment)
)
# Write byte order marker
if not _disable_byteorder_record:
@ -1886,6 +1916,10 @@ def _load(
else:
raise ValueError("Invalid load endianness type")
storage_alignment = 64
if zip_file.has_record(".storage_alignment"):
storage_alignment = int(zip_file.get_record(".storage_alignment"))
if (
not zip_file.has_record(byteordername)
and get_default_load_endianness() is None
@ -1939,7 +1973,7 @@ def _load(
storage_offset = current_offset
else:
storage_offset = zip_file.get_record_offset_no_read(
current_offset, name, numel
current_offset, name, numel, storage_alignment
)
local_header_offset = current_offset