mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21233 It is possible that OnnxifiOp is created in a thread where weights have been cleaned from the workspace, which is legit use case as we can create the backend once and lower all the weights. So we need to extract the weight shape info the first time we create the backend and save it. Reviewed By: bertmaher, rdzhabarov Differential Revision: D15587237 fbshipit-source-id: 1f264dc32c0398c42b618e9c41c119eb13e1c9f1
115 lines
3.5 KiB
C++
115 lines
3.5 KiB
C++
#pragma once
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
#include <mutex>
|
|
#include <unordered_map>
|
|
|
|
#include "caffe2/core/logging.h"
|
|
#include "caffe2/opt/shape_info.h"
|
|
#include "foxi/onnxifi_loader.h"
|
|
|
|
namespace caffe2 {
|
|
namespace onnx {
|
|
|
|
struct BackendGraphInfo {
|
|
onnxBackendID backend_id;
|
|
onnxBackend backend;
|
|
onnxGraph graph;
|
|
onnxifi_library* lib{nullptr};
|
|
std::unordered_map<std::string, ShapeInfo> weight_shape_info;
|
|
|
|
BackendGraphInfo(
|
|
onnxBackendID backend_id,
|
|
onnxBackend backend,
|
|
onnxGraph graph,
|
|
onnxifi_library* lib,
|
|
std::unordered_map<std::string, ShapeInfo>&& s)
|
|
: backend_id(backend_id),
|
|
backend(backend),
|
|
graph(graph),
|
|
lib(lib),
|
|
weight_shape_info(std::move(s)) {}
|
|
|
|
BackendGraphInfo(const BackendGraphInfo& other) = delete;
|
|
|
|
BackendGraphInfo& operator=(const BackendGraphInfo& other) = delete;
|
|
|
|
BackendGraphInfo(BackendGraphInfo&& other) noexcept {
|
|
backend_id = other.backend_id;
|
|
backend = other.backend;
|
|
graph = other.graph;
|
|
lib = other.lib;
|
|
weight_shape_info = std::move(other.weight_shape_info);
|
|
other.backend_id = other.backend = other.graph = other.lib = nullptr;
|
|
}
|
|
|
|
BackendGraphInfo& operator=(BackendGraphInfo&& other) {
|
|
backend_id = other.backend_id;
|
|
backend = other.backend;
|
|
graph = other.graph;
|
|
lib = other.lib;
|
|
weight_shape_info = std::move(other.weight_shape_info);
|
|
other.backend_id = other.backend = other.graph = other.lib = nullptr;
|
|
return *this;
|
|
}
|
|
|
|
~BackendGraphInfo() {
|
|
if (lib) {
|
|
onnxStatus err;
|
|
if (graph) {
|
|
err = lib->onnxReleaseGraph(graph);
|
|
if (err != ONNXIFI_STATUS_SUCCESS) {
|
|
LOG(ERROR) << "Error when calling onnxReleaseGraph";
|
|
}
|
|
}
|
|
if (backend) {
|
|
err = lib->onnxReleaseBackend(backend);
|
|
if (err != ONNXIFI_STATUS_SUCCESS) {
|
|
LOG(ERROR) << "Error when calling onnxReleaseBackend";
|
|
}
|
|
}
|
|
if (backend_id) {
|
|
err = lib->onnxReleaseBackendID(backend_id);
|
|
if (err != ONNXIFI_STATUS_SUCCESS) {
|
|
LOG(ERROR) << "Error when calling onnxReleaseBackendID";
|
|
}
|
|
}
|
|
}
|
|
}
|
|
};
|
|
using SharedPtrBackendGraphInfo = std::shared_ptr<BackendGraphInfo>;
|
|
|
|
// This class maintains a map of already created graph for nets+ops
|
|
class OnnxBackendGraphMap {
|
|
public:
|
|
OnnxBackendGraphMap() {}
|
|
// Make class noncopyable and nomovable.
|
|
OnnxBackendGraphMap(const OnnxBackendGraphMap&) = delete;
|
|
OnnxBackendGraphMap(OnnxBackendGraphMap&&) = delete;
|
|
OnnxBackendGraphMap operator=(const OnnxBackendGraphMap&) = delete;
|
|
OnnxBackendGraphMap operator=(OnnxBackendGraphMap&&) = delete;
|
|
|
|
SharedPtrBackendGraphInfo lookup(const std::string& key);
|
|
|
|
// If corresponding BackendGraphInfo already exists, return it directly.
|
|
// Otherwise we use creator to create the BackendGraphInfo shared_ptr and
|
|
// insert it into the map and return it. The whole process should be guarded
|
|
// by a lock. Note that since it will create the backend while holding the
|
|
// lock, expect latency during initialization phase when there are lots of
|
|
// models to compile.
|
|
SharedPtrBackendGraphInfo insert(
|
|
const std::string& key,
|
|
std::function<SharedPtrBackendGraphInfo()> creator);
|
|
|
|
void remove(const std::string& key);
|
|
|
|
private:
|
|
std::mutex backend_graph_map_lock_;
|
|
std::unordered_map<std::string, SharedPtrBackendGraphInfo> backend_graph_map_;
|
|
};
|
|
|
|
OnnxBackendGraphMap* getOnnxBackendGraphMap();
|
|
} // namespace onnx
|
|
} // namespace caffe2
|