mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:07:10 +08:00
[PyTorch Edge] Skip writing version during backport (#65842)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65842 During backport, only parts of the model (like bytecode.pkl) needs to be re-written, while the rest of the model is the same. However, `version` will always be re-written when `PyTorchStreamWriter` is destrcuted. Change version to optional and add an api to allow skipping writing version when closing the writer. ghstack-source-id: 139580386 Test Plan: buck run papaya/scripts/repro:save_load Reviewed By: iseeyuan, tugsbayasgalan Differential Revision: D31262904 fbshipit-source-id: 3b8a5e1aaa610ffb0fe8a616d9ad9d0987c03f23
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7941590a51
commit
8b8012a165
@ -384,13 +384,16 @@ void PyTorchStreamWriter::writeRecord(
|
||||
}
|
||||
|
||||
void PyTorchStreamWriter::writeEndOfFile() {
|
||||
// Rewrites version info
|
||||
std::string version = c10::to_string(version_);
|
||||
version.push_back('\n');
|
||||
if (version_ >= 0x6L) {
|
||||
writeRecord(".data/version", version.c_str(), version.size());
|
||||
} else {
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
auto allRecords = getAllWrittenRecords();
|
||||
// If no ".data/version" or "version" record in the output model, rewrites version info
|
||||
if(allRecords.find(".data/version") == allRecords.end() && allRecords.find("version") == allRecords.end()) {
|
||||
std::string version = c10::to_string(version_);
|
||||
version.push_back('\n');
|
||||
if (version_ >= 0x6L) {
|
||||
writeRecord(".data/version", version.c_str(), version.size());
|
||||
} else {
|
||||
writeRecord("version", version.c_str(), version.size());
|
||||
}
|
||||
}
|
||||
|
||||
AT_ASSERT(!finalized_);
|
||||
|
Reference in New Issue
Block a user