Added directory check before saving in C++ API

Fixes #75177

Couldn't find any utility method to get directory name in pytorch repo, hence creating a function for that.
Let me know if a new function is not needed.

I also referred [this](https://github.com/pytorch/pytorch/blob/master/c10/test/util/tempfile_test.cpp#L15) for directory check.

Also I am using TORCH_CHECK to show the error. This is highly verbose with the entire stack visible. Is there any alternative for the same so that it is easier to read? This could happen a frequently, so small and concise error would be more helpful here.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75681
Approved by: https://github.com/albanD
This commit is contained in:
Prem
2022-04-22 20:04:41 +00:00
committed by PyTorch MergeBot
parent 79891abf40
commit 7557407653
2 changed files with 42 additions and 0 deletions

View File

@ -5,6 +5,9 @@
#include <ostream>
#include <fstream>
#include <algorithm>
#include <sys/stat.h>
#include <sys/types.h>
#include <c10/core/Allocator.h>
#include <c10/core/CPUAllocator.h>
@ -49,6 +52,17 @@ static std::string basename(const std::string& name) {
return name.substr(start, end - start);
}
static std::string parentdir(const std::string& name) {
size_t end = name.find_last_of('/');
if(end == std::string::npos)
end = name.find_last_of('\\');
if(end == std::string::npos)
return "";
return name.substr(0, end);
}
size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
return in_->read(pos, buf, n, "reading file");
}
@ -338,6 +352,13 @@ void PyTorchStreamWriter::setup(const string& file_name) {
file_name,
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
valid("opening archive ", file_name.c_str());
const std::string dir_name = parentdir(file_name);
if(!dir_name.empty()) {
struct stat st;
bool dir_exists = (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR));
TORCH_CHECK(dir_exists, "Parent directory ", dir_name, " does not exist.");
}
TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened.");
writer_func_ = [this](const void* buf, size_t nbytes) -> size_t {
file_stream_.write(static_cast<const char*>(buf), nbytes);

View File

@ -27,6 +27,18 @@ Module roundtripThroughMobile(const Module& m) {
mobilem._ivalue(), files, constants, 8);
}
template <class Functor>
inline void expectThrowsEq(Functor&& functor, const char* expectedMessage) {
try {
std::forward<Functor>(functor)();
} catch (const Error& e) {
EXPECT_STREQ(e.what_without_backtrace(), expectedMessage);
return;
}
ADD_FAILURE() << "Expected to throw exception with message \""
<< expectedMessage << "\" but didn't throw";
}
} // namespace
TEST(SerializationTest, ExtraFilesHookPreference) {
@ -238,5 +250,14 @@ TEST(TestSourceRoundTrip,
}
}
TEST(SerializationTest, ParentDirNotExist) {
expectThrowsEq(
[]() {
auto t = torch::nn::Linear(5, 5);
torch::save(t, "./doesnotexist/file.pt");
},
"Parent directory ./doesnotexist does not exist.");
}
} // namespace jit
} // namespace torch