diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index d74dfb4a4a1c..cb36fe0a4bc1 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -9,6 +9,7 @@ #include #include #include +#include #include "caffe2/core/common.h" #include "caffe2/core/logging.h" @@ -235,8 +236,9 @@ std::vector PyTorchStreamReader::getAllRecords() { return out; } -const std::vector& PyTorchStreamWriter::getAllWrittenRecords() { - return files_written; +const std::unordered_set& +PyTorchStreamWriter::getAllWrittenRecords() { + return files_written_; } size_t PyTorchStreamReader::getRecordID(const std::string& name) { @@ -356,6 +358,8 @@ void PyTorchStreamWriter::writeRecord( bool compress) { AT_ASSERT(!finalized_); AT_ASSERT(!archive_name_plus_slash_.empty()); + TORCH_INTERNAL_ASSERT( + files_written_.count(name) == 0, "Tried to serialize file twice: ", name); std::string full_name = archive_name_plus_slash_ + name; size_t padding_size = detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_); @@ -376,7 +380,7 @@ void PyTorchStreamWriter::writeRecord( nullptr, 0); valid("writing file ", name.c_str()); - files_written.push_back(name); + files_written_.insert(name); } void PyTorchStreamWriter::writeEndOfFile() { diff --git a/caffe2/serialize/inline_container.h b/caffe2/serialize/inline_container.h index 281d1756d75f..4eb1b8e71ce6 100644 --- a/caffe2/serialize/inline_container.h +++ b/caffe2/serialize/inline_container.h @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -140,7 +141,7 @@ class TORCH_API PyTorchStreamWriter final { bool compress = false); void writeEndOfFile(); - const std::vector& getAllWrittenRecords(); + const std::unordered_set& getAllWrittenRecords(); bool finalized() const { return finalized_; @@ -156,7 +157,7 @@ class TORCH_API PyTorchStreamWriter final { void setup(const std::string& file_name); void valid(const char* what, const char* info = ""); size_t current_pos_ = 0; - std::vector files_written; + std::unordered_set files_written_; std::unique_ptr ar_; std::string archive_name_; std::string archive_name_plus_slash_; @@ -184,7 +185,7 @@ size_t getPadding( size_t filename_size, size_t size, std::string& padding_buf); -} +} // namespace detail } // namespace serialize } // namespace caffe2 diff --git a/caffe2/serialize/inline_container_test.cc b/caffe2/serialize/inline_container_test.cc index 3a9f511ee9cf..7a65bf1ab45d 100644 --- a/caffe2/serialize/inline_container_test.cc +++ b/caffe2/serialize/inline_container_test.cc @@ -35,9 +35,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) { } writer.writeRecord("key2", data2.data(), data2.size()); - const std::vector& written_records = writer.getAllWrittenRecords(); - ASSERT_EQ(written_records[0], "key1"); - ASSERT_EQ(written_records[1], "key2"); + const std::unordered_set& written_records = + writer.getAllWrittenRecords(); + ASSERT_EQ(written_records.size(), 2); + ASSERT_EQ(written_records.count("key1"), 1); + ASSERT_EQ(written_records.count("key2"), 1); writer.writeEndOfFile(); @@ -95,9 +97,11 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) { } writer.writeRecord("key2", data2.data(), data2.size()); - const std::vector& written_records = writer.getAllWrittenRecords(); - ASSERT_EQ(written_records[0], "key1"); - ASSERT_EQ(written_records[1], "key2"); + const std::unordered_set& written_records = + writer.getAllWrittenRecords(); + ASSERT_EQ(written_records.size(), 2); + ASSERT_EQ(written_records.count("key1"), 1); + ASSERT_EQ(written_records.count("key2"), 1); writer.writeEndOfFile(); diff --git a/torch/csrc/jit/mobile/backport_manager.cpp b/torch/csrc/jit/mobile/backport_manager.cpp index 91c8548ee7d8..3cd815626c0f 100644 --- a/torch/csrc/jit/mobile/backport_manager.cpp +++ b/torch/csrc/jit/mobile/backport_manager.cpp @@ -244,7 +244,7 @@ void writeArchiveV5( std::string prefix = archive_name + "/"; TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size()); - const std::vector& pre_serialized_files = + const std::unordered_set& pre_serialized_files = writer.getAllWrittenRecords(); for (const auto& td : data_pickle.tensorData()) {