From be0ceee1c3740bde65981d428a463cea61d88f5f Mon Sep 17 00:00:00 2001 From: Mikayla Gawarecki Date: Thu, 6 Mar 2025 08:50:55 +0000 Subject: [PATCH] Make record/storage alignment in torch.save configurable (#147788) Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787 --- caffe2/serialize/inline_container.cc | 45 ++++++++++++++++++---------- caffe2/serialize/inline_container.h | 22 +++++++++----- docs/source/notes/serialization.rst | 1 + test/test_serialization.py | 28 +++++++++++++++++ torch/_C/__init__.pyi.in | 4 +-- torch/csrc/jit/python/init.cpp | 27 +++++++++++------ torch/serialization.py | 42 +++++++++++++++++++++++--- torch/utils/serialization/config.py | 1 + 8 files changed, 132 insertions(+), 38 deletions(-) diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 2b8545af9f8e..4972c6518cfc 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -252,7 +252,11 @@ constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50; namespace detail { -std::tuple getOffset(size_t cursor, size_t filename_size, size_t size) { +std::tuple getOffset( + size_t cursor, + size_t filename_size, + size_t size, + uint64_t alignment) { size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size + sizeof(mz_uint16) * 2; if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { @@ -264,8 +268,8 @@ std::tuple getOffset(size_t cursor, size_t filename_size, size_t start += sizeof(mz_uint64); } } - size_t mod = start % kFieldAlignment; - size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); + size_t mod = start % alignment; + size_t next_offset = (mod == 0) ? start : (start + alignment - mod); std::tuple result(next_offset, start); return result; } @@ -274,8 +278,9 @@ size_t getPadding( size_t cursor, size_t filename_size, size_t size, - std::string& padding_buf) { - auto [next_offset, start] = getOffset(cursor, filename_size, size); + std::string& padding_buf, + uint64_t alignment) { + auto [next_offset, start] = getOffset(cursor, filename_size, size, alignment); size_t padding_size = next_offset - start; size_t padding_size_plus_fbxx = padding_size + 4; if (padding_buf.size() < padding_size_plus_fbxx) { @@ -410,8 +415,7 @@ size_t PyTorchStreamReader::getRecordMultiReaders( } readSizes[i] = size; LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos - << "] " - << "from " << name << " of size " << n; + << "] " << "from " << name << " of size " << n; TORCH_CHECK( threadReadSize == size, "record size ", @@ -629,10 +633,12 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) { size_t PyTorchStreamReader::getRecordOffsetNoRead( size_t cursor, std::string filename, - size_t size) { + size_t size, + uint64_t alignment) { std::string full_name = archive_name_plus_slash_ + filename; size_t full_name_size = full_name.size(); - std::tuple result = detail::getOffset(cursor, full_name_size, size); + std::tuple result = + detail::getOffset(cursor, full_name_size, size, alignment); size_t offset = std::get<0>(result); return offset; } @@ -673,17 +679,22 @@ size_t ostream_write_func( PyTorchStreamWriter::PyTorchStreamWriter( const std::string& file_name, - bool compute_crc32) - : archive_name_(basename(file_name)), compute_crc32_(compute_crc32) { + bool compute_crc32, + uint64_t alignment) + : archive_name_(basename(file_name)), + compute_crc32_(compute_crc32), + alignment_(alignment) { setup(file_name); } PyTorchStreamWriter::PyTorchStreamWriter( const std::function writer_func, - bool compute_crc32) + bool compute_crc32, + uint64_t alignment) : archive_name_("archive"), writer_func_(writer_func), - compute_crc32_(compute_crc32) { + compute_crc32_(compute_crc32), + alignment_(alignment) { setup(archive_name_); } @@ -748,8 +759,12 @@ void PyTorchStreamWriter::writeRecord( return; } std::string full_name = archive_name_plus_slash_ + name; - size_t padding_size = - detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); + size_t padding_size = detail::getPadding( + ar_->m_archive_size, + full_name.size(), + size, + padding_, + alignment_); uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; if (!compute_crc32_) { #if (!defined(FBCODE_CAFFE2)) diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 7b183fb0969a..e098bede1420 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -174,8 +174,11 @@ class TORCH_API PyTorchStreamReader final { size_t getRecordSize(const std::string& name); size_t getRecordHeaderOffset(const std::string& name); size_t getRecordOffset(const std::string& name); - size_t - getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size); + size_t getRecordOffsetNoRead( + size_t cursor, + std::string filename, + size_t size, + uint64_t alignment); bool hasRecord(const std::string& name); std::vector getAllRecords(); @@ -222,10 +225,12 @@ class TORCH_API PyTorchStreamWriter final { public: explicit PyTorchStreamWriter( const std::string& archive_name, - bool compute_crc32 = true); + bool compute_crc32 = true, + uint64_t alignment = 64); explicit PyTorchStreamWriter( const std::function writer_func, - bool compute_crc32 = true); + bool compute_crc32 = true, + uint64_t alignment = 64); void setMinVersion(const uint64_t version); @@ -267,6 +272,7 @@ class TORCH_API PyTorchStreamWriter final { uint64_t combined_uncomp_crc32_ = 0; std::string serialization_id_; bool compute_crc32_; + uint64_t alignment_; // This number will be updated when the model has operators // that have valid upgraders. @@ -281,8 +287,6 @@ class TORCH_API PyTorchStreamWriter final { }; namespace detail { -// Writer-specific constants -constexpr uint64_t kFieldAlignment = 64; // Returns a record to be appended to the local user extra data entry in order // to make data beginning aligned at kFieldAlignment bytes boundary. @@ -290,9 +294,11 @@ size_t getPadding( size_t cursor, size_t filename_size, size_t size, - std::string& padding_buf); + std::string& padding_buf, + uint64_t alignment); -std::tuple getOffset(size_t cursor, size_t filename_size, size_t size); +std::tuple +getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment); } // namespace detail diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index b3ba4feb22e1..019865e3b535 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -505,6 +505,7 @@ Config See :func:`~torch.serialization.set_crc32_options`. * ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable)) + * ``storage_alignment``: alignment of storages in the checkpoint during ``torch.save`` in bytes. (Default ``64``) ``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``. diff --git a/test/test_serialization.py b/test/test_serialization.py index 9f9baebb6066..111c206b282a 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin): self.assertTrue(opened_zipfile.has_record(".format_version")) self.assertEqual(opened_zipfile.get_record(".format_version"), b'1') + def test_storage_alignment(self): + sd = torch.nn.Linear(10, 10).state_dict() + + with tempfile.NamedTemporaryFile() as f: + torch.save(sd, f) + f.seek(0) + with FakeTensorMode(): + sd_fake = torch.load(f) + self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832) + self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344) + + storage_alignment_before = serialization_config.save.storage_alignment + with tempfile.NamedTemporaryFile() as f: + try: + serialization_config.save.storage_alignment = 4096 + torch.save(sd, f) + f.seek(0) + with FakeTensorMode(): + sd_fake = torch.load(f) + self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480) + self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576) + f.seek(0) + sd_loaded = torch.load(f) + self.assertEqual(sd_loaded, sd) + finally: + serialization_config.save.storage_alignment = storage_alignment_before + + @parametrize('path_type', (str, Path)) @unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows") def test_mmap_load_offset_calculation(self, path_type): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index cc3f4c1e219f..bdccb19f88a1 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1486,9 +1486,9 @@ class PyTorchFileReader: class PyTorchFileWriter: @overload - def __init__(self, name: str, compute_crc32 = True) -> None: ... + def __init__(self, name: str, compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ... @overload - def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ... + def __init__(self, buffer: IO[bytes], compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ... def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... def write_end_of_file(self) -> None: ... def set_min_version(self, version: _int) -> None: ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 2ba2094b3f34..5911064b22f2 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1390,11 +1390,14 @@ void initJITBindings(PyObject* module) { py::class_(m, "PyTorchFileWriter") .def( - py::init(), + py::init(), py::arg("file_name"), - py::arg("compute_crc32") = true) + py::arg("compute_crc32") = true, + py::arg("storage_alignment") = 64) .def( - py::init([](const py::object& buffer, bool compute_crc32 = true) { + py::init([](const py::object& buffer, + bool compute_crc32 = true, + uint64_t storage_alignment = 64) { auto writer_func = [=](const void* data, size_t size) { // Writing an empty file is a noop if (size == 0) { @@ -1413,14 +1416,19 @@ void initJITBindings(PyObject* module) { return size; }; return std::make_unique( - std::move(writer_func), compute_crc32); + std::move(writer_func), compute_crc32, storage_alignment); }), py::arg("buffer"), - py::arg("compute_crc32") = true) + py::arg("compute_crc32") = true, + py::arg("storage_alignment") = 64) .def( - py::init&, bool>(), + py::init< + const std::function&, + bool, + uint64_t>(), py::arg("writer_func"), - py::arg("compute_crc32") = true) + py::arg("compute_crc32") = true, + py::arg("storage_alignment") = 64) // [Note: write_record_metadata] // The write_record_metadata function is intended to write metadata (i.e. // the zipfile header and end of central directory record) for a file @@ -1630,9 +1638,10 @@ void initJITBindings(PyObject* module) { [](PyTorchStreamReader& self, size_t zipfile_header_offset, const std::string filename, - size_t size) { + size_t size, + uint64_t storage_alignment) { return self.getRecordOffsetNoRead( - zipfile_header_offset, filename, size); + zipfile_header_offset, filename, size, storage_alignment); }); // Used by torch.Package to coordinate deserialization of storages across diff --git a/torch/serialization.py b/torch/serialization.py index e65f860d969e..5c00be7310eb 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]: return config.load.mmap_flags +def _get_storage_alignment() -> int: + """ + Gets alignment for storages in torch.save files/ + + Defaults to 64. + + Returns: + storage_alginment: int + """ + from torch.utils.serialization import config + + return config.save.storage_alignment + + class set_default_mmap_options: """ Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. @@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): # for writing out the file. self.file_stream = io.FileIO(self.name, mode="w") super().__init__( - torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options()) + torch._C.PyTorchFileWriter( + self.file_stream, get_crc32_options(), _get_storage_alignment() + ) ) else: - super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options())) + super().__init__( + torch._C.PyTorchFileWriter( + self.name, get_crc32_options(), _get_storage_alignment() + ) + ) def __exit__(self, *args) -> None: self.file_like.write_end_of_file() @@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]): raise AttributeError(msg) raise TypeError(msg) self.buffer = buffer - super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options())) + super().__init__( + torch._C.PyTorchFileWriter( + buffer, get_crc32_options(), _get_storage_alignment() + ) + ) def __exit__(self, *args) -> None: self.file_like.write_end_of_file() @@ -1188,7 +1212,13 @@ def _save( # .format_version is used to track # 1. version 1 represents the order of storages being changed from # lexicographical based on keys to numerically ordered based on keys + # 2. version 2 represents including storage_alignment as a record + # within the zipfile zip_file.write_record(".format_version", "1", len("1")) + storage_alignment = str(_get_storage_alignment()) + zip_file.write_record( + ".storage_alignment", storage_alignment, len(storage_alignment) + ) # Write byte order marker if not _disable_byteorder_record: @@ -1886,6 +1916,10 @@ def _load( else: raise ValueError("Invalid load endianness type") + storage_alignment = 64 + if zip_file.has_record(".storage_alignment"): + storage_alignment = int(zip_file.get_record(".storage_alignment")) + if ( not zip_file.has_record(byteordername) and get_default_load_endianness() is None @@ -1939,7 +1973,7 @@ def _load( storage_offset = current_offset else: storage_offset = zip_file.get_record_offset_no_read( - current_offset, name, numel + current_offset, name, numel, storage_alignment ) local_header_offset = current_offset diff --git a/torch/utils/serialization/config.py b/torch/utils/serialization/config.py index 0ef12f77d9d5..0a3fba9f5b82 100644 --- a/torch/utils/serialization/config.py +++ b/torch/utils/serialization/config.py @@ -19,6 +19,7 @@ class load: class save: compute_crc32: bool = True use_pinned_memory_for_d2h: bool = False + storage_alignment: int = 64 _install_config_module(sys.modules[__name__])