Add option to serialization config to reduce random reads from get_record_offset when loading with mmap=True (#143880)

## Background

This PR adds `torch.utils.serialization.config.load.calculate_storage_offsets`. This option relies  on the previous PR in this stack, where storage order was changed to non lexicographical. A `.format_version` entry was added to the zipfile and `calculate_storage_offsets` will only work on checkpoints with `.format_version`.

When this is turned on, for `torch.load(mmap=True)`, offsets of each storage record (other than the 0th storage will be calculated instead of relying on `miniz` APIs to determine this).

The existing APIs will issue multiple random reads (reading the end of central directory record, then reading the zipfile header for the record) to determine the storage offset where the record starts. This can greatly degrade `torch.load(mmap=True)` performance for non-filesystem cases.

6aaae9d78f/caffe2/serialize/inline_container.cc (L589-L605)

## How does this work

The format for the checkpoint is as such

```
archive_name/
|_ data.pkl
|_.format_version
|_byteorder
|_data/
  |_ 0
  |_ 1
  |_ 2
  |_ ...
|_
```

Each `data/i` record represents a storage, where storages are written in the order that the Pickler encounters them.

For each storage, our `persistent_load` logic saves the following metadata to the pickle file `dtype, numel, key, location` where `numel` is the number of bytes in the storage.

Note that we always use `miniz` writer  in the zip64 mode per [here](7796e308d0/caffe2/serialize/inline_container.cc (L701)) A zipfile record written by miniz looks as such

```
 ---------------- ----------------- ------------------- ---------------- --------- ------------------------------
| 30 byte header | n byte filename | zip64_extra_data | m byte padding | storage | 16 or 24 byte local dir footer  |
 ---------------- ----------------- ------------------- ---------------- --------- ------------------------------
```

- The header size (30) is given by [`MZ_ZIP_LOCAL_DIR_HEADER_SIZE`](https://github.com/pytorch/pytorch/blob/main/third_party/miniz-3.0.2/miniz.c?fbclid=IwZXh0bgNhZW0CMTEAAR2O8Vysd--UoSCxW70gabXIS1dbz733oHwuUQ5_Ff1hY2WU6PL2i6CSH4A_aem_J9oaU2HpDeWtJKOU9EnVqw#L3290)
- filename will be `"{archive_name}/{filepath}"`

- `zip64_extra_data` is determined by [`mz_zip_writer_create_zip64_extra_data`](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6202)). Note that [we only create zip64_extra_data if storage_size >= 0xFFFFFFFF or the offset of the start of the header >= 0xFFFFFFFF](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6519-L6524))
- `m` is determined by [`getPadding`](7796e308d0/caffe2/serialize/inline_container.cc (L254)), which accounts for filename, zip64_extra_data to determine `m` such that the start of `storage` is aligned to 64 bytes. The `m` bytes will always start with `F B padding_size" as the first 4 bytes
- The local dir footer size is determined based on [this snippet ](7796e308d0/third_party/miniz-3.0.2/miniz.c (L6610-L6632)): if the buffer size is 0 it is skipped. If the zip64_extra_data was created, it is 24, otherwise it is 16.

When `torch.utils.serialization.config.load.calculate_storage_offsets` is set we do the following
- We keep track of where the "cursor" is in the file using `current_offset`, after each persistent_load call, it will be at the offset where the header for the next record starts
- for the 0th storage, "data/0", we use the regular get_record_offset to determine the start of the storage
- for any other storage, (where the storages will be in order encountered by the unpickler, 0, 1, 2, 3, ...) we use `get_record_offset_no_read`, which re-uses the `getPadding` logic to determine the offset of the storage
- Note that `load_tensor` will only ever be called again with the same key if the storage's `._data_ptr()` is 0 [[pointer1](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1917-L1918)][[pointer2](https://github.com/pytorch/pytorch/blob/main/torch/serialization.py#L1936-L1937)], so we cache the offsets for this edge case
- After each storage, if the storage is non-zero, we account for the local dir footer based on the logic described above

## Testing strategy

The agreed upon testing strategy was as follows:
- Add debug code gated by an environment flag `TORCH_SERIALIZATION_DEBUG` that will run this offset calculation logic and verify it against getRecordOffset for each storage (when mmap=False)
- This flag is set throughout CI, which means that every time `torch.load` is called, the offset calculation logic is implicitly being tested.

Differential Revision: [D67673026](https://our.internmc.facebook.com/intern/diff/D67673026)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143880
Approved by: https://github.com/albanD
ghstack dependencies: #143879
This commit is contained in:
Mikayla Gawarecki
2025-01-30 10:40:56 -08:00
committed by PyTorch MergeBot
parent 98f87edd23
commit 001e355a56
10 changed files with 195 additions and 8 deletions

View File

@ -18,6 +18,9 @@ if [[ ! $(python -c "import torch; print(int(torch.backends.openmp.is_available(
fi
popd
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
setup_test_python() {
# The CircleCI worker hostname doesn't resolve to an address.
# This environment variable makes ProcessGroupGloo default to

View File

@ -46,6 +46,9 @@ BUILD_BIN_DIR="$BUILD_DIR"/bin
SHARD_NUMBER="${SHARD_NUMBER:=1}"
NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}"
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
export VALGRIND=ON
# export TORCH_INDUCTOR_INSTALL_GXX=ON
if [[ "$BUILD_ENVIRONMENT" == *clang9* || "$BUILD_ENVIRONMENT" == *xpu* ]]; then

View File

@ -18,6 +18,9 @@ export PYTORCH_FINAL_PACKAGE_DIR="${PYTORCH_FINAL_PACKAGE_DIR:-/c/w/build-result
PYTORCH_FINAL_PACKAGE_DIR_WIN=$(cygpath -w "${PYTORCH_FINAL_PACKAGE_DIR}")
export PYTORCH_FINAL_PACKAGE_DIR_WIN
# enable debug asserts in serialization
export TORCH_SERIALIZATION_DEBUG=1
mkdir -p "$TMP_DIR"/build/torch
export SCRIPT_HELPERS_DIR=$SCRIPT_PARENT_DIR/win-test-helpers

View File

@ -251,11 +251,8 @@ constexpr int MZ_ZIP_LDH_EXTRA_LEN_OFS = 28;
constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
namespace detail {
size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf) {
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) {
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
sizeof(mz_uint16) * 2;
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
@ -269,6 +266,16 @@ size_t getPadding(
}
size_t mod = start % kFieldAlignment;
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
std::tuple<size_t, size_t> result(next_offset, start);
return result;
}
size_t getPadding(
size_t cursor,
size_t filename_size,
size_t size,
std::string& padding_buf) {
auto [next_offset, start] = getOffset(cursor, filename_size, size);
size_t padding_size = next_offset - start;
size_t padding_size_plus_fbxx = padding_size + 4;
if (padding_buf.size() < padding_size_plus_fbxx) {
@ -587,6 +594,14 @@ static int64_t read_le_16(uint8_t* buf) {
return buf[0] + (buf[1] << 8);
}
size_t PyTorchStreamReader::getRecordHeaderOffset(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getRecordID(name), &stat);
valid("retrieving file meta-data for ", name.c_str());
return stat.m_local_header_ofs;
}
size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
mz_zip_archive_file_stat stat;
@ -611,6 +626,17 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
return stat.m_uncomp_size;
}
size_t PyTorchStreamReader::getRecordOffsetNoRead(
size_t cursor,
std::string filename,
size_t size) {
std::string full_name = archive_name_plus_slash_ + filename;
size_t full_name_size = full_name.size();
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size);
size_t offset = std::get<0>(result);
return offset;
}
PyTorchStreamReader::~PyTorchStreamReader() {
mz_zip_clear_last_error(ar_.get());
mz_zip_reader_end(ar_.get());

View File

@ -172,8 +172,10 @@ class TORCH_API PyTorchStreamReader final {
size_t n);
size_t getRecordSize(const std::string& name);
size_t getRecordHeaderOffset(const std::string& name);
size_t getRecordOffset(const std::string& name);
size_t
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size);
bool hasRecord(const std::string& name);
std::vector<std::string> getAllRecords();
@ -289,6 +291,9 @@ size_t getPadding(
size_t filename_size,
size_t size,
std::string& padding_buf);
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size);
} // namespace detail
} // namespace serialize

View File

@ -515,3 +515,6 @@ Config
(Default : ``torch.serialization.LoadEndianness.NATIVE``)
* ``mmap_flags``: See :class:`~torch.serialization.set_default_mmap_options`.
(Default : ``MAP_PRIVATE``)
* ``calculate_storage_offsets``: If this config is set to ``True``, offsets for storages will be
calculated rather than read via random reads when using ``torch.load(mmap=True)``. This minimizes
random reads, which can be helpful when the file is being loaded over a network. (Default : ``False``)

View File

@ -45,6 +45,7 @@ from torch.testing._internal.common_utils import (
BytesIOContext,
download_file,
instantiate_parametrized_tests,
IS_CI,
IS_FBCODE,
IS_FILESYSTEM_UTF8_ENCODING,
IS_WINDOWS,
@ -827,6 +828,11 @@ class SerializationMixin:
loaded_data = torch.load(f, weights_only=True)
self.assertEqual(data, loaded_data)
@unittest.skipIf(not IS_CI, "only check debug var is set in CI")
def test_debug_set_in_ci(self):
# This test is to make sure that the serialization debug flag is set in CI
self.assertTrue(os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1")
class serialization_method:
def __init__(self, use_zip):
@ -1041,6 +1047,23 @@ class TestSerialization(TestCase, SerializationMixin):
f.seek(0)
state = torch.load(f)
@serialTest()
def test_serialization_4gb_file(self):
'''
This is a specially engineered testcase that would fail if the data_descriptor size
had been incorrectly set as data_descriptor_size32 when it should be data_descriptor_size64
'''
# Run GC to clear up as much memory as possible before running this test
gc.collect()
big_model = torch.nn.ModuleList([torch.nn.Linear(1, int(1024 * 1024 * 1024) + 12, bias=False),
torch.nn.Linear(1, 1, bias=False).to(torch.float8_e4m3fn),
torch.nn.Linear(1, 2, bias=False).to(torch.float8_e4m3fn)])
with BytesIOContext() as f:
torch.save(big_model.state_dict(), f)
f.seek(0)
torch.load(f)
@parametrize('weights_only', (True, False))
def test_pathlike_serialization(self, weights_only):
model = torch.nn.Conv2d(20, 3200, kernel_size=3)
@ -4533,6 +4556,30 @@ class TestSerialization(TestCase, SerializationMixin):
self.assertTrue(opened_zipfile.has_record(".format_version"))
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
@parametrize('path_type', (str, Path))
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
def test_mmap_load_offset_calculation(self, path_type):
calculate_offsets_before = serialization_config.load.calculate_storage_offsets
try:
serialization_config.load.calculate_storage_offsets = True
m = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
with TemporaryFileName() as f:
f = path_type(f)
state_dict = m.state_dict()
torch.save(state_dict, f)
result = torch.load(f, mmap=True)
result_non_mmap = torch.load(f, mmap=False)
with torch.device("meta"):
model_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
model_non_mmap_state_dict = torch.nn.Sequential(*[torch.nn.Linear(4, 4) for _ in range(20)])
model_mmap_state_dict.load_state_dict(result, assign=True)
model_non_mmap_state_dict.load_state_dict(result_non_mmap, assign=True)
inp = torch.randn(4, 4)
self.assertEqual(model_mmap_state_dict(inp), model_non_mmap_state_dict(inp.clone()))
finally:
serialization_config.load.calculate_storage_offsets = calculate_offsets_before
def run(self, *args, **kwargs):
with serialization_method(use_zip=True):

View File

@ -1619,6 +1619,20 @@ void initJITBindings(PyObject* module) {
"get_record_offset",
[](PyTorchStreamReader& self, const std::string& key) {
return self.getRecordOffset(key);
})
.def(
"get_record_header_offset",
[](PyTorchStreamReader& self, const std::string& key) {
return self.getRecordHeaderOffset(key);
})
.def(
"get_record_offset_no_read",
[](PyTorchStreamReader& self,
size_t zipfile_header_offset,
const std::string filename,
size_t size) {
return self.getRecordOffsetNoRead(
zipfile_header_offset, filename, size);
});
// Used by torch.Package to coordinate deserialization of storages across

View File

@ -15,7 +15,7 @@ import threading
import warnings
from contextlib import closing, contextmanager
from enum import Enum
from typing import Any, Callable, cast, Generic, IO, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, Generic, IO, Optional, TypeVar, Union
from typing_extensions import TypeAlias, TypeIs
import torch
@ -1856,6 +1856,11 @@ def _load(
loaded_storages = {}
can_calculate_storage_offsets = False
if zip_file.has_record(".format_version"):
version = zip_file.get_record(".format_version")
can_calculate_storage_offsets = version >= b"1"
# check if byteswapping is needed
byteordername = "byteorder"
byteorderdata = None
@ -1891,15 +1896,92 @@ def _load(
UserWarning,
)
from torch.utils.serialization import config
calculate_storage_offsets = config.load.calculate_storage_offsets
run_debug_asserts = os.environ.get("TORCH_SERIALIZATION_DEBUG", "0") == "1"
current_offset = None
# constants from miniz.h/miniz.c
data_descripter_size64 = 24
data_descripter_size32 = 16
mz_uint32_max = 0xFFFFFFFF
offsets: Dict[str, int] = dict()
def _get_offset(key, name, numel):
"""
Return the offset of the storage associated with key with record name `name` and size numel.
It is expected that the zipfile header of this storage starts at current_offset.
WARNING: This function relies on the behavior of the zipwriter in miniz.c. In particular,
the behavior of `mz_zip_writer_add_mem_ex_v2`. The behavior of this function must be kept
in sync with that of miniz!
After reading a storage of size numel that starts at storage_offset
if it is the first time that storage was read, update nonlocal variable
current_offset to the start of the next zipfile header by incrementing
it by numel and the data descriptor size.
"""
nonlocal current_offset, offsets
if name in offsets:
storage_offset = offsets[name]
return storage_offset
if current_offset is None:
assert key == "0"
current_offset = zip_file.get_record_offset(name)
local_header_offset = zip_file.get_record_header_offset(name)
storage_offset = current_offset
else:
storage_offset = zip_file.get_record_offset_no_read(
current_offset, name, numel
)
local_header_offset = current_offset
# This is only actually needed for storages that have typed_storage._data_ptr() == 0
# after being read. Otherwise persistent_load would never "re-call" load_tensor
# for a given key.
offsets[name] = storage_offset
# Increment current_offset of offset where next zipfile header starts
current_offset = storage_offset + numel
# add size of data descriptor after payload
if numel > 0:
if local_header_offset >= mz_uint32_max or numel >= mz_uint32_max:
current_offset += data_descripter_size64
else:
current_offset += data_descripter_size32
return storage_offset
def load_tensor(dtype, numel, key, location):
name = f"data/{key}"
if torch._guards.detect_fake_mode(None) is not None:
nbytes = numel * torch._utils._element_size(dtype)
storage = torch.UntypedStorage(nbytes, device="meta")
elif overall_storage is not None:
storage_offset = zip_file.get_record_offset(name)
if can_calculate_storage_offsets and calculate_storage_offsets:
storage_offset = _get_offset(key, name, numel)
if run_debug_asserts:
if storage_offset != zip_file.get_record_offset(name):
raise RuntimeError(
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
f"{zip_file.get_record_offset(name)}"
)
else:
storage_offset = zip_file.get_record_offset(name)
storage = overall_storage[storage_offset : storage_offset + numel]
else:
if can_calculate_storage_offsets and run_debug_asserts:
# This is debug code that we use to test the validity of
# torch.utils.serialization.config.load.calculate_storage_offsets throughout CI
storage_offset = _get_offset(key, name, numel)
if storage_offset != zip_file.get_record_offset(name):
raise RuntimeError(
"This is a debug assert that was run as the `TORCH_SERIALIZATION_DEBUG` environment "
f"variable was set: Incorrect offset for {name}, got {storage_offset} expected "
f"{zip_file.get_record_offset(name)}"
)
storage = (
zip_file.get_storage_from_record(name, numel, torch.UntypedStorage)
._typed_storage()

View File

@ -13,6 +13,7 @@ class load:
endianness: _Optional["_LoadEndianess"] = None
# MAP_PRIVATE = 2
mmap_flags: _Optional[int] = None if sys.platform == "win32" else 2
calculate_storage_offsets: bool = False
class save: