mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add option to serialization config to reduce random reads from get_record_offset when loading with mmap=True (#143880)
## Background This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`. When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this). The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases.6aaae9d78f/caffe2/serialize/inline_container.cc (L589-L605)
## How does this work The format for the checkpoint is as such ``` archive_name/ |_ data.pkl |_.format_version |_byteorder |_data/ |_ 0 |_ 1 |_ 2 |_ ... |_ ``` Each `data/i` record represents a storage, where storages are written in the order that the Pickler encounters them. For each storage, our `persistent_load` logic saves the following metadata to the pickle file `dtype, numel, key, location` where `numel` is the number of bytes in the storage. Note that we always use `miniz` writer in the zip64 mode per [here](7796e308d0/caffe2/serialize/inline_container.cc (L701)
) A zipfile record written by miniz looks as such ``` ---------------- ----------------- ------------------- ---------------- --------- ------------------------------ | 30 byte header | n byte filename | zip64_extra_data | m byte padding | storage | 16 or 24 byte local dir footer | ---------------- ----------------- ------------------- ---------------- --------- ------------------------------ ``` - The header size (30) is given by [`MZ_ZIP_LOCAL_DIR_HEADER_SIZE`](https://github.com/pytorch/pytorch/blob/main/third_party/miniz-3.0.2/miniz.c?fbclid=IwZXh0bgNhZW0CMTEAAR2O8Vysd--UoSCxW70gabXIS1dbz733oHwuUQ5_Ff1hY2WU6PL2i6CSH4A_aem_J9oaU2HpDeWtJKOU9EnVqw#L3290) - filename will be `"{archive_name}/{filepath}"` - `zip64_extra_data` is determined by [`mz_zip_writer_create_zip64_extra_data`](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6202)
). Note that [we only create zip64_extra_data if storage_size >= 0xFFFFFFFF or the offset of the start of the header >= 0xFFFFFFFF](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6519-L6524)
) - `m` is determined by [`getPadding`](7796e308d0/caffe2/serialize/inline_container.cc (L254)
), which accounts for filename, zip64_extra_data to determine `m` such that the start of `storage` is aligned to 64 bytes. The `m` bytes will always start with `F B padding_size" as the first 4 bytes - The local dir footer size is determined based on [this snippet ](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6610-L6632)
): if the buffer size is 0 it is skipped. If the zip64_extra_data was created, it is 24, otherwise it is 16. When `torch.utils.serialization.config.load.calculate_storage_offsets` is set we do the following - We keep track of where the "cursor" is in the file using `current_offset`, after each persistent_load call, it will be at the offset where the header for the next record starts - for the 0th storage, "data/0", we use the regular get_record_offset to determine the start of the storage - for any other storage, (where the storages will be in order encountered by the unpickler, 0, 1, 2, 3, ...) we use `get_record_offset_no_read`, which re-uses the `getPadding` logic to determine the offset of the storage - Note that `load_tensor` will only ever be called again with the same key if the storage's `._data_ptr()` is 0 [[pointer1](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1917-L1918)][[pointer2](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1936-L1937)], so we cache the offsets for this edge case - After each storage, if the storage is non-zero, we account for the local dir footer based on the logic described above ## Testing strategy The agreed upon testing strategy was as follows: - Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False) - This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested. Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880 Approved by: https://github.com/albanD ghstack dependencies: #143879
This commit is contained in:
committed by
PyTorch MergeBot
parent
98f87edd23
commit
001e355a56
@ -18,6 +18,9 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
|
||||
fi
|
||||
popd
|
||||
|
||||
# enable debug asserts in serialization
|
||||
export TORCH_SERIALIZATION_DEBUG=1
|
||||
|
||||
setup_test_python() {
|
||||
# The CircleCI worker hostname doesn't resolve to an address.
|
||||
# This environment variable makes ProcessGroupGloo default to
|
||||
|
@ -46,6 +46,9 @@ BUILD_BIN_DIR="$BUILD_DIR"/bin
|
||||
SHARD_NUMBER="${SHARD_NUMBER:=1}"
|
||||
NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}"
|
||||
|
||||
# enable debug asserts in serialization
|
||||
export TORCH_SERIALIZATION_DEBUG=1
|
||||
|
||||
export VALGRIND=ON
|
||||
# export TORCH_INDUCTOR_INSTALL_GXX=ON
|
||||
if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then
|
||||
|
@ -18,6 +18,9 @@ export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-result
|
||||
PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}")
|
||||
export PYTORCH_FINAL_PACKAGE_DIR_WIN
|
||||
|
||||
# enable debug asserts in serialization
|
||||
export TORCH_SERIALIZATION_DEBUG=1
|
||||
|
||||
mkdir -p "$TMP_DIR"/build/torch
|
||||
|
||||
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers
|
||||
|
@ -251,11 +251,8 @@ constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
|
||||
constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
|
||||
|
||||
namespace detail {
|
||||
size_t getPadding(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf) {
|
||||
|
||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) {
|
||||
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
|
||||
sizeof(mz_uint16) * 2;
|
||||
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
||||
@ -269,6 +266,16 @@ size_t getPadding(
|
||||
}
|
||||
size_t mod = start % kFieldAlignment;
|
||||
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
||||
std::tuple<size_t, size_t> result(next_offset, start);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t getPadding(
|
||||
size_t cursor,
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf) {
|
||||
auto [next_offset, start] = getOffset(cursor, filename_size, size);
|
||||
size_t padding_size = next_offset - start;
|
||||
size_t padding_size_plus_fbxx = padding_size + 4;
|
||||
if (padding_buf.size() < padding_size_plus_fbxx) {
|
||||
@ -587,6 +594,14 @@ static int64_t read_le_16(uint8_t* buf) {
|
||||
return buf[0] + (buf[1] << 8);
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordHeaderOffset(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
mz_zip_archive_file_stat stat;
|
||||
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
|
||||
valid("retrieving file meta-data for ", name.c_str());
|
||||
return stat.m_local_header_ofs;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
mz_zip_archive_file_stat stat;
|
||||
@ -611,6 +626,17 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
||||
return stat.m_uncomp_size;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordOffsetNoRead(
|
||||
size_t cursor,
|
||||
std::string filename,
|
||||
size_t size) {
|
||||
std::string full_name = archive_name_plus_slash_ + filename;
|
||||
size_t full_name_size = full_name.size();
|
||||
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size);
|
||||
size_t offset = std::get<0>(result);
|
||||
return offset;
|
||||
}
|
||||
|
||||
PyTorchStreamReader::~PyTorchStreamReader() {
|
||||
mz_zip_clear_last_error(ar_.get());
|
||||
mz_zip_reader_end(ar_.get());
|
||||
|
@ -172,8 +172,10 @@ class TORCH_API PyTorchStreamReader final {
|
||||
size_t n);
|
||||
|
||||
size_t getRecordSize(const std::string& name);
|
||||
|
||||
size_t getRecordHeaderOffset(const std::string& name);
|
||||
size_t getRecordOffset(const std::string& name);
|
||||
size_t
|
||||
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size);
|
||||
bool hasRecord(const std::string& name);
|
||||
std::vector<std::string> getAllRecords();
|
||||
|
||||
@ -289,6 +291,9 @@ size_t getPadding(
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf);
|
||||
|
||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace serialize
|
||||
|
@ -515,3 +515,6 @@ Config
|
||||
(Default : ``torch.serialization.LoadEndianness.NATIVE``)
|
||||
* ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
|
||||
(Default : ``MAP_PRIVATE``)
|
||||
* ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be
|
||||
calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes
|
||||
random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``)
|
||||
|
@ -45,6 +45,7 @@ from torch.testing._internal.common_utils import (
|
||||
BytesIOContext,
|
||||
download_file,
|
||||
instantiate_parametrized_tests,
|
||||
IS_CI,
|
||||
IS_FBCODE,
|
||||
IS_FILESYSTEM_UTF8_ENCODING,
|
||||
IS_WINDOWS,
|
||||
@ -827,6 +828,11 @@ class SerializationMixin:
|
||||
loaded_data = torch.load(f, weights_only=True)
|
||||
self.assertEqual(data, loaded_data)
|
||||
|
||||
@unittest.skipIf(not IS_CI, "only check debug var is set in CI")
|
||||
def test_debug_set_in_ci(self):
|
||||
# This test is to make sure that the serialization debug flag is set in CI
|
||||
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
|
||||
|
||||
|
||||
class serialization_method:
|
||||
def __init__(self, use_zip):
|
||||
@ -1041,6 +1047,23 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
f.seek(0)
|
||||
state = torch.load(f)
|
||||
|
||||
@serialTest()
|
||||
def test_serialization_4gb_file(self):
|
||||
'''
|
||||
This is a specially engineered testcase that would fail if the data_descriptor size
|
||||
had been incorrectly set as data_descriptor_size32 when it should be data_descriptor_size64
|
||||
'''
|
||||
# Run GC to clear up as much memory as possible before running this test
|
||||
gc.collect()
|
||||
big_model = torch.nn.ModuleList([torch.nn.Linear(1, int(1024 * 1024 * 1024) + 12, bias=False),
|
||||
torch.nn.Linear(1, 1, bias=False).to(torch.float8_e4m3fn),
|
||||
torch.nn.Linear(1, 2, bias=False).to(torch.float8_e4m3fn)])
|
||||
|
||||
with BytesIOContext() as f:
|
||||
torch.save(big_model.state_dict(), f)
|
||||
f.seek(0)
|
||||
torch.load(f)
|
||||
|
||||
@parametrize('weights_only', (True, False))
|
||||
def test_pathlike_serialization(self, weights_only):
|
||||
model = torch.nn.Conv2d(20, 3200, kernel_size=3)
|
||||
@ -4533,6 +4556,30 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
self.assertTrue(opened_zipfile.has_record(".format_version"))
|
||||
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
|
||||
|
||||
@parametrize('path_type', (str, Path))
|
||||
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
|
||||
def test_mmap_load_offset_calculation(self, path_type):
|
||||
calculate_offsets_before = serialization_config.load.calculate_storage_offsets
|
||||
try:
|
||||
serialization_config.load.calculate_storage_offsets = True
|
||||
m = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
|
||||
|
||||
with TemporaryFileName() as f:
|
||||
f = path_type(f)
|
||||
state_dict = m.state_dict()
|
||||
torch.save(state_dict, f)
|
||||
result = torch.load(f, mmap=True)
|
||||
result_non_mmap = torch.load(f, mmap=False)
|
||||
|
||||
with torch.device("meta"):
|
||||
model_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
|
||||
model_non_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
|
||||
model_mmap_state_dict.load_state_dict(result, assign=True)
|
||||
model_non_mmap_state_dict.load_state_dict(result_non_mmap, assign=True)
|
||||
inp = torch.randn(4, 4)
|
||||
self.assertEqual(model_mmap_state_dict(inp), model_non_mmap_state_dict(inp.clone()))
|
||||
finally:
|
||||
serialization_config.load.calculate_storage_offsets = calculate_offsets_before
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
|
@ -1619,6 +1619,20 @@ void initJITBindings(PyObject* module) {
|
||||
"get_record_offset",
|
||||
[](PyTorchStreamReader& self, const std::string& key) {
|
||||
return self.getRecordOffset(key);
|
||||
})
|
||||
.def(
|
||||
"get_record_header_offset",
|
||||
[](PyTorchStreamReader& self, const std::string& key) {
|
||||
return self.getRecordHeaderOffset(key);
|
||||
})
|
||||
.def(
|
||||
"get_record_offset_no_read",
|
||||
[](PyTorchStreamReader& self,
|
||||
size_t zipfile_header_offset,
|
||||
const std::string filename,
|
||||
size_t size) {
|
||||
return self.getRecordOffsetNoRead(
|
||||
zipfile_header_offset, filename, size);
|
||||
});
|
||||
|
||||
// Used by torch.Package to coordinate deserialization of storages across
|
||||
|
@ -15,7 +15,7 @@ import threading
|
||||
import warnings
|
||||
from contextlib import closing, contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
import torch
|
||||
@ -1856,6 +1856,11 @@ def _load(
|
||||
|
||||
loaded_storages = {}
|
||||
|
||||
can_calculate_storage_offsets = False
|
||||
if zip_file.has_record(".format_version"):
|
||||
version = zip_file.get_record(".format_version")
|
||||
can_calculate_storage_offsets = version >= b"1"
|
||||
|
||||
# check if byteswapping is needed
|
||||
byteordername = "byteorder"
|
||||
byteorderdata = None
|
||||
@ -1891,15 +1896,92 @@ def _load(
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
from torch.utils.serialization import config
|
||||
|
||||
calculate_storage_offsets = config.load.calculate_storage_offsets
|
||||
run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
|
||||
current_offset = None
|
||||
# constants from miniz.h/miniz.c
|
||||
data_descripter_size64 = 24
|
||||
data_descripter_size32 = 16
|
||||
mz_uint32_max = 0xFFFFFFFF
|
||||
offsets: Dict[str, int] = dict()
|
||||
|
||||
def _get_offset(key, name, numel):
|
||||
"""
|
||||
Return the offset of the storage associated with key with record name `name` and size numel.
|
||||
It is expected that the zipfile header of this storage starts at current_offset.
|
||||
|
||||
WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
|
||||
the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
|
||||
in sync with that of miniz!
|
||||
|
||||
After reading a storage of size numel that starts at storage_offset
|
||||
if it is the first time that storage was read, update nonlocal variable
|
||||
current_offset to the start of the next zipfile header by incrementing
|
||||
it by numel and the data descriptor size.
|
||||
"""
|
||||
nonlocal current_offset, offsets
|
||||
if name in offsets:
|
||||
storage_offset = offsets[name]
|
||||
return storage_offset
|
||||
|
||||
if current_offset is None:
|
||||
assert key == "0"
|
||||
current_offset = zip_file.get_record_offset(name)
|
||||
local_header_offset = zip_file.get_record_header_offset(name)
|
||||
storage_offset = current_offset
|
||||
else:
|
||||
storage_offset = zip_file.get_record_offset_no_read(
|
||||
current_offset, name, numel
|
||||
)
|
||||
local_header_offset = current_offset
|
||||
|
||||
# This is only actually needed for storages that have typed_storage._data_ptr() == 0
|
||||
# after being read. Otherwise persistent_load would never "re-call" load_tensor
|
||||
# for a given key.
|
||||
offsets[name] = storage_offset
|
||||
|
||||
# Increment current_offset of offset where next zipfile header starts
|
||||
current_offset = storage_offset + numel
|
||||
# add size of data descriptor after payload
|
||||
if numel > 0:
|
||||
if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
|
||||
current_offset += data_descripter_size64
|
||||
else:
|
||||
current_offset += data_descripter_size32
|
||||
|
||||
return storage_offset
|
||||
|
||||
def load_tensor(dtype, numel, key, location):
|
||||
name = f"data/{key}"
|
||||
if torch._guards.detect_fake_mode(None) is not None:
|
||||
nbytes = numel * torch._utils._element_size(dtype)
|
||||
storage = torch.UntypedStorage(nbytes, device="meta")
|
||||
elif overall_storage is not None:
|
||||
storage_offset = zip_file.get_record_offset(name)
|
||||
if can_calculate_storage_offsets and calculate_storage_offsets:
|
||||
storage_offset = _get_offset(key, name, numel)
|
||||
if run_debug_asserts:
|
||||
if storage_offset != zip_file.get_record_offset(name):
|
||||
raise RuntimeError(
|
||||
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
|
||||
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
|
||||
f"{zip_file.get_record_offset(name)}"
|
||||
)
|
||||
else:
|
||||
storage_offset = zip_file.get_record_offset(name)
|
||||
storage = overall_storage[storage_offset : storage_offset + numel]
|
||||
else:
|
||||
if can_calculate_storage_offsets and run_debug_asserts:
|
||||
# This is debug code that we use to test the validity of
|
||||
# torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
|
||||
storage_offset = _get_offset(key, name, numel)
|
||||
if storage_offset != zip_file.get_record_offset(name):
|
||||
raise RuntimeError(
|
||||
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
|
||||
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
|
||||
f"{zip_file.get_record_offset(name)}"
|
||||
)
|
||||
storage = (
|
||||
zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
|
||||
._typed_storage()
|
||||
|
@ -13,6 +13,7 @@ class load:
|
||||
endianness: _Optional["_LoadEndianess"] = None
|
||||
# MAP_PRIVATE = 2
|
||||
mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2
|
||||
calculate_storage_offsets: bool = False
|
||||
|
||||
|
||||
class save:
|
||||
|
Reference in New Issue
Block a user