mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 01:54:54 +08:00
Use filesystem functionality
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
This commit is contained in:
@ -10,7 +10,6 @@
|
||||
#include <fmt/format.h>
|
||||
#include <miniz.h>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <regex>
|
||||
@ -100,7 +99,7 @@ std::string create_temp_dir() {
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* object_file_ext() {
|
||||
constexpr const char* object_file_ext() {
|
||||
#ifdef _WIN32
|
||||
return ".obj";
|
||||
#else
|
||||
@ -108,7 +107,7 @@ const char* object_file_ext() {
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* extension_file_ext() {
|
||||
constexpr const char* extension_file_ext() {
|
||||
#ifdef _WIN32
|
||||
return ".pyd";
|
||||
#else
|
||||
@ -116,7 +115,7 @@ const char* extension_file_ext() {
|
||||
#endif
|
||||
}
|
||||
|
||||
const char* get_output_flags(bool compile_only) {
|
||||
constexpr const char* get_output_flags(bool compile_only) {
|
||||
if (compile_only) {
|
||||
#ifdef _WIN32
|
||||
return "/c /Fo"; // codespell:ignore
|
||||
@ -132,7 +131,7 @@ const char* get_output_flags(bool compile_only) {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool _is_windows_os() {
|
||||
constexpr bool _is_windows_os() {
|
||||
#ifdef _WIN32
|
||||
return true;
|
||||
#else
|
||||
@ -597,78 +596,66 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
<< found_filenames[1];
|
||||
}
|
||||
|
||||
temp_dir_ = normalize_path_separator(create_temp_dir());
|
||||
temp_dir_ = c10::filesystem::path(create_temp_dir());
|
||||
|
||||
std::string so_filename;
|
||||
std::string cpp_filename;
|
||||
std::string weight_blob_filename;
|
||||
std::vector<std::string> obj_filenames;
|
||||
std::string model_directory = normalize_path_separator(
|
||||
file_prefix + "data" + k_separator + "aotinductor" + k_separator +
|
||||
model_name);
|
||||
std::string const_directory = normalize_path_separator(
|
||||
file_prefix + "data" + k_separator + "constants");
|
||||
auto model_directory =
|
||||
c10::filesystem::path(file_prefix + "data") / "aotinductor" / model_name;
|
||||
auto const_directory =
|
||||
c10::filesystem::path(file_prefix + "data") / "constants";
|
||||
|
||||
// zip_filename_str can't be normalize_path_separator, because it should be
|
||||
// as index for mz_zip_reader_extract_file_to_file.
|
||||
for (auto const& zip_filename_str : found_filenames) {
|
||||
auto cur_filename = normalize_path_separator(zip_filename_str);
|
||||
auto cur_filename = c10::filesystem::path(zip_filename_str);
|
||||
// Only compile files in the specified model directory
|
||||
if (c10::starts_with(cur_filename, model_directory) ||
|
||||
c10::starts_with(cur_filename, const_directory)) {
|
||||
std::string output_path_str = temp_dir_;
|
||||
if (c10::starts_with(cur_filename.c_str(), model_directory.c_str()) ||
|
||||
c10::starts_with(cur_filename.c_str(), const_directory.c_str())) {
|
||||
c10::filesystem::path output_file_path(temp_dir_);
|
||||
|
||||
if (c10::starts_with(cur_filename, model_directory)) {
|
||||
output_path_str += k_separator;
|
||||
output_path_str += cur_filename;
|
||||
if (c10::starts_with(cur_filename.c_str(), model_directory.c_str())) {
|
||||
output_file_path /= cur_filename;
|
||||
} else { // startsWith(zip_filename_str, const_directory)
|
||||
// Extract constants to the same directory as the rest of the files
|
||||
// to be consistent with internal implementation
|
||||
size_t lastSlash = cur_filename.find_last_of(k_separator);
|
||||
std::string filename = cur_filename;
|
||||
if (lastSlash != std::string::npos) {
|
||||
filename = cur_filename.substr(lastSlash + 1);
|
||||
}
|
||||
output_path_str.append(k_separator)
|
||||
.append(model_directory)
|
||||
.append(k_separator)
|
||||
.append(filename);
|
||||
output_file_path /= model_directory;
|
||||
output_file_path /= c10::filesystem::path(cur_filename).filename();
|
||||
}
|
||||
|
||||
std::string output_file_path = normalize_path_separator(output_path_str);
|
||||
LOG(INFO) << "Extract file: " << zip_filename_str << " to "
|
||||
<< output_file_path;
|
||||
|
||||
// Create the parent directory if it doesn't exist
|
||||
size_t parent_path_idx = output_file_path.find_last_of(k_separator);
|
||||
TORCH_CHECK(
|
||||
parent_path_idx != std::string::npos,
|
||||
"Failed to find parent path in " + output_file_path);
|
||||
output_file_path.has_parent_path(),
|
||||
"Failed to find parent path in " + output_file_path.string());
|
||||
|
||||
std::string parent_path = output_file_path.substr(0, parent_path_idx);
|
||||
// Create the parent directory if it doesn't exist
|
||||
auto parent_path = output_file_path.parent_path();
|
||||
std::error_code ec{};
|
||||
c10::filesystem::create_directories(parent_path, ec);
|
||||
TORCH_CHECK(
|
||||
ec.value() == 0,
|
||||
"Failed to create directory " + parent_path,
|
||||
"Failed to create directory " + parent_path.string(),
|
||||
": ",
|
||||
ec.message());
|
||||
|
||||
// Extracts file to the temp directory
|
||||
zip_archive.extract_file(zip_filename_str, output_path_str);
|
||||
zip_archive.extract_file(zip_filename_str, output_file_path.string());
|
||||
|
||||
// Save the file for bookkeeping
|
||||
size_t extension_idx = output_file_path.find_last_of('.');
|
||||
if (extension_idx != std::string::npos) {
|
||||
std::string filename_extension = output_file_path.substr(extension_idx);
|
||||
if (output_file_path.has_extension()) {
|
||||
std::string filename_extension = output_file_path.extension().string();
|
||||
if (filename_extension == ".cpp") {
|
||||
cpp_filename = output_file_path;
|
||||
cpp_filename = output_file_path.string();
|
||||
} else if (filename_extension == object_file_ext()) {
|
||||
obj_filenames.push_back(output_file_path);
|
||||
obj_filenames.push_back(output_file_path.string());
|
||||
} else if (filename_extension == extension_file_ext()) {
|
||||
so_filename = output_file_path;
|
||||
so_filename = output_file_path.string();
|
||||
} else if (filename_extension == ".blob") {
|
||||
weight_blob_filename = output_file_path;
|
||||
weight_blob_filename = output_file_path.string();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -719,7 +706,7 @@ AOTIModelPackageLoader::AOTIModelPackageLoader(
|
||||
c10::Device device = c10::Device(device_key);
|
||||
device.set_index(device_index);
|
||||
|
||||
std::string cubin_dir = temp_dir_ + k_separator + model_directory;
|
||||
std::string cubin_dir = temp_dir_ + k_separator + model_directory.string();
|
||||
runner_ = registered_aoti_runner[device_key](
|
||||
so_path, num_runners, device.str(), cubin_dir, run_single_threaded);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user