Files
pytorch/caffe2/contrib/tensorrt/trt_utils.cc
Li Yu a3933b87c6 Back out "Revert D14613517: [pytorch][PR] Updating onnxtrt submodule to master branch" (#18514)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18514

Original commit changeset: d6267ddfc339

Reviewed By: bddppq

Differential Revision: D14634476

fbshipit-source-id: 2633b0b4c512d71001e5c20cd79c0c0d7856f942
2019-03-26 23:44:33 -07:00

45 lines
1.4 KiB
C++

#include "caffe2/contrib/tensorrt/trt_utils.h"
#include <NvOnnxParser.h>
namespace caffe2 {
namespace tensorrt {
std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
const std::string& onnx_model_str,
TrtLogger* logger,
size_t max_batch_size,
size_t max_workspace_size,
bool debug_builder) {
auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger));
auto trt_network = TrtObject(trt_builder->createNetwork());
auto trt_parser =
TrtObject(nvonnxparser::createParser(*trt_network, *logger));
auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size());
if (!status) {
const auto num_errors = trt_parser->getNbErrors();
if (num_errors > 0) {
const auto* error = trt_parser->getError(num_errors - 1);
CAFFE_THROW(
"TensorRTTransformer ERROR: ",
error->file(),
":",
error->line(),
" In function ",
error->func(),
":\n",
"[",
static_cast<int>(error->code()),
"] ",
error->desc());
} else {
CAFFE_THROW("TensorRTTransformer Unknown Error");
}
}
trt_builder->setMaxBatchSize(max_batch_size);
trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);
return TrtObject(trt_builder->buildCudaEngine(*trt_network.get()));
}
} // namespace tensorrt
} // namespace caffe2