mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This is the fix for reverted https://github.com/pytorch/pytorch/issues/26426 houseroad bddppq soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/28715 Reviewed By: hl475 Differential Revision: D18146731 Pulled By: houseroad fbshipit-source-id: 247366451a6334e84df82d00339521f797b33130
This commit is contained in:
committed by
Facebook Github Bot
parent
4a94eaa60b
commit
1e2049c566
@ -124,6 +124,7 @@ pip install --user pytest-sugar
|
||||
--ignore "$caffe2_pypath/python/operator_test/matmul_op_test.py" \
|
||||
--ignore "$caffe2_pypath/python/operator_test/pack_ops_test.py" \
|
||||
--ignore "$caffe2_pypath/python/mkl/mkl_sbn_speed_test.py" \
|
||||
--ignore "$caffe2_pypath/python/trt/test_pt_onnx_trt.py" \
|
||||
${rocm_ignore_test[@]} \
|
||||
"$caffe2_pypath/python" \
|
||||
"${EXTRA_TESTS[@]}"
|
||||
|
@ -15,6 +15,10 @@ if(NOT CMAKE_VERSION VERSION_LESS 3.15.0)
|
||||
cmake_policy(SET CMP0092 NEW)
|
||||
endif()
|
||||
|
||||
if(NOT CMAKE_VERSION VERSION_LESS 3.10)
|
||||
set(FIND_CUDA_MODULE_DEPRECATED ON)
|
||||
endif()
|
||||
|
||||
# ---[ Project and semantic versioning.
|
||||
project(Caffe2 CXX C)
|
||||
|
||||
|
@ -11,7 +11,15 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
|
||||
size_t max_workspace_size,
|
||||
bool debug_builder) {
|
||||
auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger));
|
||||
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
|
||||
auto trt_builder_cfg = TrtObject(trt_builder->createBuilderConfig());
|
||||
// TensorRTOp doesn't support dynamic shapes yet
|
||||
auto trt_network = TrtObject(trt_builder->createNetworkV2(
|
||||
1U << static_cast<uint32_t>(nvinfer1::
|
||||
NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
|
||||
#else
|
||||
auto trt_network = TrtObject(trt_builder->createNetwork());
|
||||
#endif
|
||||
auto trt_parser =
|
||||
TrtObject(nvonnxparser::createParser(*trt_network, *logger));
|
||||
auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size());
|
||||
@ -36,9 +44,19 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine(
|
||||
}
|
||||
}
|
||||
trt_builder->setMaxBatchSize(max_batch_size);
|
||||
#if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6)
|
||||
trt_builder_cfg->setMaxWorkspaceSize(max_workspace_size);
|
||||
if (debug_builder) {
|
||||
trt_builder_cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
|
||||
}
|
||||
trt_builder_cfg->setDefaultDeviceType(nvinfer1::DeviceType::kGPU);
|
||||
return TrtObject(trt_builder->
|
||||
buildEngineWithConfig(*trt_network.get(), *trt_builder_cfg));
|
||||
#else
|
||||
trt_builder->setMaxWorkspaceSize(max_workspace_size);
|
||||
trt_builder->setDebugSync(debug_builder);
|
||||
return TrtObject(trt_builder->buildCudaEngine(*trt_network.get()));
|
||||
#endif
|
||||
}
|
||||
} // namespace tensorrt
|
||||
} // namespace caffe2
|
||||
|
BIN
caffe2/python/trt/data/binoculars.jpeg
Normal file
BIN
caffe2/python/trt/data/binoculars.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 33 KiB |
1000
caffe2/python/trt/data/class_labels.txt
Normal file
1000
caffe2/python/trt/data/class_labels.txt
Normal file
File diff suppressed because it is too large
Load Diff
BIN
caffe2/python/trt/data/reflex_camera.jpeg
Normal file
BIN
caffe2/python/trt/data/reflex_camera.jpeg
Normal file
Binary file not shown.
After Width: | Height: | Size: 81 KiB |
BIN
caffe2/python/trt/data/tabby_tiger_cat.jpg
Normal file
BIN
caffe2/python/trt/data/tabby_tiger_cat.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 108 KiB |
190
caffe2/python/trt/test_pt_onnx_trt.py
Normal file
190
caffe2/python/trt/test_pt_onnx_trt.py
Normal file
@ -0,0 +1,190 @@
|
||||
###################################################################################################
|
||||
# ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only.
|
||||
# That's because it's shipped with older version of ONNX parser not supporting some
|
||||
# required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt
|
||||
# Just clone it and do something like this:
|
||||
#
|
||||
# ~/pt/third_party/onnx-tensorrt$ mkdir build/
|
||||
# ~/pt/third_party/onnx-tensorrt$ cd build/
|
||||
# ~/pt/third_party/onnx-tensorrt/build$ cmake ..
|
||||
# ~/pt/third_party/onnx-tensorrt/build$ make
|
||||
# ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu
|
||||
#
|
||||
# This note is valid for 6.0.1 release only. September 18th, 2019.
|
||||
###################################################################################################
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from typing import List, Any
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.onnx import OperatorExportTypes
|
||||
import torchvision.models as models
|
||||
|
||||
import pycuda.driver as cuda
|
||||
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
|
||||
import pycuda.autoinit
|
||||
|
||||
import tensorrt as trt
|
||||
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
||||
|
||||
def allocate_buffers(engine):
|
||||
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),
|
||||
dtype=trt.nptype(trt.float32))
|
||||
h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),
|
||||
dtype=trt.nptype(trt.float32))
|
||||
d_input = cuda.mem_alloc(h_input.nbytes)
|
||||
d_output = cuda.mem_alloc(h_output.nbytes)
|
||||
stream = cuda.Stream()
|
||||
return h_input, d_input, h_output, d_output, stream
|
||||
|
||||
def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint):
|
||||
def normalize_image(image):
|
||||
c, h, w = input_shape
|
||||
image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\
|
||||
.astype(trt.nptype(trt.float32)).ravel()
|
||||
if (normalization_hint == 0):
|
||||
return (image_arr / 255.0 - 0.45) / 0.225
|
||||
elif (normalization_hint == 1):
|
||||
return (image_arr / 256.0 - 0.5)
|
||||
np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image)))
|
||||
return test_image
|
||||
|
||||
class Test_PT_ONNX_TRT(unittest.TestCase):
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def setUp(self):
|
||||
data_path = os.path.join(os.path.dirname(__file__), 'data')
|
||||
self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"]
|
||||
for index, f in enumerate(self.image_files):
|
||||
self.image_files[index] = os.path.abspath(os.path.join(data_path, f))
|
||||
if not os.path.exists(self.image_files[index]):
|
||||
raise FileNotFoundError(self.image_files[index] + " does not exist.")
|
||||
self.labels = open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r').read().split('\n')
|
||||
|
||||
def build_engine_onnx(self, model_file):
|
||||
with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser:
|
||||
builder.max_workspace_size = 1 << 33
|
||||
with open(model_file, 'rb') as model:
|
||||
if not parser.parse(model.read()):
|
||||
for error in range(parser.num_errors):
|
||||
self.fail("ERROR: {}".format(parser.get_error(error)))
|
||||
return builder.build_cuda_engine(network)
|
||||
|
||||
def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0):
|
||||
|
||||
model = getattr(models, model_name)(pretrained=True)
|
||||
|
||||
shape = (1,) + input_shape
|
||||
dummy_input = (torch.randn(shape),)
|
||||
onnx_name = model_name + ".onnx"
|
||||
|
||||
torch.onnx.export(model,
|
||||
dummy_input,
|
||||
onnx_name,
|
||||
input_names = [],
|
||||
output_names = [],
|
||||
verbose=False,
|
||||
export_params=True,
|
||||
opset_version=9)
|
||||
|
||||
with self.build_engine_onnx(onnx_name) as engine:
|
||||
h_input, d_input, h_output, d_output, stream = allocate_buffers(engine)
|
||||
with engine.create_execution_context() as context:
|
||||
err_count = 0
|
||||
for index, f in enumerate(self.image_files):
|
||||
test_case = load_normalized_test_case(input_shape, f,\
|
||||
h_input, normalization_hint)
|
||||
cuda.memcpy_htod_async(d_input, h_input, stream)
|
||||
|
||||
context.execute_async_v2(bindings=[d_input, d_output],
|
||||
stream_handle=stream.handle)
|
||||
cuda.memcpy_dtoh_async(h_output, d_output, stream)
|
||||
stream.synchronize()
|
||||
|
||||
amax = np.argmax(h_output)
|
||||
pred = self.labels[amax]
|
||||
if "_".join(pred.split()) not in\
|
||||
os.path.splitext(os.path.basename(test_case))[0]:
|
||||
err_count = err_count + 1
|
||||
self.assertLessEqual(err_count, 1, "Too many recognition errors")
|
||||
|
||||
def test_alexnet(self):
|
||||
self._test_model("alexnet", (3, 227, 227))
|
||||
|
||||
def test_resnet18(self):
|
||||
self._test_model("resnet18")
|
||||
def test_resnet34(self):
|
||||
self._test_model("resnet34")
|
||||
def test_resnet50(self):
|
||||
self._test_model("resnet50")
|
||||
def test_resnet101(self):
|
||||
self._test_model("resnet101")
|
||||
@unittest.skip("Takes 2m")
|
||||
def test_resnet152(self):
|
||||
self._test_model("resnet152")
|
||||
|
||||
def test_resnet50_2(self):
|
||||
self._test_model("wide_resnet50_2")
|
||||
@unittest.skip("Takes 2m")
|
||||
def test_resnet101_2(self):
|
||||
self._test_model("wide_resnet101_2")
|
||||
|
||||
def test_squeezenet1_0(self):
|
||||
self._test_model("squeezenet1_0")
|
||||
def test_squeezenet1_1(self):
|
||||
self._test_model("squeezenet1_1")
|
||||
|
||||
def test_googlenet(self):
|
||||
self._test_model("googlenet")
|
||||
def test_inception_v3(self):
|
||||
self._test_model("inception_v3")
|
||||
|
||||
def test_mnasnet0_5(self):
|
||||
self._test_model("mnasnet0_5", normalization_hint = 1)
|
||||
def test_mnasnet1_0(self):
|
||||
self._test_model("mnasnet1_0", normalization_hint = 1)
|
||||
|
||||
def test_mobilenet_v2(self):
|
||||
self._test_model("mobilenet_v2", normalization_hint = 1)
|
||||
|
||||
def test_shufflenet_v2_x0_5(self):
|
||||
self._test_model("shufflenet_v2_x0_5")
|
||||
def test_shufflenet_v2_x1_0(self):
|
||||
self._test_model("shufflenet_v2_x1_0")
|
||||
|
||||
def test_vgg11(self):
|
||||
self._test_model("vgg11")
|
||||
def test_vgg11_bn(self):
|
||||
self._test_model("vgg11_bn")
|
||||
def test_vgg13(self):
|
||||
self._test_model("vgg13")
|
||||
def test_vgg13_bn(self):
|
||||
self._test_model("vgg13_bn")
|
||||
def test_vgg16(self):
|
||||
self._test_model("vgg16")
|
||||
def test_vgg16_bn(self):
|
||||
self._test_model("vgg16_bn")
|
||||
def test_vgg19(self):
|
||||
self._test_model("vgg19")
|
||||
def test_vgg19_bn(self):
|
||||
self._test_model("vgg19_bn")
|
||||
|
||||
@unittest.skip("Takes 13m")
|
||||
def test_densenet121(self):
|
||||
self._test_model("densenet121")
|
||||
@unittest.skip("Takes 25m")
|
||||
def test_densenet161(self):
|
||||
self._test_model("densenet161")
|
||||
@unittest.skip("Takes 27m")
|
||||
def test_densenet169(self):
|
||||
self._test_model("densenet169")
|
||||
@unittest.skip("Takes 44m")
|
||||
def test_densenet201(self):
|
||||
self._test_model("densenet201")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1171,7 +1171,13 @@ function (add_onnx_tensorrt_subdir)
|
||||
# We pass the paths we found to onnx tensorrt.
|
||||
set(CUDNN_INCLUDE_DIR "${CUDNN_INCLUDE_PATH}")
|
||||
set(CUDNN_LIBRARY "${CUDNN_LIBRARY_PATH}")
|
||||
set(CMAKE_VERSION_ORIG "{CMAKE_VERSION}")
|
||||
if (FIND_CUDA_MODULE_DEPRECATED)
|
||||
# TODO: this WAR is for https://github.com/pytorch/pytorch/issues/18524
|
||||
set(CMAKE_VERSION "3.9.0")
|
||||
endif()
|
||||
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt EXCLUDE_FROM_ALL)
|
||||
set(CMAKE_VERSION "{CMAKE_VERSION_ORIG}")
|
||||
endfunction()
|
||||
if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO)
|
||||
if (USE_TENSORRT)
|
||||
|
@ -123,9 +123,23 @@ if(CAFFE2_USE_TENSORRT)
|
||||
PATH_SUFFIXES lib lib64 lib/x64)
|
||||
find_package_handle_standard_args(
|
||||
TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY)
|
||||
if(NOT TENSORRT_FOUND)
|
||||
if(TENSORRT_FOUND)
|
||||
execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MAJOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MAJOR)
|
||||
execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MINOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MINOR)
|
||||
if(TENSORRT_VERSION_MAJOR)
|
||||
string(STRIP ${TENSORRT_VERSION_MAJOR} TENSORRT_VERSION_MAJOR)
|
||||
string(STRIP ${TENSORRT_VERSION_MINOR} TENSORRT_VERSION_MINOR)
|
||||
set(TENSORRT_VERSION "${TENSORRT_VERSION_MAJOR}.${TENSORRT_VERSION_MINOR}")
|
||||
#CAFFE2_USE_TRT is set in Dependencies
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MAJOR=${TENSORRT_VERSION_MAJOR}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MINOR=${TENSORRT_VERSION_MINOR}")
|
||||
else()
|
||||
message(WARNING "Caffe2: Cannot find ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h. Assuming TRT 5.0 which is no longer supported. Turning the option off.")
|
||||
set(CAFFE2_USE_TENSORRT OFF)
|
||||
endif()
|
||||
else()
|
||||
message(WARNING
|
||||
"Caffe2: Cannot find TensorRT library. Turning the option off")
|
||||
"Caffe2: Cannot find TensorRT library. Turning the option off.")
|
||||
set(CAFFE2_USE_TENSORRT OFF)
|
||||
endif()
|
||||
endif()
|
||||
|
2
third_party/onnx-tensorrt
vendored
2
third_party/onnx-tensorrt
vendored
Submodule third_party/onnx-tensorrt updated: cb3d8066f2...c153211418
@ -77,8 +77,8 @@ def _parse_arg(value, desc):
|
||||
if desc == 'is':
|
||||
for v in value.node().inputs():
|
||||
if v.node().kind() != 'onnx::Constant':
|
||||
raise RuntimeError("Failed to export an ONNX attribute, "
|
||||
"since it's not constant, please try to make "
|
||||
raise RuntimeError("Failed to export an ONNX attribute '" + v.node().kind() +
|
||||
"', since it's not constant, please try to make "
|
||||
"things (e.g., kernel size) static if possible")
|
||||
return [int(v.node()['value']) for v in value.node().inputs()]
|
||||
else:
|
||||
|
Reference in New Issue
Block a user