Make record/storage alignment in torch.save configurable (#147788)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788
Approved by: https://github.com/albanD
ghstack dependencies: #147786, #147787
This commit is contained in:
Mikayla Gawarecki
2025-03-06 08:50:55 +00:00
committed by PyTorch MergeBot
parent 209977e6e5
commit be0ceee1c3
8 changed files with 132 additions and 38 deletions

View File

@ -252,7 +252,11 @@ constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
namespace detail { namespace detail {
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) { std::tuple<size_t, size_t> getOffset(
size_t cursor,
size_t filename_size,
size_t size,
uint64_t alignment) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size + size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
sizeof(mz_uint16) * 2; sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) { if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
@ -264,8 +268,8 @@ std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t
start += sizeof(mz_uint64); start += sizeof(mz_uint64);
} }
} }
size_t mod = start % kFieldAlignment; size_t mod = start % alignment;
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod); size_t next_offset = (mod == 0) ? start : (start + alignment - mod);
std::tuple<size_t, size_t> result(next_offset, start); std::tuple<size_t, size_t> result(next_offset, start);
return result; return result;
} }
@ -274,8 +278,9 @@ size_t getPadding(
size_t cursor, size_t cursor,
size_t filename_size, size_t filename_size,
size_t size, size_t size,
std::string& padding_buf) { std::string& padding_buf,
auto [next_offset, start] = getOffset(cursor, filename_size, size); uint64_t alignment) {
auto [next_offset, start] = getOffset(cursor, filename_size, size, alignment);
size_t padding_size = next_offset - start; size_t padding_size = next_offset - start;
size_t padding_size_plus_fbxx = padding_size + 4; size_t padding_size_plus_fbxx = padding_size + 4;
if (padding_buf.size() < padding_size_plus_fbxx) { if (padding_buf.size() < padding_size_plus_fbxx) {
@ -410,8 +415,7 @@ size_t PyTorchStreamReader::getRecordMultiReaders(
} }
readSizes[i] = size; readSizes[i] = size;
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
<< "] " << "] " << "from " << name << " of size " << n;
<< "from " << name << " of size " << n;
TORCH_CHECK( TORCH_CHECK(
threadReadSize == size, threadReadSize == size,
"record size ", "record size ",
@ -629,10 +633,12 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
size_t PyTorchStreamReader::getRecordOffsetNoRead( size_t PyTorchStreamReader::getRecordOffsetNoRead(
size_t cursor, size_t cursor,
std::string filename, std::string filename,
size_t size) { size_t size,
uint64_t alignment) {
std::string full_name = archive_name_plus_slash_ + filename; std::string full_name = archive_name_plus_slash_ + filename;
size_t full_name_size = full_name.size(); size_t full_name_size = full_name.size();
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size); std::tuple<size_t, size_t> result =
detail::getOffset(cursor, full_name_size, size, alignment);
size_t offset = std::get<0>(result); size_t offset = std::get<0>(result);
return offset; return offset;
} }
@ -673,17 +679,22 @@ size_t ostream_write_func(
PyTorchStreamWriter::PyTorchStreamWriter( PyTorchStreamWriter::PyTorchStreamWriter(
const std::string& file_name, const std::string& file_name,
bool compute_crc32) bool compute_crc32,
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) { uint64_t alignment)
: archive_name_(basename(file_name)),
compute_crc32_(compute_crc32),
alignment_(alignment) {
setup(file_name); setup(file_name);
} }
PyTorchStreamWriter::PyTorchStreamWriter( PyTorchStreamWriter::PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32) bool compute_crc32,
uint64_t alignment)
: archive_name_("archive"), : archive_name_("archive"),
writer_func_(writer_func), writer_func_(writer_func),
compute_crc32_(compute_crc32) { compute_crc32_(compute_crc32),
alignment_(alignment) {
setup(archive_name_); setup(archive_name_);
} }
@ -748,8 +759,12 @@ void PyTorchStreamWriter::writeRecord(
return; return;
} }
std::string full_name = archive_name_plus_slash_ + name; std::string full_name = archive_name_plus_slash_ + name;
size_t padding_size = size_t padding_size = detail::getPadding(
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); ar_->m_archive_size,
full_name.size(),
size,
padding_,
alignment_);
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
if (!compute_crc32_) { if (!compute_crc32_) {
#if (!defined(FBCODE_CAFFE2)) #if (!defined(FBCODE_CAFFE2))

View File

@ -174,8 +174,11 @@ class TORCH_API PyTorchStreamReader final {
size_t getRecordSize(const std::string& name); size_t getRecordSize(const std::string& name);
size_t getRecordHeaderOffset(const std::string& name); size_t getRecordHeaderOffset(const std::string& name);
size_t getRecordOffset(const std::string& name); size_t getRecordOffset(const std::string& name);
size_t size_t getRecordOffsetNoRead(
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size); size_t cursor,
std::string filename,
size_t size,
uint64_t alignment);
bool hasRecord(const std::string& name); bool hasRecord(const std::string& name);
std::vector<std::string> getAllRecords(); std::vector<std::string> getAllRecords();
@ -222,10 +225,12 @@ class TORCH_API PyTorchStreamWriter final {
public: public:
explicit PyTorchStreamWriter( explicit PyTorchStreamWriter(
const std::string& archive_name, const std::string& archive_name,
bool compute_crc32 = true); bool compute_crc32 = true,
uint64_t alignment = 64);
explicit PyTorchStreamWriter( explicit PyTorchStreamWriter(
const std::function<size_t(const void*, size_t)> writer_func, const std::function<size_t(const void*, size_t)> writer_func,
bool compute_crc32 = true); bool compute_crc32 = true,
uint64_t alignment = 64);
void setMinVersion(const uint64_t version); void setMinVersion(const uint64_t version);
@ -267,6 +272,7 @@ class TORCH_API PyTorchStreamWriter final {
uint64_t combined_uncomp_crc32_ = 0; uint64_t combined_uncomp_crc32_ = 0;
std::string serialization_id_; std::string serialization_id_;
bool compute_crc32_; bool compute_crc32_;
uint64_t alignment_;
// This number will be updated when the model has operators // This number will be updated when the model has operators
// that have valid upgraders. // that have valid upgraders.
@ -281,8 +287,6 @@ class TORCH_API PyTorchStreamWriter final {
}; };
namespace detail { namespace detail {
// Writer-specific constants
constexpr uint64_t kFieldAlignment = 64;
// Returns a record to be appended to the local user extra data entry in order // Returns a record to be appended to the local user extra data entry in order
// to make data beginning aligned at kFieldAlignment bytes boundary. // to make data beginning aligned at kFieldAlignment bytes boundary.
@ -290,9 +294,11 @@ size_t getPadding(
size_t cursor, size_t cursor,
size_t filename_size, size_t filename_size,
size_t size, size_t size,
std::string& padding_buf); std::string& padding_buf,
uint64_t alignment);
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size); std::tuple<size_t, size_t>
getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment);
} // namespace detail } // namespace detail

View File

@ -505,6 +505,7 @@ Config
See :func:`~torch.serialization.set_crc32_options`. See :func:`~torch.serialization.set_crc32_options`.
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to * ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable)) move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
* ``storage_alignment``: alignment of storages in the checkpoint during ``torch.save`` in bytes. (Default ``64``)
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``. ``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.

View File

@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertTrue(opened_zipfile.has_record(".format_version")) self.assertTrue(opened_zipfile.has_record(".format_version"))
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1') self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
def test_storage_alignment(self):
sd = torch.nn.Linear(10, 10).state_dict()
with tempfile.NamedTemporaryFile() as f:
torch.save(sd, f)
f.seek(0)
with FakeTensorMode():
sd_fake = torch.load(f)
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832)
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344)
storage_alignment_before = serialization_config.save.storage_alignment
with tempfile.NamedTemporaryFile() as f:
try:
serialization_config.save.storage_alignment = 4096
torch.save(sd, f)
f.seek(0)
with FakeTensorMode():
sd_fake = torch.load(f)
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480)
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576)
f.seek(0)
sd_loaded = torch.load(f)
self.assertEqual(sd_loaded, sd)
finally:
serialization_config.save.storage_alignment = storage_alignment_before
@parametrize('path_type', (str, Path)) @parametrize('path_type', (str, Path))
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows") @unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
def test_mmap_load_offset_calculation(self, path_type): def test_mmap_load_offset_calculation(self, path_type):

View File

@ -1486,9 +1486,9 @@ class PyTorchFileReader:
class PyTorchFileWriter: class PyTorchFileWriter:
@overload @overload
def __init__(self, name: str, compute_crc32 = True) -> None: ... def __init__(self, name: str, compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
@overload @overload
def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ... def __init__(self, buffer: IO[bytes], compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
def write_end_of_file(self) -> None: ... def write_end_of_file(self) -> None: ...
def set_min_version(self, version: _int) -> None: ... def set_min_version(self, version: _int) -> None: ...

View File

@ -1390,11 +1390,14 @@ void initJITBindings(PyObject* module) {
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter") py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
.def( .def(
py::init<std::string, bool>(), py::init<std::string, bool, uint64_t>(),
py::arg("file_name"), py::arg("file_name"),
py::arg("compute_crc32") = true) py::arg("compute_crc32") = true,
py::arg("storage_alignment") = 64)
.def( .def(
py::init([](const py::object& buffer, bool compute_crc32 = true) { py::init([](const py::object& buffer,
bool compute_crc32 = true,
uint64_t storage_alignment = 64) {
auto writer_func = [=](const void* data, size_t size) { auto writer_func = [=](const void* data, size_t size) {
// Writing an empty file is a noop // Writing an empty file is a noop
if (size == 0) { if (size == 0) {
@ -1413,14 +1416,19 @@ void initJITBindings(PyObject* module) {
return size; return size;
}; };
return std::make_unique<PyTorchStreamWriter>( return std::make_unique<PyTorchStreamWriter>(
std::move(writer_func), compute_crc32); std::move(writer_func), compute_crc32, storage_alignment);
}), }),
py::arg("buffer"), py::arg("buffer"),
py::arg("compute_crc32") = true) py::arg("compute_crc32") = true,
py::arg("storage_alignment") = 64)
.def( .def(
py::init<const std::function<size_t(const void*, size_t)>&, bool>(), py::init<
const std::function<size_t(const void*, size_t)>&,
bool,
uint64_t>(),
py::arg("writer_func"), py::arg("writer_func"),
py::arg("compute_crc32") = true) py::arg("compute_crc32") = true,
py::arg("storage_alignment") = 64)
// [Note: write_record_metadata] // [Note: write_record_metadata]
// The write_record_metadata function is intended to write metadata (i.e. // The write_record_metadata function is intended to write metadata (i.e.
// the zipfile header and end of central directory record) for a file // the zipfile header and end of central directory record) for a file
@ -1630,9 +1638,10 @@ void initJITBindings(PyObject* module) {
[](PyTorchStreamReader& self, [](PyTorchStreamReader& self,
size_t zipfile_header_offset, size_t zipfile_header_offset,
const std::string filename, const std::string filename,
size_t size) { size_t size,
uint64_t storage_alignment) {
return self.getRecordOffsetNoRead( return self.getRecordOffsetNoRead(
zipfile_header_offset, filename, size); zipfile_header_offset, filename, size, storage_alignment);
}); });
// Used by torch.Package to coordinate deserialization of storages across // Used by torch.Package to coordinate deserialization of storages across

View File

@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]:
return config.load.mmap_flags return config.load.mmap_flags
def _get_storage_alignment() -> int:
"""
Gets alignment for storages in torch.save files/
Defaults to 64.
Returns:
storage_alginment: int
"""
from torch.utils.serialization import config
return config.save.storage_alignment
class set_default_mmap_options: class set_default_mmap_options:
""" """
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags. Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
# for writing out the file. # for writing out the file.
self.file_stream = io.FileIO(self.name, mode="w") self.file_stream = io.FileIO(self.name, mode="w")
super().__init__( super().__init__(
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options()) torch._C.PyTorchFileWriter(
self.file_stream, get_crc32_options(), _get_storage_alignment()
)
) )
else: else:
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options())) super().__init__(
torch._C.PyTorchFileWriter(
self.name, get_crc32_options(), _get_storage_alignment()
)
)
def __exit__(self, *args) -> None: def __exit__(self, *args) -> None:
self.file_like.write_end_of_file() self.file_like.write_end_of_file()
@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
raise AttributeError(msg) raise AttributeError(msg)
raise TypeError(msg) raise TypeError(msg)
self.buffer = buffer self.buffer = buffer
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options())) super().__init__(
torch._C.PyTorchFileWriter(
buffer, get_crc32_options(), _get_storage_alignment()
)
)
def __exit__(self, *args) -> None: def __exit__(self, *args) -> None:
self.file_like.write_end_of_file() self.file_like.write_end_of_file()
@ -1188,7 +1212,13 @@ def _save(
# .format_version is used to track # .format_version is used to track
# 1. version 1 represents the order of storages being changed from # 1. version 1 represents the order of storages being changed from
# lexicographical based on keys to numerically ordered based on keys # lexicographical based on keys to numerically ordered based on keys
# 2. version 2 represents including storage_alignment as a record
# within the zipfile
zip_file.write_record(".format_version", "1", len("1")) zip_file.write_record(".format_version", "1", len("1"))
storage_alignment = str(_get_storage_alignment())
zip_file.write_record(
".storage_alignment", storage_alignment, len(storage_alignment)
)
# Write byte order marker # Write byte order marker
if not _disable_byteorder_record: if not _disable_byteorder_record:
@ -1886,6 +1916,10 @@ def _load(
else: else:
raise ValueError("Invalid load endianness type") raise ValueError("Invalid load endianness type")
storage_alignment = 64
if zip_file.has_record(".storage_alignment"):
storage_alignment = int(zip_file.get_record(".storage_alignment"))
if ( if (
not zip_file.has_record(byteordername) not zip_file.has_record(byteordername)
and get_default_load_endianness() is None and get_default_load_endianness() is None
@ -1939,7 +1973,7 @@ def _load(
storage_offset = current_offset storage_offset = current_offset
else: else:
storage_offset = zip_file.get_record_offset_no_read( storage_offset = zip_file.get_record_offset_no_read(
current_offset, name, numel current_offset, name, numel, storage_alignment
) )
local_header_offset = current_offset local_header_offset = current_offset

View File

@ -19,6 +19,7 @@ class load:
class save: class save:
compute_crc32: bool = True compute_crc32: bool = True
use_pinned_memory_for_d2h: bool = False use_pinned_memory_for_d2h: bool = False
storage_alignment: int = 64
_install_config_module(sys.modules[__name__]) _install_config_module(sys.modules[__name__])