From 0517222dc847729e22be81bff2069d9df3581378 Mon Sep 17 00:00:00 2001 From: Michael Suo Date: Tue, 13 Apr 2021 11:46:55 -0700 Subject: [PATCH] [package] Correct usage of miniz API in PyTorchStreamReader (#55725) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55725 We were previously checking m_last_error on the miniz struct directly, which fails to preserve internal invariants and can the leave the reader broken in specific situations (reading a non-existent file). Using the provided error checking API fixes this. Differential Revision: D27693105 Test Plan: Imported from OSS Reviewed By: SplitInfinity Pulled By: suo fbshipit-source-id: 20c520bb1d590fb75751bca1e970df4f2b7eb043 --- caffe2/serialize/inline_container.cc | 35 +++++++++---------- caffe2/serialize/inline_container_test.cc | 41 +++++++++++++++++++++++ 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index 69b800fee71f..e0e03736a7e3 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -140,15 +140,14 @@ void PyTorchStreamReader::init() { } void PyTorchStreamReader::valid(const char* what, const char* info) { - auto err = mz_zip_get_last_error(ar_.get()); - if (err != MZ_ZIP_NO_ERROR) { - CAFFE_THROW( - "PytorchStreamReader failed ", - what, - info, - ": ", - mz_zip_get_error_string(err)); - } + const auto err = mz_zip_get_last_error(ar_.get()); + TORCH_CHECK( + err == MZ_ZIP_NO_ERROR, + "PytorchStreamReader failed ", + what, + info, + ": ", + mz_zip_get_error_string(err)); } constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30; @@ -192,12 +191,17 @@ bool PyTorchStreamReader::hasRecord(const std::string& name) { std::lock_guard guard(reader_lock_); std::string ss = archive_name_plus_slash_ + name; mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0); - bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND; - if (!result) { - ar_->m_last_error = MZ_ZIP_NO_ERROR; + const mz_zip_error err = mz_zip_get_last_error(ar_.get()); + + if (err == MZ_ZIP_NO_ERROR) { + return true; + } else if (err == MZ_ZIP_FILE_NOT_FOUND) { + return false; + } else { + // A different error happened, raise it. + valid("attempting to locate file ", name.c_str()); } - valid("attempting to locate file ", name.c_str()); - return result; + TORCH_INTERNAL_ASSERT(false, "should not reach here"); } std::vector PyTorchStreamReader::getAllRecords() { @@ -229,9 +233,6 @@ const std::vector& PyTorchStreamWriter::getAllWrittenRecords() { size_t PyTorchStreamReader::getRecordID(const std::string& name) { std::string ss = archive_name_plus_slash_ + name; size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0); - if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) { - CAFFE_THROW("file not found: ", ss); - } valid("locating file ", name.c_str()); return result; } diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 1b796f91c38f..3aaca12fa9d3 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -68,6 +68,47 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0); } +TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { + std::ostringstream oss; + // write records through writers + PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t { + oss.write(static_cast(b), n); + return oss ? n : 0; + }); + std::array data1; + + for (int i = 0; i < data1.size(); ++i) { + data1[i] = data1.size() - i; + } + writer.writeRecord("key1", data1.data(), data1.size()); + + std::array data2; + for (int i = 0; i < data2.size(); ++i) { + data2[i] = data2.size() - i; + } + writer.writeRecord("key2", data2.data(), data2.size()); + + const std::vector& written_records = writer.getAllWrittenRecords(); + ASSERT_EQ(written_records[0], "key1"); + ASSERT_EQ(written_records[1], "key2"); + + writer.writeEndOfFile(); + + std::string the_file = oss.str(); + std::ofstream foo("output2.zip"); + foo.write(the_file.c_str(), the_file.size()); + foo.close(); + + std::istringstream iss(the_file); + + // read records through readers + PyTorchStreamReader reader(&iss); + EXPECT_THROW(reader.getRecord("key3"), c10::Error); + + // Reader should still work after throwing + EXPECT_TRUE(reader.hasRecord("key1")); +} + } // namespace } // namespace serialize } // namespace caffe2