Summary:
This is the fix for reverted https://github.com/pytorch/pytorch/issues/26426
houseroad bddppq soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28715

Reviewed By: hl475

Differential Revision: D18146731

Pulled By: houseroad

fbshipit-source-id: 247366451a6334e84df82d00339521f797b33130
This commit is contained in:
Sergei Nikolaev
2019-11-01 12:51:28 -07:00
committed by Facebook Github Bot
parent 4a94eaa60b
commit 1e2049c566
12 changed files with 1238 additions and 5 deletions

View File

@ -124,6 +124,7 @@ pip install --user pytest-sugar
--ignore "$caffe2_pypath/python/operator_test/matmul_op_test.py" \
--ignore "$caffe2_pypath/python/operator_test/pack_ops_test.py" \
--ignore "$caffe2_pypath/python/mkl/mkl_sbn_speed_test.py" \
--ignore "$caffe2_pypath/python/trt/test_pt_onnx_trt.py" \
${rocm_ignore_test[@]} \
"$caffe2_pypath/python" \
"${EXTRA_TESTS[@]}"

View File

@ -15,6 +15,10 @@ if(NOT CMAKE_VERSION VERSION_LESS 3.15.0)
cmake_policy(SET CMP0092 NEW)
endif()
if(NOT CMAKE_VERSION VERSION_LESS 3.10)
set(FIND_CUDA_MODULE_DEPRECATED ON)
endif()
# ---[ Project and semantic versioning.
project(Caffe2 CXX C)

View File

@ -11,7 +11,15 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
size_t max_workspace_size,
bool debug_builder) {
auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger));
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
auto trt_builder_cfg = TrtObject(trt_builder->createBuilderConfig());
// TensorRTOp doesn't support dynamic shapes yet
auto trt_network = TrtObject(trt_builder->createNetworkV2(
1U << static_cast<uint32_t>(nvinfer1::
NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
#else
auto trt_network = TrtObject(trt_builder->createNetwork());
#endif
auto trt_parser =
TrtObject(nvonnxparser::createParser(*trt_network, *logger));
auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size());
@ -36,9 +44,19 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
}
}
trt_builder->setMaxBatchSize(max_batch_size);
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
trt_builder_cfg->setMaxWorkspaceSize(max_workspace_size);
if (debug_builder) {
trt_builder_cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
}
trt_builder_cfg->setDefaultDeviceType(nvinfer1::DeviceType::kGPU);
return TrtObject(trt_builder->
buildEngineWithConfig(*trt_network.get(), *trt_builder_cfg));
#else
trt_builder->setMaxWorkspaceSize(max_workspace_size);
trt_builder->setDebugSync(debug_builder);
return TrtObject(trt_builder->buildCudaEngine(*trt_network.get()));
#endif
}
} // namespace tensorrt
} // namespace caffe2

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

File diff suppressed because it is too large Load Diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

View File

@ -0,0 +1,190 @@
###################################################################################################
# ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only.
# That's because it's shipped with older version of ONNX parser not supporting some
# required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt
# Just clone it and do something like this:
#
# ~/pt/third_party/onnx-tensorrt$ mkdir build/
# ~/pt/third_party/onnx-tensorrt$ cd build/
# ~/pt/third_party/onnx-tensorrt/build$ cmake ..
# ~/pt/third_party/onnx-tensorrt/build$ make
# ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu
#
# This note is valid for 6.0.1 release only. September 18th, 2019.
###################################################################################################
import os
import unittest
from typing import List, Any
from PIL import Image
import numpy as np
import torch
from torch.onnx import OperatorExportTypes
import torchvision.models as models
import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit
import tensorrt as trt
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
def allocate_buffers(engine):
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),
dtype=trt.nptype(trt.float32))
h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),
dtype=trt.nptype(trt.float32))
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)
stream = cuda.Stream()
return h_input, d_input, h_output, d_output, stream
def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint):
def normalize_image(image):
c, h, w = input_shape
image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\
.astype(trt.nptype(trt.float32)).ravel()
if (normalization_hint == 0):
return (image_arr / 255.0 - 0.45) / 0.225
elif (normalization_hint == 1):
return (image_arr / 256.0 - 0.5)
np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image)))
return test_image
class Test_PT_ONNX_TRT(unittest.TestCase):
def __enter__(self):
return self
def setUp(self):
data_path = os.path.join(os.path.dirname(__file__), 'data')
self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"]
for index, f in enumerate(self.image_files):
self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
if not os.path.exists(self.image_files[index]):
raise FileNotFoundError(self.image_files[index] + " does not exist.")
self.labels = open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r').read().split('\n')
def build_engine_onnx(self, model_file):
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
builder.max_workspace_size = 1 << 33
with open(model_file, 'rb') as model:
if not parser.parse(model.read()):
for error in range(parser.num_errors):
self.fail("ERROR: {}".format(parser.get_error(error)))
return builder.build_cuda_engine(network)
def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):
model = getattr(models, model_name)(pretrained=True)
shape = (1,) + input_shape
dummy_input = (torch.randn(shape),)
onnx_name = model_name + ".onnx"
torch.onnx.export(model,
dummy_input,
onnx_name,
input_names = [],
output_names = [],
verbose=False,
export_params=True,
opset_version=9)
with self.build_engine_onnx(onnx_name) as engine:
h_input, d_input, h_output, d_output, stream = allocate_buffers(engine)
with engine.create_execution_context() as context:
err_count = 0
for index, f in enumerate(self.image_files):
test_case = load_normalized_test_case(input_shape, f,\
h_input, normalization_hint)
cuda.memcpy_htod_async(d_input, h_input, stream)
context.execute_async_v2(bindings=[d_input, d_output],
stream_handle=stream.handle)
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()
amax = np.argmax(h_output)
pred = self.labels[amax]
if "_".join(pred.split()) not in\
os.path.splitext(os.path.basename(test_case))[0]:
err_count = err_count + 1
self.assertLessEqual(err_count, 1, "Too many recognition errors")
def test_alexnet(self):
self._test_model("alexnet", (3, 227, 227))
def test_resnet18(self):
self._test_model("resnet18")
def test_resnet34(self):
self._test_model("resnet34")
def test_resnet50(self):
self._test_model("resnet50")
def test_resnet101(self):
self._test_model("resnet101")
@unittest.skip("Takes 2m")
def test_resnet152(self):
self._test_model("resnet152")
def test_resnet50_2(self):
self._test_model("wide_resnet50_2")
@unittest.skip("Takes 2m")
def test_resnet101_2(self):
self._test_model("wide_resnet101_2")
def test_squeezenet1_0(self):
self._test_model("squeezenet1_0")
def test_squeezenet1_1(self):
self._test_model("squeezenet1_1")
def test_googlenet(self):
self._test_model("googlenet")
def test_inception_v3(self):
self._test_model("inception_v3")
def test_mnasnet0_5(self):
self._test_model("mnasnet0_5", normalization_hint = 1)
def test_mnasnet1_0(self):
self._test_model("mnasnet1_0", normalization_hint = 1)
def test_mobilenet_v2(self):
self._test_model("mobilenet_v2", normalization_hint = 1)
def test_shufflenet_v2_x0_5(self):
self._test_model("shufflenet_v2_x0_5")
def test_shufflenet_v2_x1_0(self):
self._test_model("shufflenet_v2_x1_0")
def test_vgg11(self):
self._test_model("vgg11")
def test_vgg11_bn(self):
self._test_model("vgg11_bn")
def test_vgg13(self):
self._test_model("vgg13")
def test_vgg13_bn(self):
self._test_model("vgg13_bn")
def test_vgg16(self):
self._test_model("vgg16")
def test_vgg16_bn(self):
self._test_model("vgg16_bn")
def test_vgg19(self):
self._test_model("vgg19")
def test_vgg19_bn(self):
self._test_model("vgg19_bn")
@unittest.skip("Takes 13m")
def test_densenet121(self):
self._test_model("densenet121")
@unittest.skip("Takes 25m")
def test_densenet161(self):
self._test_model("densenet161")
@unittest.skip("Takes 27m")
def test_densenet169(self):
self._test_model("densenet169")
@unittest.skip("Takes 44m")
def test_densenet201(self):
self._test_model("densenet201")
if __name__ == '__main__':
unittest.main()

View File

@ -1171,7 +1171,13 @@ function (add_onnx_tensorrt_subdir)
# We pass the paths we found to onnx tensorrt.
set(CUDNN_INCLUDE_DIR "${CUDNN_INCLUDE_PATH}")
set(CUDNN_LIBRARY "${CUDNN_LIBRARY_PATH}")
set(CMAKE_VERSION_ORIG "{CMAKE_VERSION}")
if (FIND_CUDA_MODULE_DEPRECATED)
# TODO: this WAR is for https://github.com/pytorch/pytorch/issues/18524
set(CMAKE_VERSION "3.9.0")
endif()
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt EXCLUDE_FROM_ALL)
set(CMAKE_VERSION "{CMAKE_VERSION_ORIG}")
endfunction()
if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
if (USE_TENSORRT)

View File

@ -123,9 +123,23 @@ if(CAFFE2_USE_TENSORRT)
PATH_SUFFIXES lib lib64 lib/x64)
find_package_handle_standard_args(
TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY)
if(NOT TENSORRT_FOUND)
if(TENSORRT_FOUND)
execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MAJOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MAJOR)
execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MINOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MINOR)
if(TENSORRT_VERSION_MAJOR)
string(STRIP ${TENSORRT_VERSION_MAJOR} TENSORRT_VERSION_MAJOR)
string(STRIP ${TENSORRT_VERSION_MINOR} TENSORRT_VERSION_MINOR)
set(TENSORRT_VERSION "${TENSORRT_VERSION_MAJOR}.${TENSORRT_VERSION_MINOR}")
#CAFFE2_USE_TRT is set in Dependencies
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MAJOR=${TENSORRT_VERSION_MAJOR}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MINOR=${TENSORRT_VERSION_MINOR}")
else()
message(WARNING "Caffe2: Cannot find ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h. Assuming TRT 5.0 which is no longer supported. Turning the option off.")
set(CAFFE2_USE_TENSORRT OFF)
endif()
else()
message(WARNING
"Caffe2: Cannot find TensorRT library. Turning the option off")
"Caffe2: Cannot find TensorRT library. Turning the option off.")
set(CAFFE2_USE_TENSORRT OFF)
endif()
endif()

View File

@ -77,8 +77,8 @@ def _parse_arg(value, desc):
if desc == 'is':
for v in value.node().inputs():
if v.node().kind() != 'onnx::Constant':
raise RuntimeError("Failed to export an ONNX attribute, "
"since it's not constant, please try to make "
raise RuntimeError("Failed to export an ONNX attribute '" + v.node().kind() +
"', since it's not constant, please try to make "
"things (e.g., kernel size) static if possible")
return [int(v.node()['value']) for v in value.node().inputs()]
else: