mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Make record/storage alignment in torch.save configurable (#147788)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787
This commit is contained in:
committed by
PyTorch MergeBot
parent
209977e6e5
commit
be0ceee1c3
@ -252,7 +252,11 @@ constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
|
||||
|
||||
namespace detail {
|
||||
|
||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) {
|
||||
std::tuple<size_t, size_t> getOffset(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
uint64_t alignment) {
|
||||
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
|
||||
sizeof(mz_uint16) * 2;
|
||||
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
||||
@ -264,8 +268,8 @@ std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t
|
||||
start += sizeof(mz_uint64);
|
||||
}
|
||||
}
|
||||
size_t mod = start % kFieldAlignment;
|
||||
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
||||
size_t mod = start % alignment;
|
||||
size_t next_offset = (mod == 0) ? start : (start + alignment - mod);
|
||||
std::tuple<size_t, size_t> result(next_offset, start);
|
||||
return result;
|
||||
}
|
||||
@ -274,8 +278,9 @@ size_t getPadding(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf) {
|
||||
auto [next_offset, start] = getOffset(cursor, filename_size, size);
|
||||
std::string& padding_buf,
|
||||
uint64_t alignment) {
|
||||
auto [next_offset, start] = getOffset(cursor, filename_size, size, alignment);
|
||||
size_t padding_size = next_offset - start;
|
||||
size_t padding_size_plus_fbxx = padding_size + 4;
|
||||
if (padding_buf.size() < padding_size_plus_fbxx) {
|
||||
@ -410,8 +415,7 @@ size_t PyTorchStreamReader::getRecordMultiReaders(
|
||||
}
|
||||
readSizes[i] = size;
|
||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
|
||||
<< "] "
|
||||
<< "from " << name << " of size " << n;
|
||||
<< "] " << "from " << name << " of size " << n;
|
||||
TORCH_CHECK(
|
||||
threadReadSize == size,
|
||||
"record size ",
|
||||
@ -629,10 +633,12 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
||||
size_t PyTorchStreamReader::getRecordOffsetNoRead(
|
||||
size_t cursor,
|
||||
std::string filename,
|
||||
size_t size) {
|
||||
size_t size,
|
||||
uint64_t alignment) {
|
||||
std::string full_name = archive_name_plus_slash_ + filename;
|
||||
size_t full_name_size = full_name.size();
|
||||
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size);
|
||||
std::tuple<size_t, size_t> result =
|
||||
detail::getOffset(cursor, full_name_size, size, alignment);
|
||||
size_t offset = std::get<0>(result);
|
||||
return offset;
|
||||
}
|
||||
@ -673,17 +679,22 @@ size_t ostream_write_func(
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::string& file_name,
|
||||
bool compute_crc32)
|
||||
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
|
||||
bool compute_crc32,
|
||||
uint64_t alignment)
|
||||
: archive_name_(basename(file_name)),
|
||||
compute_crc32_(compute_crc32),
|
||||
alignment_(alignment) {
|
||||
setup(file_name);
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
bool compute_crc32)
|
||||
bool compute_crc32,
|
||||
uint64_t alignment)
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func),
|
||||
compute_crc32_(compute_crc32) {
|
||||
compute_crc32_(compute_crc32),
|
||||
alignment_(alignment) {
|
||||
setup(archive_name_);
|
||||
}
|
||||
|
||||
@ -748,8 +759,12 @@ void PyTorchStreamWriter::writeRecord(
|
||||
return;
|
||||
}
|
||||
std::string full_name = archive_name_plus_slash_ + name;
|
||||
size_t padding_size =
|
||||
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
||||
size_t padding_size = detail::getPadding(
|
||||
ar_->m_archive_size,
|
||||
full_name.size(),
|
||||
size,
|
||||
padding_,
|
||||
alignment_);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
if (!compute_crc32_) {
|
||||
#if (!defined(FBCODE_CAFFE2))
|
||||
|
@ -174,8 +174,11 @@ class TORCH_API PyTorchStreamReader final {
|
||||
size_t getRecordSize(const std::string& name);
|
||||
size_t getRecordHeaderOffset(const std::string& name);
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
size_t
|
||||
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size);
|
||||
size_t getRecordOffsetNoRead(
|
||||
size_t cursor,
|
||||
std::string filename,
|
||||
size_t size,
|
||||
uint64_t alignment);
|
||||
bool hasRecord(const std::string& name);
|
||||
std::vector<std::string> getAllRecords();
|
||||
|
||||
@ -222,10 +225,12 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
public:
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::string& archive_name,
|
||||
bool compute_crc32 = true);
|
||||
bool compute_crc32 = true,
|
||||
uint64_t alignment = 64);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func,
|
||||
bool compute_crc32 = true);
|
||||
bool compute_crc32 = true,
|
||||
uint64_t alignment = 64);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
@ -267,6 +272,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
uint64_t combined_uncomp_crc32_ = 0;
|
||||
std::string serialization_id_;
|
||||
bool compute_crc32_;
|
||||
uint64_t alignment_;
|
||||
|
||||
// This number will be updated when the model has operators
|
||||
// that have valid upgraders.
|
||||
@ -281,8 +287,6 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Writer-specific constants
|
||||
constexpr uint64_t kFieldAlignment = 64;
|
||||
|
||||
// Returns a record to be appended to the local user extra data entry in order
|
||||
// to make data beginning aligned at kFieldAlignment bytes boundary.
|
||||
@ -290,9 +294,11 @@ size_t getPadding(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf);
|
||||
std::string& padding_buf,
|
||||
uint64_t alignment);
|
||||
|
||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size);
|
||||
std::tuple<size_t, size_t>
|
||||
getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
|
@ -505,6 +505,7 @@ Config
|
||||
See :func:`~torch.serialization.set_crc32_options`.
|
||||
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
|
||||
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
|
||||
* ``storage_alignment``: alignment of storages in the checkpoint during ``torch.save`` in bytes. (Default ``64``)
|
||||
|
||||
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
|
||||
|
||||
|
@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertTrue(opened_zipfile.has_record(".format_version"))
|
||||
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
|
||||
|
||||
def test_storage_alignment(self):
|
||||
sd = torch.nn.Linear(10, 10).state_dict()
|
||||
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with FakeTensorMode():
|
||||
sd_fake = torch.load(f)
|
||||
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832)
|
||||
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344)
|
||||
|
||||
storage_alignment_before = serialization_config.save.storage_alignment
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
try:
|
||||
serialization_config.save.storage_alignment = 4096
|
||||
torch.save(sd, f)
|
||||
f.seek(0)
|
||||
with FakeTensorMode():
|
||||
sd_fake = torch.load(f)
|
||||
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480)
|
||||
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576)
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f)
|
||||
self.assertEqual(sd_loaded, sd)
|
||||
finally:
|
||||
serialization_config.save.storage_alignment = storage_alignment_before
|
||||
|
||||
|
||||
@parametrize('path_type', (str, Path))
|
||||
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
|
||||
def test_mmap_load_offset_calculation(self, path_type):
|
||||
|
@ -1486,9 +1486,9 @@ class PyTorchFileReader:
|
||||
|
||||
class PyTorchFileWriter:
|
||||
@overload
|
||||
def __init__(self, name: str, compute_crc32 = True) -> None: ...
|
||||
def __init__(self, name: str, compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ...
|
||||
def __init__(self, buffer: IO[bytes], compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
|
||||
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
|
@ -1390,11 +1390,14 @@ void initJITBindings(PyObject* module) {
|
||||
|
||||
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
|
||||
.def(
|
||||
py::init<std::string, bool>(),
|
||||
py::init<std::string, bool, uint64_t>(),
|
||||
py::arg("file_name"),
|
||||
py::arg("compute_crc32") = true)
|
||||
py::arg("compute_crc32") = true,
|
||||
py::arg("storage_alignment") = 64)
|
||||
.def(
|
||||
py::init([](const py::object& buffer, bool compute_crc32 = true) {
|
||||
py::init([](const py::object& buffer,
|
||||
bool compute_crc32 = true,
|
||||
uint64_t storage_alignment = 64) {
|
||||
auto writer_func = [=](const void* data, size_t size) {
|
||||
// Writing an empty file is a noop
|
||||
if (size == 0) {
|
||||
@ -1413,14 +1416,19 @@ void initJITBindings(PyObject* module) {
|
||||
return size;
|
||||
};
|
||||
return std::make_unique<PyTorchStreamWriter>(
|
||||
std::move(writer_func), compute_crc32);
|
||||
std::move(writer_func), compute_crc32, storage_alignment);
|
||||
}),
|
||||
py::arg("buffer"),
|
||||
py::arg("compute_crc32") = true)
|
||||
py::arg("compute_crc32") = true,
|
||||
py::arg("storage_alignment") = 64)
|
||||
.def(
|
||||
py::init<const std::function<size_t(const void*, size_t)>&, bool>(),
|
||||
py::init<
|
||||
const std::function<size_t(const void*, size_t)>&,
|
||||
bool,
|
||||
uint64_t>(),
|
||||
py::arg("writer_func"),
|
||||
py::arg("compute_crc32") = true)
|
||||
py::arg("compute_crc32") = true,
|
||||
py::arg("storage_alignment") = 64)
|
||||
// [Note: write_record_metadata]
|
||||
// The write_record_metadata function is intended to write metadata (i.e.
|
||||
// the zipfile header and end of central directory record) for a file
|
||||
@ -1630,9 +1638,10 @@ void initJITBindings(PyObject* module) {
|
||||
[](PyTorchStreamReader& self,
|
||||
size_t zipfile_header_offset,
|
||||
const std::string filename,
|
||||
size_t size) {
|
||||
size_t size,
|
||||
uint64_t storage_alignment) {
|
||||
return self.getRecordOffsetNoRead(
|
||||
zipfile_header_offset, filename, size);
|
||||
zipfile_header_offset, filename, size, storage_alignment);
|
||||
});
|
||||
|
||||
// Used by torch.Package to coordinate deserialization of storages across
|
||||
|
@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]:
|
||||
return config.load.mmap_flags
|
||||
|
||||
|
||||
def _get_storage_alignment() -> int:
|
||||
"""
|
||||
Gets alignment for storages in torch.save files/
|
||||
|
||||
Defaults to 64.
|
||||
|
||||
Returns:
|
||||
storage_alginment: int
|
||||
"""
|
||||
from torch.utils.serialization import config
|
||||
|
||||
return config.save.storage_alignment
|
||||
|
||||
|
||||
class set_default_mmap_options:
|
||||
"""
|
||||
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||
@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||
# for writing out the file.
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
|
||||
torch._C.PyTorchFileWriter(
|
||||
self.file_stream, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
else:
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(
|
||||
self.name, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
|
||||
raise AttributeError(msg)
|
||||
raise TypeError(msg)
|
||||
self.buffer = buffer
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(
|
||||
buffer, get_crc32_options(), _get_storage_alignment()
|
||||
)
|
||||
)
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -1188,7 +1212,13 @@ def _save(
|
||||
# .format_version is used to track
|
||||
# 1. version 1 represents the order of storages being changed from
|
||||
# lexicographical based on keys to numerically ordered based on keys
|
||||
# 2. version 2 represents including storage_alignment as a record
|
||||
# within the zipfile
|
||||
zip_file.write_record(".format_version", "1", len("1"))
|
||||
storage_alignment = str(_get_storage_alignment())
|
||||
zip_file.write_record(
|
||||
".storage_alignment", storage_alignment, len(storage_alignment)
|
||||
)
|
||||
|
||||
# Write byte order marker
|
||||
if not _disable_byteorder_record:
|
||||
@ -1886,6 +1916,10 @@ def _load(
|
||||
else:
|
||||
raise ValueError("Invalid load endianness type")
|
||||
|
||||
storage_alignment = 64
|
||||
if zip_file.has_record(".storage_alignment"):
|
||||
storage_alignment = int(zip_file.get_record(".storage_alignment"))
|
||||
|
||||
if (
|
||||
not zip_file.has_record(byteordername)
|
||||
and get_default_load_endianness() is None
|
||||
@ -1939,7 +1973,7 @@ def _load(
|
||||
storage_offset = current_offset
|
||||
else:
|
||||
storage_offset = zip_file.get_record_offset_no_read(
|
||||
current_offset, name, numel
|
||||
current_offset, name, numel, storage_alignment
|
||||
)
|
||||
local_header_offset = current_offset
|
||||
|
||||
|
@ -19,6 +19,7 @@ class load:
|
||||
class save:
|
||||
compute_crc32: bool = True
|
||||
use_pinned_memory_for_d2h: bool = False
|
||||
storage_alignment: int = 64
|
||||
|
||||
|
||||
_install_config_module(sys.modules[__name__])
|
||||
|
Reference in New Issue
Block a user