diff --git a/build.bzl b/build.bzl index dbb1866ac548..3ba83e4578ca 100644 --- a/build.bzl +++ b/build.bzl @@ -36,7 +36,7 @@ def define_targets(rules): "caffe2/serialize/istream_adapter.cc", "caffe2/serialize/read_adapter_interface.cc", ], - copts = ["-fexceptions"], + copts = ["-fexceptions", "-DFBCODE_CAFFE2"], tags = [ "-fbcode", "supermodule:android/default/pytorch", diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 2761147cf333..70c13791da68 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -621,15 +621,17 @@ size_t ostream_write_func( return ret; } -PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name) - : archive_name_(basename(file_name)) { +PyTorchStreamWriter::PyTorchStreamWriter(const std::string& file_name, bool compute_crc32) + : archive_name_(basename(file_name)), + compute_crc32_(compute_crc32) { setup(file_name); } PyTorchStreamWriter::PyTorchStreamWriter( - const std::function writer_func) + const std::function writer_func, bool compute_crc32) : archive_name_("archive"), - writer_func_(writer_func) { + writer_func_(writer_func), + compute_crc32_(compute_crc32) { setup(archive_name_); } @@ -695,6 +697,11 @@ void PyTorchStreamWriter::writeRecord( size_t padding_size = detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); uint32_t flags = compress ? MZ_BEST_COMPRESSION : 0; + if (!compute_crc32_) { +#if (!defined(FBCODE_CAFFE2)) + flags |= MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32; +#endif + } mz_zip_writer_add_mem_ex_v2( /*pZip=*/ar_.get(), /*pArchive_name=*/full_name.c_str(), diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 6a13d414feb9..55a723f3b891 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -205,9 +205,9 @@ class TORCH_API PyTorchStreamReader final { class TORCH_API PyTorchStreamWriter final { public: - explicit PyTorchStreamWriter(const std::string& archive_name); + explicit PyTorchStreamWriter(const std::string& archive_name, bool compute_crc32 = true); explicit PyTorchStreamWriter( - const std::function writer_func); + const std::function writer_func, bool compute_crc32 = true); void setMinVersion(const uint64_t version); @@ -248,6 +248,7 @@ class TORCH_API PyTorchStreamWriter final { std::function writer_func_; uint64_t combined_uncomp_crc32_ = 0; std::string serialization_id_; + bool compute_crc32_; // This number will be updated when the model has operators // that have valid upgraders. diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index c05dc028a471..255fa2dbaa57 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -390,6 +390,8 @@ The following utility functions are related to serialization: .. currentmodule:: torch.serialization .. autofunction:: register_package +.. autofunction:: get_crc32_options +.. autofunction:: set_crc32_options .. autofunction:: get_default_load_endianness .. autofunction:: set_default_load_endianness .. autofunction:: get_default_mmap_options diff --git a/test/test_serialization.py b/test/test_serialization.py index a58e47c08317..59d6e21bd3a1 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4334,6 +4334,35 @@ class TestSerialization(TestCase, SerializationMixin): else: os.environ[env_var] = old_value + @unittest.skipIf(IS_FBCODE, "miniz version differs between fbcode and oss") + @parametrize("compute_crc32", (True, False)) + @parametrize("filename", (True, False)) + def test_crc32_options(self, compute_crc32, filename): + # test both path and buffer case + file_creation_func = TemporaryFileName if filename else tempfile.NamedTemporaryFile + sd = torch.nn.Linear(3, 5).state_dict() + with file_creation_func() as f: + try: + torch.serialization.set_crc32_options(compute_crc32) + torch.save(sd, f) + if not filename: + f.seek(0) + sd_loaded = torch.load(f, weights_only=True) + self.assertEqual(sd_loaded, sd) + finally: + torch.serialization.set_crc32_options(True) + + args = () if compute_crc32 else (zipfile.BadZipFile, "Bad CRC-32 for file") + ctx = contextlib.nullcontext if compute_crc32 else self.assertRaisesRegex + + if not filename: + f.seek(0) + # zip_file.extractall() will raise BadZipFile if CRC32 is not populated + # we use the context manager to check whether CRC32 was populated + with ctx(*args), tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(f) as zip_file: + zip_file.extractall(path=temp_dir) + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/third_party/miniz-2.1.0/miniz.c b/third_party/miniz-2.1.0/miniz.c index dc790d9e36b7..043a11b1d45f 100755 --- a/third_party/miniz-2.1.0/miniz.c +++ b/third_party/miniz-2.1.0/miniz.c @@ -6251,6 +6251,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE]; mz_uint16 bit_flags = 0; mz_bool write_metadata_only = buf_size && !pBuf; + mz_bool skip_crc32 = write_metadata_only || (level_and_flags & MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32); if ((int)level_and_flags < 0) level_and_flags = MZ_DEFAULT_LEVEL; @@ -6309,7 +6310,7 @@ mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_n if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) { - if (!write_metadata_only) { + if (!skip_crc32) { uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size); } uncomp_size = buf_size; diff --git a/third_party/miniz-2.1.0/miniz.h b/third_party/miniz-2.1.0/miniz.h index 2cad1370c638..0d5e73071f82 100755 --- a/third_party/miniz-2.1.0/miniz.h +++ b/third_party/miniz-2.1.0/miniz.h @@ -1001,7 +1001,8 @@ typedef enum { MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY = 0x2000, /* validate the local headers, but don't decompress the entire file and check the crc32 */ MZ_ZIP_FLAG_WRITE_ZIP64 = 0x4000, /* always use the zip64 file format, instead of the original zip file format with automatic switch to zip64. Use as flags parameter with mz_zip_writer_init*_v2 */ MZ_ZIP_FLAG_WRITE_ALLOW_READING = 0x8000, - MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000 + MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000, + MZ_ZIP_FLAG_DO_NOT_COMPUTE_CRC32 = 0x20000, /* don't compute the crc32 of file data that's being added. */ } mz_zip_flags; typedef enum { diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 930e4be2420e..f54f2617f548 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1447,9 +1447,9 @@ class PyTorchFileReader: class PyTorchFileWriter: @overload - def __init__(self, name: str) -> None: ... + def __init__(self, name: str, compute_crc32 = True) -> None: ... @overload - def __init__(self, buffer: BinaryIO) -> None: ... + def __init__(self, buffer: BinaryIO, compute_crc32 = True) -> None: ... def write_record(self, name: str, data: Union[Storage, bytes, _int], size: _int) -> None: ... def write_end_of_file(self) -> None: ... def set_min_version(self, version: _int) -> None: ... diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 4ac84dedb544..588c13c21bb5 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1389,28 +1389,38 @@ void initJITBindings(PyObject* module) { "fallback", [](GraphExecutorState& s) { return s.fallback; }); py::class_(m, "PyTorchFileWriter") - .def(py::init()) - .def(py::init([](const py::object& buffer) { - auto writer_func = [=](const void* data, size_t size) { - // Writing an empty file is a noop - if (size == 0) { - return size; - } - py::gil_scoped_acquire acquire; - if (!data) { - // See [Note: write_record_metadata] - buffer.attr("seek")( - size, py::module::import("os").attr("SEEK_CUR")); - } else { - auto memory_view = py::memoryview::from_memory( - reinterpret_cast(data), size); - buffer.attr("write")(std::move(memory_view)); - } - return size; - }; - return std::make_unique(std::move(writer_func)); - })) - .def(py::init&>()) + .def( + py::init(), + py::arg("file_name"), + py::arg("compute_crc32") = true) + .def( + py::init([](const py::object& buffer, bool compute_crc32 = true) { + auto writer_func = [=](const void* data, size_t size) { + // Writing an empty file is a noop + if (size == 0) { + return size; + } + py::gil_scoped_acquire acquire; + if (!data) { + // See [Note: write_record_metadata] + buffer.attr("seek")( + size, py::module::import("os").attr("SEEK_CUR")); + } else { + auto memory_view = py::memoryview::from_memory( + reinterpret_cast(data), size); + buffer.attr("write")(std::move(memory_view)); + } + return size; + }; + return std::make_unique( + std::move(writer_func), compute_crc32); + }), + py::arg("buffer"), + py::arg("compute_crc32") = true) + .def( + py::init&, bool>(), + py::arg("writer_func"), + py::arg("compute_crc32") = true) // [Note: write_record_metadata] // The write_record_metadata function is intended to write metadata (i.e. // the zipfile header and end of central directory record) for a file diff --git a/torch/serialization.py b/torch/serialization.py index 17517db6e7fd..a87230e824aa 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -53,6 +53,8 @@ __all__ = [ "load", "StorageType", "LoadEndianness", + "get_crc32_options", + "set_crc32_options", "get_default_load_endianness", "set_default_load_endianness", "get_default_mmap_options", @@ -167,6 +169,34 @@ def set_default_load_endianness(endianness): _default_load_endian = endianness +_compute_crc32: bool = True + + +def get_crc32_options() -> bool: + """ + Get whether :func:`torch.save` computes and writes crc32 for each record. + + Defaults to ``True``. + """ + return _compute_crc32 + + +def set_crc32_options(compute_crc32: bool): + """ + Set whether :func:`torch.save` computes and writes crc32 for each record. + + .. note:: + Setting this to ``False`` may make unzipping of the ``torch.save`` output + fail or warn due to corrupted CRC32. However ``torch.load`` will be + able to load the file. + + Args: + compute_crc32 (bool): set crc32 compuation flag + """ + global _compute_crc32 + _compute_crc32 = compute_crc32 + + _default_mmap_options: int = MAP_PRIVATE @@ -682,9 +712,11 @@ class _open_zipfile_writer_file(_opener): # For filenames with non-ascii characters, we rely on Python # for writing out the file. self.file_stream = io.FileIO(self.name, mode="w") - super().__init__(torch._C.PyTorchFileWriter(self.file_stream)) + super().__init__( + torch._C.PyTorchFileWriter(self.file_stream, _compute_crc32) + ) else: - super().__init__(torch._C.PyTorchFileWriter(self.name)) + super().__init__(torch._C.PyTorchFileWriter(self.name, _compute_crc32)) def __exit__(self, *args) -> None: self.file_like.write_end_of_file() @@ -700,7 +732,7 @@ class _open_zipfile_writer_buffer(_opener): raise AttributeError(msg) raise TypeError(msg) self.buffer = buffer - super().__init__(torch._C.PyTorchFileWriter(buffer)) + super().__init__(torch._C.PyTorchFileWriter(buffer, _compute_crc32)) def __exit__(self, *args) -> None: self.file_like.write_end_of_file()