mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[package] Correct usage of miniz API in PyTorchStreamReader (#55725)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55725 We were previously checking m_last_error on the miniz struct directly, which fails to preserve internal invariants and can the leave the reader broken in specific situations (reading a non-existent file). Using the provided error checking API fixes this. Differential Revision: D27693105 Test Plan: Imported from OSS Reviewed By: SplitInfinity Pulled By: suo fbshipit-source-id: 20c520bb1d590fb75751bca1e970df4f2b7eb043
This commit is contained in:
committed by
Facebook GitHub Bot
parent
c3a49cb30c
commit
0517222dc8
@ -140,15 +140,14 @@ void PyTorchStreamReader::init() {
|
||||
}
|
||||
|
||||
void PyTorchStreamReader::valid(const char* what, const char* info) {
|
||||
auto err = mz_zip_get_last_error(ar_.get());
|
||||
if (err != MZ_ZIP_NO_ERROR) {
|
||||
CAFFE_THROW(
|
||||
"PytorchStreamReader failed ",
|
||||
what,
|
||||
info,
|
||||
": ",
|
||||
mz_zip_get_error_string(err));
|
||||
}
|
||||
const auto err = mz_zip_get_last_error(ar_.get());
|
||||
TORCH_CHECK(
|
||||
err == MZ_ZIP_NO_ERROR,
|
||||
"PytorchStreamReader failed ",
|
||||
what,
|
||||
info,
|
||||
": ",
|
||||
mz_zip_get_error_string(err));
|
||||
}
|
||||
|
||||
constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
|
||||
@ -192,12 +191,17 @@ bool PyTorchStreamReader::hasRecord(const std::string& name) {
|
||||
std::lock_guard<std::mutex> guard(reader_lock_);
|
||||
std::string ss = archive_name_plus_slash_ + name;
|
||||
mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
|
||||
bool result = ar_->m_last_error != MZ_ZIP_FILE_NOT_FOUND;
|
||||
if (!result) {
|
||||
ar_->m_last_error = MZ_ZIP_NO_ERROR;
|
||||
const mz_zip_error err = mz_zip_get_last_error(ar_.get());
|
||||
|
||||
if (err == MZ_ZIP_NO_ERROR) {
|
||||
return true;
|
||||
} else if (err == MZ_ZIP_FILE_NOT_FOUND) {
|
||||
return false;
|
||||
} else {
|
||||
// A different error happened, raise it.
|
||||
valid("attempting to locate file ", name.c_str());
|
||||
}
|
||||
valid("attempting to locate file ", name.c_str());
|
||||
return result;
|
||||
TORCH_INTERNAL_ASSERT(false, "should not reach here");
|
||||
}
|
||||
|
||||
std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
@ -229,9 +233,6 @@ const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
|
||||
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
std::string ss = archive_name_plus_slash_ + name;
|
||||
size_t result = mz_zip_reader_locate_file(ar_.get(), ss.c_str(), nullptr, 0);
|
||||
if (ar_->m_last_error == MZ_ZIP_FILE_NOT_FOUND) {
|
||||
CAFFE_THROW("file not found: ", ss);
|
||||
}
|
||||
valid("locating file ", name.c_str());
|
||||
return result;
|
||||
}
|
||||
|
@ -68,6 +68,47 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
ASSERT_EQ(memcmp(the_file.c_str() + off2, data2.data(), data2.size()), 0);
|
||||
}
|
||||
|
||||
TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
||||
std::ostringstream oss;
|
||||
// write records through writers
|
||||
PyTorchStreamWriter writer([&](const void* b, size_t n) -> size_t {
|
||||
oss.write(static_cast<const char*>(b), n);
|
||||
return oss ? n : 0;
|
||||
});
|
||||
std::array<char, 127> data1;
|
||||
|
||||
for (int i = 0; i < data1.size(); ++i) {
|
||||
data1[i] = data1.size() - i;
|
||||
}
|
||||
writer.writeRecord("key1", data1.data(), data1.size());
|
||||
|
||||
std::array<char, 64> data2;
|
||||
for (int i = 0; i < data2.size(); ++i) {
|
||||
data2[i] = data2.size() - i;
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records[0], "key1");
|
||||
ASSERT_EQ(written_records[1], "key2");
|
||||
|
||||
writer.writeEndOfFile();
|
||||
|
||||
std::string the_file = oss.str();
|
||||
std::ofstream foo("output2.zip");
|
||||
foo.write(the_file.c_str(), the_file.size());
|
||||
foo.close();
|
||||
|
||||
std::istringstream iss(the_file);
|
||||
|
||||
// read records through readers
|
||||
PyTorchStreamReader reader(&iss);
|
||||
EXPECT_THROW(reader.getRecord("key3"), c10::Error);
|
||||
|
||||
// Reader should still work after throwing
|
||||
EXPECT_TRUE(reader.hasRecord("key1"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
Reference in New Issue
Block a user