mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
ban PyTorchStreamWriter from writing the same file twice (#61805)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61805 Similar in spirit to https://github.com/pytorch/pytorch/pull/61371. While writing two files with the same name is allowed by the ZIP format, most tools (including our own) handle this poorly. Previously I banned this within `PackageExporter`, but that doesn't cover other uses of the zip format like TorchScript. Given that there are no valid use cases and debugging issues caused by multiple file writes is fiendishly difficult, banning this behavior enitrely. Differential Revision: D29748968 D29748968 Test Plan: Imported from OSS Reviewed By: Lilyjjo Pulled By: suo fbshipit-source-id: 0afee1506c59c0f283ef41e4be562f9c22f21023
This commit is contained in:
committed by
Facebook GitHub Bot
parent
04043d681e
commit
f02cfcc802
@ -9,6 +9,7 @@
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
@ -235,8 +236,9 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
|
||||
return out;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& PyTorchStreamWriter::getAllWrittenRecords() {
|
||||
return files_written;
|
||||
const std::unordered_set<std::string>&
|
||||
PyTorchStreamWriter::getAllWrittenRecords() {
|
||||
return files_written_;
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::getRecordID(const std::string& name) {
|
||||
@ -356,6 +358,8 @@ void PyTorchStreamWriter::writeRecord(
|
||||
bool compress) {
|
||||
AT_ASSERT(!finalized_);
|
||||
AT_ASSERT(!archive_name_plus_slash_.empty());
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
files_written_.count(name) == 0, "Tried to serialize file twice: ", name);
|
||||
std::string full_name = archive_name_plus_slash_ + name;
|
||||
size_t padding_size =
|
||||
detail::getPadding(ar_->m_archive_size, full_name.size(), size, padding_);
|
||||
@ -376,7 +380,7 @@ void PyTorchStreamWriter::writeRecord(
|
||||
nullptr,
|
||||
0);
|
||||
valid("writing file ", name.c_str());
|
||||
files_written.push_back(name);
|
||||
files_written_.insert(name);
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include <istream>
|
||||
#include <mutex>
|
||||
#include <ostream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/Backend.h>
|
||||
@ -140,7 +141,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
bool compress = false);
|
||||
void writeEndOfFile();
|
||||
|
||||
const std::vector<std::string>& getAllWrittenRecords();
|
||||
const std::unordered_set<std::string>& getAllWrittenRecords();
|
||||
|
||||
bool finalized() const {
|
||||
return finalized_;
|
||||
@ -156,7 +157,7 @@ class TORCH_API PyTorchStreamWriter final {
|
||||
void setup(const std::string& file_name);
|
||||
void valid(const char* what, const char* info = "");
|
||||
size_t current_pos_ = 0;
|
||||
std::vector<std::string> files_written;
|
||||
std::unordered_set<std::string> files_written_;
|
||||
std::unique_ptr<mz_zip_archive> ar_;
|
||||
std::string archive_name_;
|
||||
std::string archive_name_plus_slash_;
|
||||
@ -184,7 +185,7 @@ size_t getPadding(
|
||||
size_t filename_size,
|
||||
size_t size,
|
||||
std::string& padding_buf);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
@ -35,9 +35,11 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records[0], "key1");
|
||||
ASSERT_EQ(written_records[1], "key2");
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 2);
|
||||
ASSERT_EQ(written_records.count("key1"), 1);
|
||||
ASSERT_EQ(written_records.count("key2"), 1);
|
||||
|
||||
writer.writeEndOfFile();
|
||||
|
||||
@ -95,9 +97,11 @@ TEST(PytorchStreamWriterAndReader, GetNonexistentRecordThrows) {
|
||||
}
|
||||
writer.writeRecord("key2", data2.data(), data2.size());
|
||||
|
||||
const std::vector<std::string>& written_records = writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records[0], "key1");
|
||||
ASSERT_EQ(written_records[1], "key2");
|
||||
const std::unordered_set<std::string>& written_records =
|
||||
writer.getAllWrittenRecords();
|
||||
ASSERT_EQ(written_records.size(), 2);
|
||||
ASSERT_EQ(written_records.count("key1"), 1);
|
||||
ASSERT_EQ(written_records.count("key2"), 1);
|
||||
|
||||
writer.writeEndOfFile();
|
||||
|
||||
|
@ -244,7 +244,7 @@ void writeArchiveV5(
|
||||
std::string prefix = archive_name + "/";
|
||||
|
||||
TORCH_INTERNAL_ASSERT(tensor_names.size() == data_pickle.tensorData().size());
|
||||
const std::vector<std::string>& pre_serialized_files =
|
||||
const std::unordered_set<std::string>& pre_serialized_files =
|
||||
writer.getAllWrittenRecords();
|
||||
|
||||
for (const auto& td : data_pickle.tensorData()) {
|
||||
|
Reference in New Issue
Block a user