Files
pytorch/caffe2/opt/onnxifi_transformer.h
Edward Yang 91797c0672 Replace direct include of caffe2.pb.h with an intermediary header caffe2_pb.h (#10946)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10946

```
codemod -d . --extensions cc,cpp,cu,cuh,h caffe2/proto/caffe2.pb.h caffe2/proto/caffe2_pb.h
```

Reviewed By: houseroad

Differential Revision: D9539945

fbshipit-source-id: 497d04720e8e7e61c05ffe1b23733d0cb774de7e
2018-08-28 11:57:08 -07:00

68 lines
1.8 KiB
C++

#pragma once
#include <cstdint>
#include <string>
#include <unordered_map>
#include <vector>
#include "onnx/onnx_pb.h"
#include "caffe2/core/common.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/workspace.h"
#include "caffe2/onnx/onnxifi_init.h"
#include "caffe2/proto/caffe2_pb.h"
namespace caffe2 {
namespace onnx {
class OnnxExporter;
}
class CAFFE2_API OnnxifiTransformer {
public:
explicit OnnxifiTransformer(bool debug);
void Transform(
Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& shape_hints);
private:
// Note that we have two workspaces here as inputs. The first mapped_ws is
// used to mapped SSA names back to c2 original names. The second one is
// actually used to inject more weights into the original workspace
caffe2::NetDef SubnetToOnnxifiOp(
const caffe2::NetDef& net,
const Workspace& mapped_ws,
Workspace* ws,
onnx::OnnxExporter* exporter,
std::unordered_map<std::string, TensorShape>* shape_hints);
OperatorDef BuildOnnxifiOp(
const std::string& onnx_model_str,
const std::unordered_map<std::string, std::vector<int>>&
output_size_hints,
const std::unordered_set<std::string>& initialization_list,
const caffe2::NetDef& net);
CaffeMap<std::string, TensorShape> SsaRewriteAndMapNames(
Workspace* ws,
NetDef* pred_net,
const std::unordered_map<std::string, TensorShape>& input_shape_hints);
// Dump onnx model for debugging
bool debug_{false};
// Pointer to loaded onnxifi library
onnxifi_library* lib_{nullptr};
// Number of backends
size_t num_backends_{0};
// Backned IDs
std::vector<onnxBackendID> backend_ids_;
// Input mapping
std::unordered_map<std::string, std::string> input_mapping_;
};
} // namespace caffe2