mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added directory check before saving in C++ API
Fixes #75177 Couldn't find any utility method to get directory name in pytorch repo, hence creating a function for that. Let me know if a new function is not needed. I also referred [this](https://github.com/pytorch/pytorch/blob/master/c10/test/util/tempfile_test.cpp#L15) for directory check. Also I am using TORCH_CHECK to show the error. This is highly verbose with the entire stack visible. Is there any alternative for the same so that it is easier to read? This could happen a frequently, so small and concise error would be more helpful here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/75681 Approved by: https://github.com/albanD
This commit is contained in:
@ -5,6 +5,9 @@
|
||||
#include <ostream>
|
||||
#include <fstream>
|
||||
#include <algorithm>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
@ -49,6 +52,17 @@ static std::string basename(const std::string& name) {
|
||||
return name.substr(start, end - start);
|
||||
}
|
||||
|
||||
static std::string parentdir(const std::string& name) {
|
||||
size_t end = name.find_last_of('/');
|
||||
if(end == std::string::npos)
|
||||
end = name.find_last_of('\\');
|
||||
|
||||
if(end == std::string::npos)
|
||||
return "";
|
||||
|
||||
return name.substr(0, end);
|
||||
}
|
||||
|
||||
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
|
||||
return in_->read(pos, buf, n, "reading file");
|
||||
}
|
||||
@ -338,6 +352,13 @@ void PyTorchStreamWriter::setup(const string& file_name) {
|
||||
file_name,
|
||||
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
|
||||
valid("opening archive ", file_name.c_str());
|
||||
|
||||
const std::string dir_name = parentdir(file_name);
|
||||
if(!dir_name.empty()) {
|
||||
struct stat st;
|
||||
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
|
||||
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
|
||||
}
|
||||
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
|
||||
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
|
||||
file_stream_.write(static_cast<const char*>(buf), nbytes);
|
||||
|
@ -27,6 +27,18 @@ Module roundtripThroughMobile(const Module& m) {
|
||||
mobilem._ivalue(), files, constants, 8);
|
||||
}
|
||||
|
||||
template <class Functor>
|
||||
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
|
||||
try {
|
||||
std::forward<Functor>(functor)();
|
||||
} catch (const Error& e) {
|
||||
EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
|
||||
return;
|
||||
}
|
||||
ADD_FAILURE() << "Expected to throw exception with message \""
|
||||
<< expectedMessage << "\" but didn't throw";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(SerializationTest, ExtraFilesHookPreference) {
|
||||
@ -238,5 +250,14 @@ TEST(TestSourceRoundTrip,
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SerializationTest, ParentDirNotExist) {
|
||||
expectThrowsEq(
|
||||
[]() {
|
||||
auto t = torch::nn::Linear(5, 5);
|
||||
torch::save(t, "./doesnotexist/file.pt");
|
||||
},
|
||||
"Parent directory ./doesnotexist does not exist.");
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace torch
|
||||
|
Reference in New Issue
Block a user