mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Save the weight shape info the first time we have chance to extract it (#21233)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/21233 It is possible that OnnxifiOp is created in a thread where weights have been cleaned from the workspace, which is legit use case as we can create the backend once and lower all the weights. So we need to extract the weight shape info the first time we create the backend and save it. Reviewed By: bertmaher, rdzhabarov Differential Revision: D15587237 fbshipit-source-id: 1f264dc32c0398c42b618e9c41c119eb13e1c9f1
This commit is contained in:
committed by
Facebook Github Bot
parent
0efc527dd1
commit
7c40576c61
@ -6,6 +6,7 @@
|
||||
#include <unordered_map>
|
||||
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/opt/shape_info.h"
|
||||
#include "foxi/onnxifi_loader.h"
|
||||
|
||||
namespace caffe2 {
|
||||
@ -16,13 +17,19 @@ struct BackendGraphInfo {
|
||||
onnxBackend backend;
|
||||
onnxGraph graph;
|
||||
onnxifi_library* lib{nullptr};
|
||||
std::unordered_map<std::string, ShapeInfo> weight_shape_info;
|
||||
|
||||
BackendGraphInfo(
|
||||
onnxBackendID backend_id,
|
||||
onnxBackend backend,
|
||||
onnxGraph graph,
|
||||
onnxifi_library* lib)
|
||||
: backend_id(backend_id), backend(backend), graph(graph), lib(lib) {}
|
||||
onnxifi_library* lib,
|
||||
std::unordered_map<std::string, ShapeInfo>&& s)
|
||||
: backend_id(backend_id),
|
||||
backend(backend),
|
||||
graph(graph),
|
||||
lib(lib),
|
||||
weight_shape_info(std::move(s)) {}
|
||||
|
||||
BackendGraphInfo(const BackendGraphInfo& other) = delete;
|
||||
|
||||
@ -33,6 +40,7 @@ struct BackendGraphInfo {
|
||||
backend = other.backend;
|
||||
graph = other.graph;
|
||||
lib = other.lib;
|
||||
weight_shape_info = std::move(other.weight_shape_info);
|
||||
other.backend_id = other.backend = other.graph = other.lib = nullptr;
|
||||
}
|
||||
|
||||
@ -41,6 +49,7 @@ struct BackendGraphInfo {
|
||||
backend = other.backend;
|
||||
graph = other.graph;
|
||||
lib = other.lib;
|
||||
weight_shape_info = std::move(other.weight_shape_info);
|
||||
other.backend_id = other.backend = other.graph = other.lib = nullptr;
|
||||
return *this;
|
||||
}
|
||||
|
@ -121,29 +121,6 @@ class OnnxifiOp final : public Operator<Context> {
|
||||
// process.
|
||||
buildBackendAndGraph(ws, property_pointers, onnx_model_str);
|
||||
|
||||
// Get the weights (initializer) a second time to fill in its shape info
|
||||
std::vector<std::string> weight_names;
|
||||
std::vector<std::vector<uint64_t>> weight_shapes;
|
||||
auto initializers =
|
||||
this->template GetRepeatedArgument<std::string>("initializers");
|
||||
std::vector<std::vector<float>> dummy_scales;
|
||||
std::vector<std::vector<float>> dummy_offsets;
|
||||
buildInitializationList(
|
||||
ws,
|
||||
initializers,
|
||||
&weight_names,
|
||||
&weight_shapes,
|
||||
&dummy_scales,
|
||||
&dummy_offsets);
|
||||
for (int i = 0; i < weight_names.size(); ++i) {
|
||||
TensorShape shape;
|
||||
const auto& shape0 = weight_shapes[i];
|
||||
for (const auto d : shape0) {
|
||||
shape.add_dims(d);
|
||||
}
|
||||
input_shape_info_[weight_names[i]] =
|
||||
ShapeInfo(ShapeInfo::DimType::CONSTANT, std::move(shape));
|
||||
}
|
||||
}
|
||||
|
||||
~OnnxifiOp() {
|
||||
@ -250,6 +227,18 @@ class OnnxifiOp final : public Operator<Context> {
|
||||
&all_scales_,
|
||||
&all_offsets_);
|
||||
|
||||
// Extra weight shapes
|
||||
std::unordered_map<std::string, ShapeInfo> weight_shape_info;
|
||||
for (int i = 0; i < weight_names.size(); ++i) {
|
||||
TensorShape shape;
|
||||
const auto& shape0 = weight_shapes[i];
|
||||
for (const auto d : shape0) {
|
||||
shape.add_dims(d);
|
||||
}
|
||||
weight_shape_info[weight_names[i]] =
|
||||
ShapeInfo(ShapeInfo::DimType::CONSTANT, std::move(shape));
|
||||
}
|
||||
|
||||
onnxGraph graph{nullptr};
|
||||
CAFFE_ENFORCE_EQ(
|
||||
lib_->onnxInitGraph(
|
||||
@ -263,7 +252,7 @@ class OnnxifiOp final : public Operator<Context> {
|
||||
ONNXIFI_STATUS_SUCCESS);
|
||||
|
||||
return std::make_shared<onnx::BackendGraphInfo>(
|
||||
backend_id, backend, graph, lib_);
|
||||
backend_id, backend, graph, lib_, std::move(weight_shape_info));
|
||||
};
|
||||
backend_graph_shared_ptr_ =
|
||||
backend_graph_map_ptr_->insert(op_id_string_, creator);
|
||||
@ -271,6 +260,7 @@ class OnnxifiOp final : public Operator<Context> {
|
||||
backend_id_ = backend_graph_shared_ptr_->backend_id;
|
||||
backend_ = backend_graph_shared_ptr_->backend;
|
||||
graph_ = backend_graph_shared_ptr_->graph;
|
||||
input_shape_info_ = backend_graph_shared_ptr_->weight_shape_info;
|
||||
|
||||
getExtFunctionPointers();
|
||||
}
|
||||
|
Reference in New Issue
Block a user