diff --git a/caffe2/serialize/inline_container.cc b/caffe2/serialize/inline_container.cc index e39a78c62dd5..015c480cf04f 100644 --- a/caffe2/serialize/inline_container.cc +++ b/caffe2/serialize/inline_container.cc @@ -27,6 +27,10 @@ #include "caffe2/serialize/versions.h" #include "miniz.h" +#ifdef _WIN32 +#include +#endif // _WIN32 + namespace caffe2 { namespace serialize { constexpr std::string_view kDebugPklSuffix(".debug_pkl"); @@ -711,21 +715,35 @@ void PyTorchStreamWriter::setup(const string& file_name) { if (archive_name_.size() == 0) { CAFFE_THROW("invalid file name: ", file_name); } - if (!writer_func_) { - file_stream_.open( - file_name, - std::ofstream::out | std::ofstream::trunc | std::ofstream::binary); - valid("opening archive ", file_name.c_str()); - const std::string dir_name = parentdir(file_name); - if (!dir_name.empty()) { - struct stat st; - bool dir_exists = - (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); - TORCH_CHECK( - dir_exists, "Parent directory ", dir_name, " does not exist."); + const std::string dir_name = parentdir(file_name); + if (!dir_name.empty()) { + struct stat st; + bool dir_exists = + (stat(dir_name.c_str(), &st) == 0 && (st.st_mode & S_IFDIR)); + TORCH_CHECK( + dir_exists, "Parent directory ", dir_name, " does not exist."); + } + TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); + + if (!writer_func_) { + valid("opening archive ", file_name.c_str()); + try { + file_stream_.exceptions(std::ios_base::failbit | std::ios_base::badbit); + file_stream_.open( + file_name, + std::ofstream::out | std::ofstream::trunc | std::ofstream::binary + ); + } catch (const std::ios_base::failure& e) { +#ifdef _WIN32 + // Windows have verbose error code, we prefer to use it than std errno. + uint32_t error_code = GetLastError(); + CAFFE_THROW("open file failed with error code: ", error_code); +#else // !_WIN32 + CAFFE_THROW("open file failed with strerror: ", strerror(errno)); +#endif // _WIN32 } - TORCH_CHECK(file_stream_, "File ", file_name, " cannot be opened."); + writer_func_ = [this](const void* buf, size_t nbytes) -> size_t { if (!buf) { // See [Note: write_record_metadata]