[PyTorch Edge] Skip writing version during backport (#65842)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65842

During backport, only parts of the model (like bytecode.pkl) needs to be re-written, while the rest of the model is the same. However, `version` will always be re-written when `PyTorchStreamWriter` is destrcuted.

Change version to optional and add an api to allow skipping writing version when closing the writer.
ghstack-source-id: 139580386

Test Plan: buck run papaya/scripts/repro:save_load

Reviewed By: iseeyuan, tugsbayasgalan

Differential Revision: D31262904

fbshipit-source-id: 3b8a5e1aaa610ffb0fe8a616d9ad9d0987c03f23
This commit is contained in:
Chen Lai
2021-10-01 21:16:58 -07:00
committed by Facebook GitHub Bot
parent 7941590a51
commit 8b8012a165
2 changed files with 15 additions and 17 deletions

View File

@ -384,13 +384,16 @@ void PyTorchStreamWriter::writeRecord(
}
void PyTorchStreamWriter::writeEndOfFile() {
// Rewrites version info
std::string version = c10::to_string(version_);
version.push_back('\n');
if (version_ >= 0x6L) {
writeRecord(".data/version", version.c_str(), version.size());
} else {
writeRecord("version", version.c_str(), version.size());
auto allRecords = getAllWrittenRecords();
// If no ".data/version" or "version" record in the output model, rewrites version info
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
std::string version = c10::to_string(version_);
version.push_back('\n');
if (version_ >= 0x6L) {
writeRecord(".data/version", version.c_str(), version.size());
} else {
writeRecord("version", version.c_str(), version.size());
}
}
AT_ASSERT(!finalized_);