mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/10946 ``` codemod -d . --extensions cc,cpp,cu,cuh,h caffe2/proto/caffe2.pb.h caffe2/proto/caffe2_pb.h ``` Reviewed By: houseroad Differential Revision: D9539945 fbshipit-source-id: 497d04720e8e7e61c05ffe1b23733d0cb774de7e
68 lines
1.8 KiB
C++
68 lines
1.8 KiB
C++
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <string>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "onnx/onnx_pb.h"
|
|
|
|
#include "caffe2/core/common.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/core/workspace.h"
|
|
#include "caffe2/onnx/onnxifi_init.h"
|
|
#include "caffe2/proto/caffe2_pb.h"
|
|
|
|
namespace caffe2 {
|
|
namespace onnx {
|
|
class OnnxExporter;
|
|
}
|
|
|
|
class CAFFE2_API OnnxifiTransformer {
|
|
public:
|
|
explicit OnnxifiTransformer(bool debug);
|
|
|
|
void Transform(
|
|
Workspace* ws,
|
|
NetDef* pred_net,
|
|
const std::unordered_map<std::string, TensorShape>& shape_hints);
|
|
|
|
private:
|
|
// Note that we have two workspaces here as inputs. The first mapped_ws is
|
|
// used to mapped SSA names back to c2 original names. The second one is
|
|
// actually used to inject more weights into the original workspace
|
|
caffe2::NetDef SubnetToOnnxifiOp(
|
|
const caffe2::NetDef& net,
|
|
const Workspace& mapped_ws,
|
|
Workspace* ws,
|
|
onnx::OnnxExporter* exporter,
|
|
std::unordered_map<std::string, TensorShape>* shape_hints);
|
|
|
|
OperatorDef BuildOnnxifiOp(
|
|
const std::string& onnx_model_str,
|
|
const std::unordered_map<std::string, std::vector<int>>&
|
|
output_size_hints,
|
|
const std::unordered_set<std::string>& initialization_list,
|
|
const caffe2::NetDef& net);
|
|
|
|
CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
|
|
Workspace* ws,
|
|
NetDef* pred_net,
|
|
const std::unordered_map<std::string, TensorShape>& input_shape_hints);
|
|
|
|
// Dump onnx model for debugging
|
|
bool debug_{false};
|
|
|
|
// Pointer to loaded onnxifi library
|
|
onnxifi_library* lib_{nullptr};
|
|
|
|
// Number of backends
|
|
size_t num_backends_{0};
|
|
|
|
// Backned IDs
|
|
std::vector<onnxBackendID> backend_ids_;
|
|
// Input mapping
|
|
std::unordered_map<std::string, std::string> input_mapping_;
|
|
};
|
|
} // namespace caffe2
|