[package] Correct usage of miniz API in PyTorchStreamReader (#55725)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55725

We were previously checking m_last_error on the miniz struct directly,
which fails to preserve internal invariants and can the leave the reader
broken in specific situations (reading a non-existent file).

Using the provided error checking API fixes this.

Differential Revision: D27693105

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Pulled By: suo

fbshipit-source-id: 20c520bb1d590fb75751bca1e970df4f2b7eb043
This commit is contained in:
Michael Suo
2021-04-13 11:46:55 -07:00
committed by Facebook GitHub Bot
parent c3a49cb30c
commit 0517222dc8
2 changed files with 59 additions and 17 deletions

View File

@ -140,15 +140,14 @@ void PyTorchStreamReader::init() {
}
void PyTorchStreamReader::valid(const char* what, const char* info) {
auto err = mz_zip_get_last_error(ar_.get());
if (err != MZ_ZIP_NO_ERROR) {
CAFFE_THROW(
"PytorchStreamReader failed ",
what,
info,
": ",
mz_zip_get_error_string(err));
}
const auto err = mz_zip_get_last_error(ar_.get());
TORCH_CHECK(
err == MZ_ZIP_NO_ERROR,
"PytorchStreamReader failed ",
what,
info,
": ",
mz_zip_get_error_string(err));
}
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
@ -192,12 +191,17 @@ bool PyTorchStreamReader::hasRecord(const std::string& name) {
std::lock_guard<std::mutex> guard(reader_lock_);
std::string ss = archive_name_plus_slash_ + name;
mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND;
if (!result) {
ar_->m_last_error = MZ_ZIP_NO_ERROR;
const mz_zip_error err = mz_zip_get_last_error(ar_.get());
if (err == MZ_ZIP_NO_ERROR) {
return true;
} else if (err == MZ_ZIP_FILE_NOT_FOUND) {
return false;
} else {
// A different error happened, raise it.
valid("attempting to locate file ", name.c_str());
}
valid("attempting to locate file ", name.c_str());
return result;
TORCH_INTERNAL_ASSERT(false, "should not reach here");
}
std::vector<std::string> PyTorchStreamReader::getAllRecords() {
@ -229,9 +233,6 @@ const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
std::string ss = archive_name_plus_slash_ + name;
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
CAFFE_THROW("file not found: ", ss);
}
valid("locating file ", name.c_str());
return result;
}

View File

@ -68,6 +68,47 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
}
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
std::ostringstream oss;
// write records through writers
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
oss.write(static_cast<const char*>(b), n);
return oss ? n : 0;
});
std::array<char, 127> data1;
for (int i = 0; i < data1.size(); ++i) {
data1[i] = data1.size() - i;
}
writer.writeRecord("key1", data1.data(), data1.size());
std::array<char, 64> data2;
for (int i = 0; i < data2.size(); ++i) {
data2[i] = data2.size() - i;
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
ASSERT_EQ(written_records[0], "key1");
ASSERT_EQ(written_records[1], "key2");
writer.writeEndOfFile();
std::string the_file = oss.str();
std::ofstream foo("output2.zip");
foo.write(the_file.c_str(), the_file.size());
foo.close();
std::istringstream iss(the_file);
// read records through readers
PyTorchStreamReader reader(&iss);
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
// Reader should still work after throwing
EXPECT_TRUE(reader.hasRecord("key1"));
}
} // namespace
} // namespace serialize
} // namespace caffe2