mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Prototype] [PyTorch Edge] Speed up model loading by 12% by directly calling the C file API from FileAdapter (#61997)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61997 After profiling the model loading latency on AI Bench (Android Galaxy S8 US), it seems like a significant amount of time was spent reading data using FileAdapter, which internally calls IStreamAdapter. However, IStreamAdapter uses `std::istream` under the hood, which is not that efficient. This change reduces the model loading time from [~293ms](https://www.internalfb.com/intern/aibench/details/600870874797229) to [~254ms](https://www.internalfb.com/intern/aibench/details/163731416457694), which is a reduction of ~12%. ghstack-source-id: 134634610 Test Plan: See the AI Bench links above. Reviewed By: raziel Differential Revision: D29812191 fbshipit-source-id: 57810fdc1ac515305f5504f88ac5e9e4319e9d28
This commit is contained in:
committed by
Facebook GitHub Bot
parent
693d8f2f07
commit
725d98bab6
@ -1,29 +1,68 @@
|
||||
#include "caffe2/serialize/file_adapter.h"
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cstdio>
|
||||
|
||||
#include "caffe2/core/common.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace serialize {
|
||||
|
||||
FileAdapter::FileAdapter(const std::string& file_name) {
|
||||
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
|
||||
if (!file_stream_) {
|
||||
FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
|
||||
fp_ = fopen(file_name.c_str(), "rb");
|
||||
if (fp_ == nullptr) {
|
||||
AT_ERROR("open file failed, file path: ", file_name);
|
||||
}
|
||||
istream_adapter_ = std::make_unique<IStreamAdapter>(&file_stream_);
|
||||
}
|
||||
|
||||
FileAdapter::RAIIFile::~RAIIFile() {
|
||||
if (fp_ != nullptr) {
|
||||
fclose(fp_);
|
||||
}
|
||||
}
|
||||
|
||||
// FileAdapter directly calls C file API.
|
||||
FileAdapter::FileAdapter(const std::string& file_name): file_(file_name) {
|
||||
const int fseek_ret = fseek(file_.fp_, 0L, SEEK_END);
|
||||
TORCH_CHECK(fseek_ret == 0, "fseek returned ", fseek_ret);
|
||||
#if defined(_MSC_VER)
|
||||
const int64_t ftell_ret = _ftelli64(file_.fp_);
|
||||
#else
|
||||
const off_t ftell_ret = ftello(file_.fp_);
|
||||
#endif
|
||||
TORCH_CHECK(ftell_ret != -1L, "ftell returned ", ftell_ret);
|
||||
size_ = ftell_ret;
|
||||
rewind(file_.fp_);
|
||||
}
|
||||
|
||||
size_t FileAdapter::size() const {
|
||||
return istream_adapter_->size();
|
||||
return size_;
|
||||
}
|
||||
|
||||
size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
|
||||
const {
|
||||
return istream_adapter_->read(pos, buf, n, what);
|
||||
// Ensure that pos doesn't exceed size_.
|
||||
pos = std::min(pos, size_);
|
||||
// If pos doesn't exceed size_, then size_ - pos can never be negative (in
|
||||
// signed math) or since these are unsigned values, a very large value.
|
||||
// Clamp 'n' to the smaller of 'size_ - pos' and 'n' itself. i.e. if the
|
||||
// user requested to read beyond the end of the file, we clamp to just the
|
||||
// end of the file.
|
||||
n = std::min(static_cast<size_t>(size_ - pos), n);
|
||||
#if defined(_MSC_VER)
|
||||
const int fseek_ret = _fseeki64(file_.fp_, pos, SEEK_SET);
|
||||
#else
|
||||
const int fseek_ret = fseeko(file_.fp_, pos, SEEK_SET);
|
||||
#endif
|
||||
TORCH_CHECK(
|
||||
fseek_ret == 0,
|
||||
"fseek returned ",
|
||||
fseek_ret,
|
||||
", context: ",
|
||||
what);
|
||||
return fread(buf, 1, n, file_.fp_);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(modernize-use-equals-default)
|
||||
FileAdapter::~FileAdapter() {}
|
||||
FileAdapter::~FileAdapter() = default;
|
||||
|
||||
} // namespace serialize
|
||||
} // namespace caffe2
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
#include "c10/macros/Macros.h"
|
||||
#include "caffe2/serialize/istream_adapter.h"
|
||||
#include "caffe2/serialize/read_adapter_interface.h"
|
||||
|
||||
@ -17,11 +17,19 @@ class TORCH_API FileAdapter final : public ReadAdapterInterface {
|
||||
size_t size() const override;
|
||||
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
|
||||
const override;
|
||||
~FileAdapter();
|
||||
~FileAdapter() override;
|
||||
|
||||
private:
|
||||
std::ifstream file_stream_;
|
||||
std::unique_ptr<IStreamAdapter> istream_adapter_;
|
||||
// An RAII Wrapper for a FILE pointer. Closes on destruction.
|
||||
struct RAIIFile {
|
||||
FILE* fp_;
|
||||
explicit RAIIFile(const std::string& file_name);
|
||||
~RAIIFile();
|
||||
};
|
||||
|
||||
RAIIFile file_;
|
||||
// The size of the opened file in bytes
|
||||
uint64_t size_;
|
||||
};
|
||||
|
||||
} // namespace serialize
|
||||
|
Reference in New Issue
Block a user