[jit] PyTorchStreamReader::getAllRecord should omit archive name prefix (#43317)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43317

Previous version was returning the path with a prefix so subsequent `getRecord` would fail.

There's only one place in PyTorch codebase that uses this function (introduced in https://github.com/pytorch/pytorch/pull/29339 ) and it's unlikely that anyone else is using it - it's not a public API anyway.

Test Plan: unittest

Reviewed By: houseroad

Differential Revision: D23235241

fbshipit-source-id: 6f7363e6981623aa96320f5e39c54e65d716240b
This commit is contained in:
Dmytro Dzhulgakov
2020-08-21 10:34:51 -07:00
committed by Facebook GitHub Bot
parent 0bd35de30e
commit 478fb925e6
2 changed files with 12 additions and 6 deletions

View File

@ -198,7 +198,17 @@ std::vector<std::string> PyTorchStreamReader::getAllRecords() {
char buf[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE];
for (size_t i = 0; i < num_files; i++) {
mz_zip_reader_get_filename(ar_.get(), i, buf, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE);
out.push_back(buf);
if (strncmp(
buf,
archive_name_plus_slash_.data(),
archive_name_plus_slash_.size()) != 0) {
CAFFE_THROW(
"file in archive is not in a subdirectory ",
archive_name_plus_slash_,
": ",
buf);
}
out.push_back(buf + archive_name_plus_slash_.size());
}
return out;
}

View File

@ -859,8 +859,4 @@ def _load(zip_file, map_location, pickle_module, **pickle_load_args):
def _is_torchscript_zip(zip_file):
for file_name in zip_file.get_all_records():
parts = file_name.split(os.sep)
if len(parts) > 1 and parts[1] == 'constants.pkl':
return True
return False
return 'constants.pkl' in zip_file.get_all_records()