[Prototype] [PyTorch Edge] Speed up model loading by 12% by directly calling the C file API from FileAdapter (#61997)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61997

After profiling the model loading latency on AI Bench (Android Galaxy S8 US), it seems like a significant amount of time was spent reading data using FileAdapter, which internally calls IStreamAdapter. However, IStreamAdapter uses `std::istream` under the hood, which is not that efficient. This change reduces the model loading time from [~293ms](https://www.internalfb.com/intern/aibench/details/600870874797229) to [~254ms](https://www.internalfb.com/intern/aibench/details/163731416457694), which is a reduction of ~12%.
ghstack-source-id: 134634610

Test Plan: See the AI Bench links above.

Reviewed By: raziel

Differential Revision: D29812191

fbshipit-source-id: 57810fdc1ac515305f5504f88ac5e9e4319e9d28
This commit is contained in:
Dhruv Matani
2021-07-29 20:09:07 -07:00
committed by Facebook GitHub Bot
parent 693d8f2f07
commit 725d98bab6
2 changed files with 59 additions and 12 deletions

View File

@ -1,29 +1,68 @@
#include "caffe2/serialize/file_adapter.h"
#include <c10/util/Exception.h>
#include <cstdio>
#include "caffe2/core/common.h"
namespace caffe2 {
namespace serialize {
FileAdapter::FileAdapter(const std::string& file_name) {
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
if (!file_stream_) {
FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) {
fp_ = fopen(file_name.c_str(), "rb");
if (fp_ == nullptr) {
AT_ERROR("open file failed, file path: ", file_name);
}
istream_adapter_ = std::make_unique<IStreamAdapter>(&file_stream_);
}
FileAdapter::RAIIFile::~RAIIFile() {
if (fp_ != nullptr) {
fclose(fp_);
}
}
// FileAdapter directly calls C file API.
FileAdapter::FileAdapter(const std::string& file_name): file_(file_name) {
const int fseek_ret = fseek(file_.fp_, 0L, SEEK_END);
TORCH_CHECK(fseek_ret == 0, "fseek returned ", fseek_ret);
#if defined(_MSC_VER)
const int64_t ftell_ret = _ftelli64(file_.fp_);
#else
const off_t ftell_ret = ftello(file_.fp_);
#endif
TORCH_CHECK(ftell_ret != -1L, "ftell returned ", ftell_ret);
size_ = ftell_ret;
rewind(file_.fp_);
}
size_t FileAdapter::size() const {
return istream_adapter_->size();
return size_;
}
size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
const {
return istream_adapter_->read(pos, buf, n, what);
// Ensure that pos doesn't exceed size_.
pos = std::min(pos, size_);
// If pos doesn't exceed size_, then size_ - pos can never be negative (in
// signed math) or since these are unsigned values, a very large value.
// Clamp 'n' to the smaller of 'size_ - pos' and 'n' itself. i.e. if the
// user requested to read beyond the end of the file, we clamp to just the
// end of the file.
n = std::min(static_cast<size_t>(size_ - pos), n);
#if defined(_MSC_VER)
const int fseek_ret = _fseeki64(file_.fp_, pos, SEEK_SET);
#else
const int fseek_ret = fseeko(file_.fp_, pos, SEEK_SET);
#endif
TORCH_CHECK(
fseek_ret == 0,
"fseek returned ",
fseek_ret,
", context: ",
what);
return fread(buf, 1, n, file_.fp_);
}
// NOLINTNEXTLINE(modernize-use-equals-default)
FileAdapter::~FileAdapter() {}
FileAdapter::~FileAdapter() = default;
} // namespace serialize
} // namespace caffe2

View File

@ -2,8 +2,8 @@
#include <fstream>
#include <memory>
#include <c10/macros/Macros.h>
#include "c10/macros/Macros.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"
@ -17,11 +17,19 @@ class TORCH_API FileAdapter final : public ReadAdapterInterface {
size_t size() const override;
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override;
~FileAdapter();
~FileAdapter() override;
private:
std::ifstream file_stream_;
std::unique_ptr<IStreamAdapter> istream_adapter_;
// An RAII Wrapper for a FILE pointer. Closes on destruction.
struct RAIIFile {
FILE* fp_;
explicit RAIIFile(const std::string& file_name);
~RAIIFile();
};
RAIIFile file_;
// The size of the opened file in bytes
uint64_t size_;
};
} // namespace serialize