mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make record/storage alignment in torch.save configurable (#147788)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147788 Approved by: https://github.com/albanD ghstack dependencies: #147786, #147787
This commit is contained in:
committed by
PyTorch MergeBot
parent
209977e6e5
commit
be0ceee1c3
@ -252,7 +252,11 @@ constexpr int MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50;
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size) {
|
std::tuple<size_t, size_t> getOffset(
|
||||||
|
size_t cursor,
|
||||||
|
size_t filename_size,
|
||||||
|
size_t size,
|
||||||
|
uint64_t alignment) {
|
||||||
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
|
size_t start = cursor + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_size +
|
||||||
sizeof(mz_uint16) * 2;
|
sizeof(mz_uint16) * 2;
|
||||||
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
if (size >= MZ_UINT32_MAX || cursor >= MZ_UINT32_MAX) {
|
||||||
@ -264,8 +268,8 @@ std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t
|
|||||||
start += sizeof(mz_uint64);
|
start += sizeof(mz_uint64);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
size_t mod = start % kFieldAlignment;
|
size_t mod = start % alignment;
|
||||||
size_t next_offset = (mod == 0) ? start : (start + kFieldAlignment - mod);
|
size_t next_offset = (mod == 0) ? start : (start + alignment - mod);
|
||||||
std::tuple<size_t, size_t> result(next_offset, start);
|
std::tuple<size_t, size_t> result(next_offset, start);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -274,8 +278,9 @@ size_t getPadding(
|
|||||||
size_t cursor,
|
size_t cursor,
|
||||||
size_t filename_size,
|
size_t filename_size,
|
||||||
size_t size,
|
size_t size,
|
||||||
std::string& padding_buf) {
|
std::string& padding_buf,
|
||||||
auto [next_offset, start] = getOffset(cursor, filename_size, size);
|
uint64_t alignment) {
|
||||||
|
auto [next_offset, start] = getOffset(cursor, filename_size, size, alignment);
|
||||||
size_t padding_size = next_offset - start;
|
size_t padding_size = next_offset - start;
|
||||||
size_t padding_size_plus_fbxx = padding_size + 4;
|
size_t padding_size_plus_fbxx = padding_size + 4;
|
||||||
if (padding_buf.size() < padding_size_plus_fbxx) {
|
if (padding_buf.size() < padding_size_plus_fbxx) {
|
||||||
@ -410,8 +415,7 @@ size_t PyTorchStreamReader::getRecordMultiReaders(
|
|||||||
}
|
}
|
||||||
readSizes[i] = size;
|
readSizes[i] = size;
|
||||||
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
|
LOG(INFO) << "Thread " << i << " read [" << startPos << "-" << endPos
|
||||||
<< "] "
|
<< "] " << "from " << name << " of size " << n;
|
||||||
<< "from " << name << " of size " << n;
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
threadReadSize == size,
|
threadReadSize == size,
|
||||||
"record size ",
|
"record size ",
|
||||||
@ -629,10 +633,12 @@ size_t PyTorchStreamReader::getRecordSize(const std::string& name) {
|
|||||||
size_t PyTorchStreamReader::getRecordOffsetNoRead(
|
size_t PyTorchStreamReader::getRecordOffsetNoRead(
|
||||||
size_t cursor,
|
size_t cursor,
|
||||||
std::string filename,
|
std::string filename,
|
||||||
size_t size) {
|
size_t size,
|
||||||
|
uint64_t alignment) {
|
||||||
std::string full_name = archive_name_plus_slash_ + filename;
|
std::string full_name = archive_name_plus_slash_ + filename;
|
||||||
size_t full_name_size = full_name.size();
|
size_t full_name_size = full_name.size();
|
||||||
std::tuple<size_t, size_t> result = detail::getOffset(cursor, full_name_size, size);
|
std::tuple<size_t, size_t> result =
|
||||||
|
detail::getOffset(cursor, full_name_size, size, alignment);
|
||||||
size_t offset = std::get<0>(result);
|
size_t offset = std::get<0>(result);
|
||||||
return offset;
|
return offset;
|
||||||
}
|
}
|
||||||
@ -673,17 +679,22 @@ size_t ostream_write_func(
|
|||||||
|
|
||||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||||
const std::string& file_name,
|
const std::string& file_name,
|
||||||
bool compute_crc32)
|
bool compute_crc32,
|
||||||
: archive_name_(basename(file_name)), compute_crc32_(compute_crc32) {
|
uint64_t alignment)
|
||||||
|
: archive_name_(basename(file_name)),
|
||||||
|
compute_crc32_(compute_crc32),
|
||||||
|
alignment_(alignment) {
|
||||||
setup(file_name);
|
setup(file_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||||
const std::function<size_t(const void*, size_t)> writer_func,
|
const std::function<size_t(const void*, size_t)> writer_func,
|
||||||
bool compute_crc32)
|
bool compute_crc32,
|
||||||
|
uint64_t alignment)
|
||||||
: archive_name_("archive"),
|
: archive_name_("archive"),
|
||||||
writer_func_(writer_func),
|
writer_func_(writer_func),
|
||||||
compute_crc32_(compute_crc32) {
|
compute_crc32_(compute_crc32),
|
||||||
|
alignment_(alignment) {
|
||||||
setup(archive_name_);
|
setup(archive_name_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -748,8 +759,12 @@ void PyTorchStreamWriter::writeRecord(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
std::string full_name = archive_name_plus_slash_ + name;
|
std::string full_name = archive_name_plus_slash_ + name;
|
||||||
size_t padding_size =
|
size_t padding_size = detail::getPadding(
|
||||||
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
ar_->m_archive_size,
|
||||||
|
full_name.size(),
|
||||||
|
size,
|
||||||
|
padding_,
|
||||||
|
alignment_);
|
||||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||||
if (!compute_crc32_) {
|
if (!compute_crc32_) {
|
||||||
#if (!defined(FBCODE_CAFFE2))
|
#if (!defined(FBCODE_CAFFE2))
|
||||||
|
@ -174,8 +174,11 @@ class TORCH_API PyTorchStreamReader final {
|
|||||||
size_t getRecordSize(const std::string& name);
|
size_t getRecordSize(const std::string& name);
|
||||||
size_t getRecordHeaderOffset(const std::string& name);
|
size_t getRecordHeaderOffset(const std::string& name);
|
||||||
size_t getRecordOffset(const std::string& name);
|
size_t getRecordOffset(const std::string& name);
|
||||||
size_t
|
size_t getRecordOffsetNoRead(
|
||||||
getRecordOffsetNoRead(size_t cursor, std::string filename, size_t size);
|
size_t cursor,
|
||||||
|
std::string filename,
|
||||||
|
size_t size,
|
||||||
|
uint64_t alignment);
|
||||||
bool hasRecord(const std::string& name);
|
bool hasRecord(const std::string& name);
|
||||||
std::vector<std::string> getAllRecords();
|
std::vector<std::string> getAllRecords();
|
||||||
|
|
||||||
@ -222,10 +225,12 @@ class TORCH_API PyTorchStreamWriter final {
|
|||||||
public:
|
public:
|
||||||
explicit PyTorchStreamWriter(
|
explicit PyTorchStreamWriter(
|
||||||
const std::string& archive_name,
|
const std::string& archive_name,
|
||||||
bool compute_crc32 = true);
|
bool compute_crc32 = true,
|
||||||
|
uint64_t alignment = 64);
|
||||||
explicit PyTorchStreamWriter(
|
explicit PyTorchStreamWriter(
|
||||||
const std::function<size_t(const void*, size_t)> writer_func,
|
const std::function<size_t(const void*, size_t)> writer_func,
|
||||||
bool compute_crc32 = true);
|
bool compute_crc32 = true,
|
||||||
|
uint64_t alignment = 64);
|
||||||
|
|
||||||
void setMinVersion(const uint64_t version);
|
void setMinVersion(const uint64_t version);
|
||||||
|
|
||||||
@ -267,6 +272,7 @@ class TORCH_API PyTorchStreamWriter final {
|
|||||||
uint64_t combined_uncomp_crc32_ = 0;
|
uint64_t combined_uncomp_crc32_ = 0;
|
||||||
std::string serialization_id_;
|
std::string serialization_id_;
|
||||||
bool compute_crc32_;
|
bool compute_crc32_;
|
||||||
|
uint64_t alignment_;
|
||||||
|
|
||||||
// This number will be updated when the model has operators
|
// This number will be updated when the model has operators
|
||||||
// that have valid upgraders.
|
// that have valid upgraders.
|
||||||
@ -281,8 +287,6 @@ class TORCH_API PyTorchStreamWriter final {
|
|||||||
};
|
};
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
// Writer-specific constants
|
|
||||||
constexpr uint64_t kFieldAlignment = 64;
|
|
||||||
|
|
||||||
// Returns a record to be appended to the local user extra data entry in order
|
// Returns a record to be appended to the local user extra data entry in order
|
||||||
// to make data beginning aligned at kFieldAlignment bytes boundary.
|
// to make data beginning aligned at kFieldAlignment bytes boundary.
|
||||||
@ -290,9 +294,11 @@ size_t getPadding(
|
|||||||
size_t cursor,
|
size_t cursor,
|
||||||
size_t filename_size,
|
size_t filename_size,
|
||||||
size_t size,
|
size_t size,
|
||||||
std::string& padding_buf);
|
std::string& padding_buf,
|
||||||
|
uint64_t alignment);
|
||||||
|
|
||||||
std::tuple<size_t, size_t> getOffset(size_t cursor, size_t filename_size, size_t size);
|
std::tuple<size_t, size_t>
|
||||||
|
getOffset(size_t cursor, size_t filename_size, size_t size, uint64_t alignment);
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
|
@ -505,6 +505,7 @@ Config
|
|||||||
See :func:`~torch.serialization.set_crc32_options`.
|
See :func:`~torch.serialization.set_crc32_options`.
|
||||||
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
|
* ``use_pinned_memory_for_d2h``: for storages that are on an accelerator when passed to ``torch.save``, whether to
|
||||||
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
|
move storage to pinned memory or pageable memory on CPU within ``torch.save``. (Default: ``False`` (i.e. pageable))
|
||||||
|
* ``storage_alignment``: alignment of storages in the checkpoint during ``torch.save`` in bytes. (Default ``64``)
|
||||||
|
|
||||||
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
|
``torch.utils.serialization.config.load`` contains options that control the behavior of ``torch.load``.
|
||||||
|
|
||||||
|
@ -4653,6 +4653,34 @@ class TestSerialization(TestCase, SerializationMixin):
|
|||||||
self.assertTrue(opened_zipfile.has_record(".format_version"))
|
self.assertTrue(opened_zipfile.has_record(".format_version"))
|
||||||
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
|
self.assertEqual(opened_zipfile.get_record(".format_version"), b'1')
|
||||||
|
|
||||||
|
def test_storage_alignment(self):
|
||||||
|
sd = torch.nn.Linear(10, 10).state_dict()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
torch.save(sd, f)
|
||||||
|
f.seek(0)
|
||||||
|
with FakeTensorMode():
|
||||||
|
sd_fake = torch.load(f)
|
||||||
|
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 832)
|
||||||
|
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 1344)
|
||||||
|
|
||||||
|
storage_alignment_before = serialization_config.save.storage_alignment
|
||||||
|
with tempfile.NamedTemporaryFile() as f:
|
||||||
|
try:
|
||||||
|
serialization_config.save.storage_alignment = 4096
|
||||||
|
torch.save(sd, f)
|
||||||
|
f.seek(0)
|
||||||
|
with FakeTensorMode():
|
||||||
|
sd_fake = torch.load(f)
|
||||||
|
self.assertEqual(sd_fake['weight'].untyped_storage()._checkpoint_offset, 20480)
|
||||||
|
self.assertEqual(sd_fake['bias'].untyped_storage()._checkpoint_offset, 24576)
|
||||||
|
f.seek(0)
|
||||||
|
sd_loaded = torch.load(f)
|
||||||
|
self.assertEqual(sd_loaded, sd)
|
||||||
|
finally:
|
||||||
|
serialization_config.save.storage_alignment = storage_alignment_before
|
||||||
|
|
||||||
|
|
||||||
@parametrize('path_type', (str, Path))
|
@parametrize('path_type', (str, Path))
|
||||||
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
|
@unittest.skipIf(IS_WINDOWS, "TemporaryFileName on windows")
|
||||||
def test_mmap_load_offset_calculation(self, path_type):
|
def test_mmap_load_offset_calculation(self, path_type):
|
||||||
|
@ -1486,9 +1486,9 @@ class PyTorchFileReader:
|
|||||||
|
|
||||||
class PyTorchFileWriter:
|
class PyTorchFileWriter:
|
||||||
@overload
|
@overload
|
||||||
def __init__(self, name: str, compute_crc32 = True) -> None: ...
|
def __init__(self, name: str, compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
|
||||||
@overload
|
@overload
|
||||||
def __init__(self, buffer: IO[bytes], compute_crc32 = True) -> None: ...
|
def __init__(self, buffer: IO[bytes], compute_crc32: _bool = True, storage_alignment: _int = 64) -> None: ...
|
||||||
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
||||||
def write_end_of_file(self) -> None: ...
|
def write_end_of_file(self) -> None: ...
|
||||||
def set_min_version(self, version: _int) -> None: ...
|
def set_min_version(self, version: _int) -> None: ...
|
||||||
|
@ -1390,11 +1390,14 @@ void initJITBindings(PyObject* module) {
|
|||||||
|
|
||||||
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
|
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
|
||||||
.def(
|
.def(
|
||||||
py::init<std::string, bool>(),
|
py::init<std::string, bool, uint64_t>(),
|
||||||
py::arg("file_name"),
|
py::arg("file_name"),
|
||||||
py::arg("compute_crc32") = true)
|
py::arg("compute_crc32") = true,
|
||||||
|
py::arg("storage_alignment") = 64)
|
||||||
.def(
|
.def(
|
||||||
py::init([](const py::object& buffer, bool compute_crc32 = true) {
|
py::init([](const py::object& buffer,
|
||||||
|
bool compute_crc32 = true,
|
||||||
|
uint64_t storage_alignment = 64) {
|
||||||
auto writer_func = [=](const void* data, size_t size) {
|
auto writer_func = [=](const void* data, size_t size) {
|
||||||
// Writing an empty file is a noop
|
// Writing an empty file is a noop
|
||||||
if (size == 0) {
|
if (size == 0) {
|
||||||
@ -1413,14 +1416,19 @@ void initJITBindings(PyObject* module) {
|
|||||||
return size;
|
return size;
|
||||||
};
|
};
|
||||||
return std::make_unique<PyTorchStreamWriter>(
|
return std::make_unique<PyTorchStreamWriter>(
|
||||||
std::move(writer_func), compute_crc32);
|
std::move(writer_func), compute_crc32, storage_alignment);
|
||||||
}),
|
}),
|
||||||
py::arg("buffer"),
|
py::arg("buffer"),
|
||||||
py::arg("compute_crc32") = true)
|
py::arg("compute_crc32") = true,
|
||||||
|
py::arg("storage_alignment") = 64)
|
||||||
.def(
|
.def(
|
||||||
py::init<const std::function<size_t(const void*, size_t)>&, bool>(),
|
py::init<
|
||||||
|
const std::function<size_t(const void*, size_t)>&,
|
||||||
|
bool,
|
||||||
|
uint64_t>(),
|
||||||
py::arg("writer_func"),
|
py::arg("writer_func"),
|
||||||
py::arg("compute_crc32") = true)
|
py::arg("compute_crc32") = true,
|
||||||
|
py::arg("storage_alignment") = 64)
|
||||||
// [Note: write_record_metadata]
|
// [Note: write_record_metadata]
|
||||||
// The write_record_metadata function is intended to write metadata (i.e.
|
// The write_record_metadata function is intended to write metadata (i.e.
|
||||||
// the zipfile header and end of central directory record) for a file
|
// the zipfile header and end of central directory record) for a file
|
||||||
@ -1630,9 +1638,10 @@ void initJITBindings(PyObject* module) {
|
|||||||
[](PyTorchStreamReader& self,
|
[](PyTorchStreamReader& self,
|
||||||
size_t zipfile_header_offset,
|
size_t zipfile_header_offset,
|
||||||
const std::string filename,
|
const std::string filename,
|
||||||
size_t size) {
|
size_t size,
|
||||||
|
uint64_t storage_alignment) {
|
||||||
return self.getRecordOffsetNoRead(
|
return self.getRecordOffsetNoRead(
|
||||||
zipfile_header_offset, filename, size);
|
zipfile_header_offset, filename, size, storage_alignment);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Used by torch.Package to coordinate deserialization of storages across
|
// Used by torch.Package to coordinate deserialization of storages across
|
||||||
|
@ -211,6 +211,20 @@ def get_default_mmap_options() -> Optional[int]:
|
|||||||
return config.load.mmap_flags
|
return config.load.mmap_flags
|
||||||
|
|
||||||
|
|
||||||
|
def _get_storage_alignment() -> int:
|
||||||
|
"""
|
||||||
|
Gets alignment for storages in torch.save files/
|
||||||
|
|
||||||
|
Defaults to 64.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
storage_alginment: int
|
||||||
|
"""
|
||||||
|
from torch.utils.serialization import config
|
||||||
|
|
||||||
|
return config.save.storage_alignment
|
||||||
|
|
||||||
|
|
||||||
class set_default_mmap_options:
|
class set_default_mmap_options:
|
||||||
"""
|
"""
|
||||||
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
Context manager or function to set default mmap options for :func:`torch.load` with ``mmap=True`` to flags.
|
||||||
@ -767,10 +781,16 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
|||||||
# for writing out the file.
|
# for writing out the file.
|
||||||
self.file_stream = io.FileIO(self.name, mode="w")
|
self.file_stream = io.FileIO(self.name, mode="w")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
torch._C.PyTorchFileWriter(self.file_stream, get_crc32_options())
|
torch._C.PyTorchFileWriter(
|
||||||
|
self.file_stream, get_crc32_options(), _get_storage_alignment()
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
super().__init__(torch._C.PyTorchFileWriter(self.name, get_crc32_options()))
|
super().__init__(
|
||||||
|
torch._C.PyTorchFileWriter(
|
||||||
|
self.name, get_crc32_options(), _get_storage_alignment()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
def __exit__(self, *args) -> None:
|
||||||
self.file_like.write_end_of_file()
|
self.file_like.write_end_of_file()
|
||||||
@ -786,7 +806,11 @@ class _open_zipfile_writer_buffer(_opener[torch._C.PyTorchFileWriter]):
|
|||||||
raise AttributeError(msg)
|
raise AttributeError(msg)
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
self.buffer = buffer
|
self.buffer = buffer
|
||||||
super().__init__(torch._C.PyTorchFileWriter(buffer, get_crc32_options()))
|
super().__init__(
|
||||||
|
torch._C.PyTorchFileWriter(
|
||||||
|
buffer, get_crc32_options(), _get_storage_alignment()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def __exit__(self, *args) -> None:
|
def __exit__(self, *args) -> None:
|
||||||
self.file_like.write_end_of_file()
|
self.file_like.write_end_of_file()
|
||||||
@ -1188,7 +1212,13 @@ def _save(
|
|||||||
# .format_version is used to track
|
# .format_version is used to track
|
||||||
# 1. version 1 represents the order of storages being changed from
|
# 1. version 1 represents the order of storages being changed from
|
||||||
# lexicographical based on keys to numerically ordered based on keys
|
# lexicographical based on keys to numerically ordered based on keys
|
||||||
|
# 2. version 2 represents including storage_alignment as a record
|
||||||
|
# within the zipfile
|
||||||
zip_file.write_record(".format_version", "1", len("1"))
|
zip_file.write_record(".format_version", "1", len("1"))
|
||||||
|
storage_alignment = str(_get_storage_alignment())
|
||||||
|
zip_file.write_record(
|
||||||
|
".storage_alignment", storage_alignment, len(storage_alignment)
|
||||||
|
)
|
||||||
|
|
||||||
# Write byte order marker
|
# Write byte order marker
|
||||||
if not _disable_byteorder_record:
|
if not _disable_byteorder_record:
|
||||||
@ -1886,6 +1916,10 @@ def _load(
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid load endianness type")
|
raise ValueError("Invalid load endianness type")
|
||||||
|
|
||||||
|
storage_alignment = 64
|
||||||
|
if zip_file.has_record(".storage_alignment"):
|
||||||
|
storage_alignment = int(zip_file.get_record(".storage_alignment"))
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not zip_file.has_record(byteordername)
|
not zip_file.has_record(byteordername)
|
||||||
and get_default_load_endianness() is None
|
and get_default_load_endianness() is None
|
||||||
@ -1939,7 +1973,7 @@ def _load(
|
|||||||
storage_offset = current_offset
|
storage_offset = current_offset
|
||||||
else:
|
else:
|
||||||
storage_offset = zip_file.get_record_offset_no_read(
|
storage_offset = zip_file.get_record_offset_no_read(
|
||||||
current_offset, name, numel
|
current_offset, name, numel, storage_alignment
|
||||||
)
|
)
|
||||||
local_header_offset = current_offset
|
local_header_offset = current_offset
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ class load:
|
|||||||
class save:
|
class save:
|
||||||
compute_crc32: bool = True
|
compute_crc32: bool = True
|
||||||
use_pinned_memory_for_d2h: bool = False
|
use_pinned_memory_for_d2h: bool = False
|
||||||
|
storage_alignment: int = 64
|
||||||
|
|
||||||
|
|
||||||
_install_config_module(sys.modules[__name__])
|
_install_config_module(sys.modules[__name__])
|
||||||
|
Reference in New Issue
Block a user