mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Make record/storage alignment in torch.save configurable (#147788)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787
This commit is contained in:
committed by
PyTorch MergeBot
parent
209977e6e5
commit
be0ceee1c3
@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]:
|
||||
return config.load.mmap_flags
|
||||
|
||||
|
||||
def _get_storage_alignment() -> int:
|
||||
"""
|
||||
Gets alignment for storages in torch.save files/
|
||||
|
||||
Defaults to 64.
|
||||
|
||||
Returns:
|
||||
storage_alginment: int
|
||||
"""
|
||||
from torch.utils.serialization import config
|
||||
|
||||
return config.save.storage_alignment
|
||||
|
||||
|
||||
class set_default_mmap_options:
|
||||
"""
|
||||
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||
@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||
# for writing out the file.
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
|
||||
torch._C.PyTorchFileWriter(
|
||||
self.file_stream, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
else:
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(
|
||||
self.name, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
|
||||
raise AttributeError(msg)
|
||||
raise TypeError(msg)
|
||||
self.buffer = buffer
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(
|
||||
buffer, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -1188,7 +1212,13 @@ def _save(
|
||||
# .format_version is used to track
|
||||
# 1. version 1 represents the order of storages being changed from
|
||||
# lexicographical based on keys to numerically ordered based on keys
|
||||
# 2. version 2 represents including storage_alignment as a record
|
||||
# within the zipfile
|
||||
zip_file.write_record(".format_version", "1", len("1"))
|
||||
storage_alignment = str(_get_storage_alignment())
|
||||
zip_file.write_record(
|
||||
".storage_alignment", storage_alignment, len(storage_alignment)
|
||||
)
|
||||
|
||||
# Write byte order marker
|
||||
if not _disable_byteorder_record:
|
||||
@ -1886,6 +1916,10 @@ def _load(
|
||||
else:
|
||||
raise ValueError("Invalid load endianness type")
|
||||
|
||||
storage_alignment = 64
|
||||
if zip_file.has_record(".storage_alignment"):
|
||||
storage_alignment = int(zip_file.get_record(".storage_alignment"))
|
||||
|
||||
if (
|
||||
not zip_file.has_record(byteordername)
|
||||
and get_default_load_endianness() is None
|
||||
@ -1939,7 +1973,7 @@ def _load(
|
||||
storage_offset = current_offset
|
||||
else:
|
||||
storage_offset = zip_file.get_record_offset_no_read(
|
||||
current_offset, name, numel
|
||||
current_offset, name, numel, storage_alignment
|
||||
)
|
||||
local_header_offset = current_offset
|
||||
|
||||
|
Reference in New Issue
Block a user