mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
reimport pr137735 due to merging check issues (#138959)
This is a cherry-pick from #137735 by @mikaylagawarecki , that cannot be merged due to a (wrongly) failing check for codev @diff-train-skip-merge Pull Request resolved: https://github.com/pytorch/pytorch/pull/138959 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
committed by
PyTorch MergeBot
parent
144d75d934
commit
bae3426af7
@ -36,7 +36,7 @@ def define_targets(rules):
|
||||
"caffe2/serialize/istream_adapter.cc",
|
||||
"caffe2/serialize/read_adapter_interface.cc",
|
||||
],
|
||||
copts = ["-fexceptions"],
|
||||
copts = ["-fexceptions", "-DFBCODE_CAFFE2"],
|
||||
tags = [
|
||||
"-fbcode",
|
||||
"supermodule:android/default/pytorch",
|
||||
|
@ -621,15 +621,17 @@ size_t ostream_write_func(
|
||||
return ret;
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name)
|
||||
: archive_name_(basename(file_name)) {
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32)
|
||||
: archive_name_(basename(file_name)),
|
||||
compute_crc32_(compute_crc32) {
|
||||
setup(file_name);
|
||||
}
|
||||
|
||||
PyTorchStreamWriter::PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func)
|
||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32)
|
||||
: archive_name_("archive"),
|
||||
writer_func_(writer_func) {
|
||||
writer_func_(writer_func),
|
||||
compute_crc32_(compute_crc32) {
|
||||
setup(archive_name_);
|
||||
}
|
||||
|
||||
@ -695,6 +697,11 @@ void PyTorchStreamWriter::writeRecord(
|
||||
size_t padding_size =
|
||||
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
||||
uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0;
|
||||
if (!compute_crc32_) {
|
||||
#if (!defined(FBCODE_CAFFE2))
|
||||
flags |= MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32;
|
||||
#endif
|
||||
}
|
||||
mz_zip_writer_add_mem_ex_v2(
|
||||
/*pZip=*/ar_.get(),
|
||||
/*pArchive_name=*/full_name.c_str(),
|
||||
|
@ -205,9 +205,9 @@ class TORCH_API PyTorchStreamReader final {
|
||||
|
||||
class TORCH_API PyTorchStreamWriter final {
|
||||
public:
|
||||
explicit PyTorchStreamWriter(const std::string& archive_name);
|
||||
explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true);
|
||||
explicit PyTorchStreamWriter(
|
||||
const std::function<size_t(const void*, size_t)> writer_func);
|
||||
const std::function<size_t(const void*, size_t)> writer_func, bool compute_crc32 = true);
|
||||
|
||||
void setMinVersion(const uint64_t version);
|
||||
|
||||
@ -248,6 +248,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
std::function<size_t(const void*, size_t)> writer_func_;
|
||||
uint64_t combined_uncomp_crc32_ = 0;
|
||||
std::string serialization_id_;
|
||||
bool compute_crc32_;
|
||||
|
||||
// This number will be updated when the model has operators
|
||||
// that have valid upgraders.
|
||||
|
@ -390,6 +390,8 @@ The following utility functions are related to serialization:
|
||||
.. currentmodule:: torch.serialization
|
||||
|
||||
.. autofunction:: register_package
|
||||
.. autofunction:: get_crc32_options
|
||||
.. autofunction:: set_crc32_options
|
||||
.. autofunction:: get_default_load_endianness
|
||||
.. autofunction:: set_default_load_endianness
|
||||
.. autofunction:: get_default_mmap_options
|
||||
|
@ -4334,6 +4334,35 @@ class TestSerialization(TestCase, SerializationMixin):
|
||||
else:
|
||||
os.environ[env_var] = old_value
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss")
|
||||
@parametrize("compute_crc32", (True, False))
|
||||
@parametrize("filename", (True, False))
|
||||
def test_crc32_options(self, compute_crc32, filename):
|
||||
# test both path and buffer case
|
||||
file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile
|
||||
sd = torch.nn.Linear(3, 5).state_dict()
|
||||
with file_creation_func() as f:
|
||||
try:
|
||||
torch.serialization.set_crc32_options(compute_crc32)
|
||||
torch.save(sd, f)
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
sd_loaded = torch.load(f, weights_only=True)
|
||||
self.assertEqual(sd_loaded, sd)
|
||||
finally:
|
||||
torch.serialization.set_crc32_options(True)
|
||||
|
||||
args = () if compute_crc32 else (zipfile.BadZipFile, "Bad CRC-32 for file")
|
||||
ctx = contextlib.nullcontext if compute_crc32 else self.assertRaisesRegex
|
||||
|
||||
if not filename:
|
||||
f.seek(0)
|
||||
# zip_file.extractall() will raise BadZipFile if CRC32 is not populated
|
||||
# we use the context manager to check whether CRC32 was populated
|
||||
with ctx(*args), tempfile.TemporaryDirectory() as temp_dir:
|
||||
with zipfile.ZipFile(f) as zip_file:
|
||||
zip_file.extractall(path=temp_dir)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
with serialization_method(use_zip=True):
|
||||
return super().run(*args, **kwargs)
|
||||
|
3
third_party/miniz-2.1.0/miniz.c
vendored
3
third_party/miniz-2.1.0/miniz.c
vendored
@ -6251,6 +6251,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE];
|
||||
mz_uint16 bit_flags = 0;
|
||||
mz_bool write_metadata_only = buf_size && !pBuf;
|
||||
mz_bool skip_crc32 = write_metadata_only || (level_and_flags & MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32);
|
||||
|
||||
if ((int)level_and_flags < 0)
|
||||
level_and_flags = MZ_DEFAULT_LEVEL;
|
||||
@ -6309,7 +6310,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n
|
||||
|
||||
if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))
|
||||
{
|
||||
if (!write_metadata_only) {
|
||||
if (!skip_crc32) {
|
||||
uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size);
|
||||
}
|
||||
uncomp_size = buf_size;
|
||||
|
3
third_party/miniz-2.1.0/miniz.h
vendored
3
third_party/miniz-2.1.0/miniz.h
vendored
@ -1001,7 +1001,8 @@ typedef enum {
|
||||
MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY = 0x2000, /* validate the local headers, but don't decompress the entire file and check the crc32 */
|
||||
MZ_ZIP_FLAG_WRITE_ZIP64 = 0x4000, /* always use the zip64 file format, instead of the original zip file format with automatic switch to zip64. Use as flags parameter with mz_zip_writer_init*_v2 */
|
||||
MZ_ZIP_FLAG_WRITE_ALLOW_READING = 0x8000,
|
||||
MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000
|
||||
MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000,
|
||||
MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32 = 0x20000, /* don't compute the crc32 of file data that's being added. */
|
||||
} mz_zip_flags;
|
||||
|
||||
typedef enum {
|
||||
|
@ -1447,9 +1447,9 @@ class PyTorchFileReader:
|
||||
|
||||
class PyTorchFileWriter:
|
||||
@overload
|
||||
def __init__(self, name: str) -> None: ...
|
||||
def __init__(self, name: str, compute_crc32 = True) -> None: ...
|
||||
@overload
|
||||
def __init__(self, buffer: BinaryIO) -> None: ...
|
||||
def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ...
|
||||
def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ...
|
||||
def write_end_of_file(self) -> None: ...
|
||||
def set_min_version(self, version: _int) -> None: ...
|
||||
|
@ -1389,8 +1389,12 @@ void initJITBindings(PyObject* module) {
|
||||
"fallback", [](GraphExecutorState& s) { return s.fallback; });
|
||||
|
||||
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
|
||||
.def(py::init<std::string>())
|
||||
.def(py::init([](const py::object& buffer) {
|
||||
.def(
|
||||
py::init<std::string, bool>(),
|
||||
py::arg("file_name"),
|
||||
py::arg("compute_crc32") = true)
|
||||
.def(
|
||||
py::init([](const py::object& buffer, bool compute_crc32 = true) {
|
||||
auto writer_func = [=](const void* data, size_t size) {
|
||||
// Writing an empty file is a noop
|
||||
if (size == 0) {
|
||||
@ -1408,9 +1412,15 @@ void initJITBindings(PyObject* module) {
|
||||
}
|
||||
return size;
|
||||
};
|
||||
return std::make_unique<PyTorchStreamWriter>(std::move(writer_func));
|
||||
}))
|
||||
.def(py::init<const std::function<size_t(const void*, size_t)>&>())
|
||||
return std::make_unique<PyTorchStreamWriter>(
|
||||
std::move(writer_func), compute_crc32);
|
||||
}),
|
||||
py::arg("buffer"),
|
||||
py::arg("compute_crc32") = true)
|
||||
.def(
|
||||
py::init<const std::function<size_t(const void*, size_t)>&, bool>(),
|
||||
py::arg("writer_func"),
|
||||
py::arg("compute_crc32") = true)
|
||||
// [Note: write_record_metadata]
|
||||
// The write_record_metadata function is intended to write metadata (i.e.
|
||||
// the zipfile header and end of central directory record) for a file
|
||||
|
@ -53,6 +53,8 @@ __all__ = [
|
||||
"load",
|
||||
"StorageType",
|
||||
"LoadEndianness",
|
||||
"get_crc32_options",
|
||||
"set_crc32_options",
|
||||
"get_default_load_endianness",
|
||||
"set_default_load_endianness",
|
||||
"get_default_mmap_options",
|
||||
@ -167,6 +169,34 @@ def set_default_load_endianness(endianness):
|
||||
_default_load_endian = endianness
|
||||
|
||||
|
||||
_compute_crc32: bool = True
|
||||
|
||||
|
||||
def get_crc32_options() -> bool:
|
||||
"""
|
||||
Get whether :func:`torch.save` computes and writes crc32 for each record.
|
||||
|
||||
Defaults to ``True``.
|
||||
"""
|
||||
return _compute_crc32
|
||||
|
||||
|
||||
def set_crc32_options(compute_crc32: bool):
|
||||
"""
|
||||
Set whether :func:`torch.save` computes and writes crc32 for each record.
|
||||
|
||||
.. note::
|
||||
Setting this to ``False`` may make unzipping of the ``torch.save`` output
|
||||
fail or warn due to corrupted CRC32. However ``torch.load`` will be
|
||||
able to load the file.
|
||||
|
||||
Args:
|
||||
compute_crc32 (bool): set crc32 compuation flag
|
||||
"""
|
||||
global _compute_crc32
|
||||
_compute_crc32 = compute_crc32
|
||||
|
||||
|
||||
_default_mmap_options: int = MAP_PRIVATE
|
||||
|
||||
|
||||
@ -682,9 +712,11 @@ class _open_zipfile_writer_file(_opener):
|
||||
# For filenames with non-ascii characters, we rely on Python
|
||||
# for writing out the file.
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32)
|
||||
)
|
||||
else:
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name))
|
||||
super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
@ -700,7 +732,7 @@ class _open_zipfile_writer_buffer(_opener):
|
||||
raise AttributeError(msg)
|
||||
raise TypeError(msg)
|
||||
self.buffer = buffer
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer))
|
||||
super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32))
|
||||
|
||||
def __exit__(self, *args) -> None:
|
||||
self.file_like.write_end_of_file()
|
||||
|
Reference in New Issue
Block a user