mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AOTI] normalize path and process model files. (#158705)
Continued to https://github.com/pytorch/pytorch/pull/158702 , split `zip_filename_str` and real file path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158705 Approved by: https://github.com/desertfire
This commit is contained in:
@ -478,27 +478,31 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string model_directory = file_prefix + "data" + k_separator +
|
||||
"aotinductor" + k_separator + model_name;
|
||||
std::string const_directory =
|
||||
file_prefix + "data" + k_separator + "constants";
|
||||
std::string model_directory = normalize_path_separator(
|
||||
file_prefix + "data" + k_separator + "aotinductor" + k_separator +
|
||||
model_name);
|
||||
std::string const_directory = normalize_path_separator(
|
||||
file_prefix + "data" + k_separator + "constants");
|
||||
|
||||
for (const std::string& filename_str : found_filenames) {
|
||||
// zip_filename_str can't be normalize_path_separator, because it should be
|
||||
// as index for mz_zip_reader_extract_file_to_file.
|
||||
for (auto zip_filename_str : found_filenames) {
|
||||
auto cur_filename = normalize_path_separator(zip_filename_str);
|
||||
// Only compile files in the specified model directory
|
||||
if (c10::starts_with(filename_str, model_directory) ||
|
||||
c10::starts_with(filename_str, const_directory)) {
|
||||
if (c10::starts_with(cur_filename, model_directory) ||
|
||||
c10::starts_with(cur_filename, const_directory)) {
|
||||
std::string output_path_str = temp_dir_;
|
||||
|
||||
if (c10::starts_with(filename_str, model_directory)) {
|
||||
if (c10::starts_with(cur_filename, model_directory)) {
|
||||
output_path_str += k_separator;
|
||||
output_path_str += filename_str;
|
||||
} else { // startsWith(filename_str, const_directory)
|
||||
output_path_str += cur_filename;
|
||||
} else { // startsWith(zip_filename_str, const_directory)
|
||||
// Extract constants to the same directory as the rest of the files
|
||||
// to be consistent with internal implementation
|
||||
size_t lastSlash = filename_str.find_last_of(k_separator);
|
||||
std::string filename = filename_str;
|
||||
size_t lastSlash = cur_filename.find_last_of(k_separator);
|
||||
std::string filename = cur_filename;
|
||||
if (lastSlash != std::string::npos) {
|
||||
filename = filename_str.substr(lastSlash + 1);
|
||||
filename = cur_filename.substr(lastSlash + 1);
|
||||
}
|
||||
output_path_str.append(k_separator)
|
||||
.append(model_directory)
|
||||
@ -507,7 +511,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
}
|
||||
|
||||
std::string output_file_path = normalize_path_separator(output_path_str);
|
||||
LOG(INFO) << "Extract file: " << filename_str << " to "
|
||||
LOG(INFO) << "Extract file: " << zip_filename_str << " to "
|
||||
<< output_file_path;
|
||||
|
||||
// Create the parent directory if it doesn't exist
|
||||
@ -526,10 +530,12 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
|
||||
// Extracts file to the temp directory
|
||||
mz_bool b_extract = mz_zip_reader_extract_file_to_file(
|
||||
&zip_archive, filename_str.c_str(), output_file_path.c_str(), 0);
|
||||
&zip_archive, zip_filename_str.c_str(), output_file_path.c_str(), 0);
|
||||
if (b_extract == MZ_FALSE) {
|
||||
throw std::runtime_error(fmt::format(
|
||||
"Failed to extract file {} to {}", filename_str, output_file_path));
|
||||
"Failed to extract file {} to {}",
|
||||
zip_filename_str,
|
||||
output_file_path));
|
||||
}
|
||||
|
||||
// Save the file for bookkeeping
|
||||
|
Reference in New Issue
Block a user