ban PyTorchStreamWriter from writing the same file twice (#61805)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61805

Similar in spirit to https://github.com/pytorch/pytorch/pull/61371.
While writing two files with the same name is allowed by the ZIP format,
most tools (including our own) handle this poorly. Previously I banned
this within `PackageExporter`, but that doesn't cover other uses of the
zip format like TorchScript.

Given that there are no valid use cases and debugging issues caused by
multiple file writes is fiendishly difficult, banning this behavior enitrely.

Differential Revision:
D29748968
D29748968

Test Plan: Imported from OSS

Reviewed By: Lilyjjo

Pulled By: suo

fbshipit-source-id: 0afee1506c59c0f283ef41e4be562f9c22f21023
This commit is contained in:
Michael Suo
2021-07-19 18:20:53 -07:00
committed by Facebook GitHub Bot
parent 04043d681e
commit f02cfcc802
4 changed files with 22 additions and 13 deletions

View File

@ -9,6 +9,7 @@
#include <c10/core/Allocator.h>
#include <c10/core/CPUAllocator.h>
#include <c10/core/Backend.h>
#include <c10/util/Exception.h>
#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
@ -235,8 +236,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
return out;
}
const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
return files_written;
const std::unordered_set<std::string>&
PyTorchStreamWriter::getAllWrittenRecords() {
return files_written_;
}
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
@ -356,6 +358,8 @@ void PyTorchStreamWriter::writeRecord(
bool compress) {
AT_ASSERT(!finalized_);
AT_ASSERT(!archive_name_plus_slash_.empty());
TORCH_INTERNAL_ASSERT(
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
std::string full_name = archive_name_plus_slash_ + name;
size_t padding_size =
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
@ -376,7 +380,7 @@ void PyTorchStreamWriter::writeRecord(
nullptr,
0);
valid("writing file ", name.c_str());
files_written.push_back(name);
files_written_.insert(name);
}
void PyTorchStreamWriter::writeEndOfFile() {

View File

@ -7,6 +7,7 @@
#include <istream>
#include <mutex>
#include <ostream>
#include <unordered_set>
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>
@ -140,7 +141,7 @@ class TORCH_API PyTorchStreamWriter final {
bool compress = false);
void writeEndOfFile();
const std::vector<std::string>& getAllWrittenRecords();
const std::unordered_set<std::string>& getAllWrittenRecords();
bool finalized() const {
return finalized_;
@ -156,7 +157,7 @@ class TORCH_API PyTorchStreamWriter final {
void setup(const std::string& file_name);
void valid(const char* what, const char* info = "");
size_t current_pos_ = 0;
std::vector<std::string> files_written;
std::unordered_set<std::string> files_written_;
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::string archive_name_plus_slash_;
@ -184,7 +185,7 @@ size_t getPadding(
size_t filename_size,
size_t size,
std::string& padding_buf);
}
} // namespace detail
} // namespace serialize
} // namespace caffe2

View File

@ -35,9 +35,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
ASSERT_EQ(written_records[0], "key1");
ASSERT_EQ(written_records[1], "key2");
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();
@ -95,9 +97,11 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
}
writer.writeRecord("key2", data2.data(), data2.size());
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
ASSERT_EQ(written_records[0], "key1");
ASSERT_EQ(written_records[1], "key2");
const std::unordered_set<std::string>& written_records =
writer.getAllWrittenRecords();
ASSERT_EQ(written_records.size(), 2);
ASSERT_EQ(written_records.count("key1"), 1);
ASSERT_EQ(written_records.count("key2"), 1);
writer.writeEndOfFile();

View File

@ -244,7 +244,7 @@ void writeArchiveV5(
std::string prefix = archive_name + "/";
TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
const std::vector<std::string>& pre_serialized_files =
const std::unordered_set<std::string>& pre_serialized_files =
writer.getAllWrittenRecords();
for (const auto& td : data_pickle.tensorData()) {