mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Summary: This is the fix for reverted https://github.com/pytorch/pytorch/issues/26426 houseroad bddppq soumith Pull Request resolved: https://github.com/pytorch/pytorch/pull/28715 Reviewed By: hl475 Differential Revision: D18146731 Pulled By: houseroad fbshipit-source-id: 247366451a6334e84df82d00339521f797b33130
This commit is contained in:
		
				
					committed by
					
						 Facebook Github Bot
						Facebook Github Bot
					
				
			
			
				
	
			
			
			
						parent
						
							4a94eaa60b
						
					
				
				
					commit
					1e2049c566
				
			| @ -124,6 +124,7 @@ pip install --user pytest-sugar | ||||
|   --ignore "$caffe2_pypath/python/operator_test/matmul_op_test.py" \ | ||||
|   --ignore "$caffe2_pypath/python/operator_test/pack_ops_test.py" \ | ||||
|   --ignore "$caffe2_pypath/python/mkl/mkl_sbn_speed_test.py" \ | ||||
|   --ignore "$caffe2_pypath/python/trt/test_pt_onnx_trt.py" \ | ||||
|   ${rocm_ignore_test[@]} \ | ||||
|   "$caffe2_pypath/python" \ | ||||
|   "${EXTRA_TESTS[@]}" | ||||
|  | ||||
| @ -15,6 +15,10 @@ if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) | ||||
|   cmake_policy(SET CMP0092 NEW) | ||||
| endif() | ||||
|  | ||||
| if(NOT CMAKE_VERSION VERSION_LESS 3.10) | ||||
|   set(FIND_CUDA_MODULE_DEPRECATED ON) | ||||
| endif() | ||||
|  | ||||
| # ---[ Project and semantic versioning. | ||||
| project(Caffe2 CXX C) | ||||
|  | ||||
|  | ||||
| @ -11,7 +11,15 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine( | ||||
|     size_t max_workspace_size, | ||||
|     bool debug_builder) { | ||||
|   auto trt_builder = TrtObject(nvinfer1::createInferBuilder(*logger)); | ||||
| #if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6) | ||||
|   auto trt_builder_cfg = TrtObject(trt_builder->createBuilderConfig()); | ||||
|   // TensorRTOp doesn't support dynamic shapes yet | ||||
|   auto trt_network = TrtObject(trt_builder->createNetworkV2( | ||||
|       1U << static_cast<uint32_t>(nvinfer1:: | ||||
|       NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); | ||||
| #else | ||||
|   auto trt_network = TrtObject(trt_builder->createNetwork()); | ||||
| #endif | ||||
|   auto trt_parser = | ||||
|       TrtObject(nvonnxparser::createParser(*trt_network, *logger)); | ||||
|   auto status = trt_parser->parse(onnx_model_str.data(), onnx_model_str.size()); | ||||
| @ -36,9 +44,19 @@ std::shared_ptr<nvinfer1::ICudaEngine> BuildTrtEngine( | ||||
|     } | ||||
|   } | ||||
|   trt_builder->setMaxBatchSize(max_batch_size); | ||||
| #if defined(TENSORRT_VERSION_MAJOR) && (TENSORRT_VERSION_MAJOR >= 6) | ||||
|   trt_builder_cfg->setMaxWorkspaceSize(max_workspace_size); | ||||
|   if (debug_builder) { | ||||
|     trt_builder_cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG); | ||||
|   } | ||||
|   trt_builder_cfg->setDefaultDeviceType(nvinfer1::DeviceType::kGPU); | ||||
|   return TrtObject(trt_builder-> | ||||
|       buildEngineWithConfig(*trt_network.get(), *trt_builder_cfg)); | ||||
| #else | ||||
|   trt_builder->setMaxWorkspaceSize(max_workspace_size); | ||||
|   trt_builder->setDebugSync(debug_builder); | ||||
|   return TrtObject(trt_builder->buildCudaEngine(*trt_network.get())); | ||||
| #endif | ||||
| } | ||||
| } // namespace tensorrt | ||||
| } // namespace caffe2 | ||||
|  | ||||
							
								
								
									
										
											BIN
										
									
								
								caffe2/python/trt/data/binoculars.jpeg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								caffe2/python/trt/data/binoculars.jpeg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 33 KiB | 
							
								
								
									
										1000
									
								
								caffe2/python/trt/data/class_labels.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1000
									
								
								caffe2/python/trt/data/class_labels.txt
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								caffe2/python/trt/data/reflex_camera.jpeg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								caffe2/python/trt/data/reflex_camera.jpeg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 81 KiB | 
							
								
								
									
										
											BIN
										
									
								
								caffe2/python/trt/data/tabby_tiger_cat.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								caffe2/python/trt/data/tabby_tiger_cat.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| After Width: | Height: | Size: 108 KiB | 
							
								
								
									
										190
									
								
								caffe2/python/trt/test_pt_onnx_trt.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										190
									
								
								caffe2/python/trt/test_pt_onnx_trt.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,190 @@ | ||||
| ################################################################################################### | ||||
| # ATTENTION! This test will most probably fail if you install TensorRT 6.0.1 only. | ||||
| # That's because it's shipped with older version of ONNX parser not supporting some | ||||
| # required features. To make it work please use new version: https://github.com/onnx/onnx-tensorrt | ||||
| # Just clone it and do something like this: | ||||
| # | ||||
| # ~/pt/third_party/onnx-tensorrt$ mkdir build/ | ||||
| # ~/pt/third_party/onnx-tensorrt$ cd build/ | ||||
| # ~/pt/third_party/onnx-tensorrt/build$ cmake .. | ||||
| # ~/pt/third_party/onnx-tensorrt/build$ make | ||||
| # ~/pt/third_party/onnx-tensorrt/build$ sudo cp libnvonnxparser.so.6.0.1 /usr/lib/x86_64-linux-gnu | ||||
| # | ||||
| # This note is valid for 6.0.1 release only. September 18th, 2019. | ||||
| ################################################################################################### | ||||
|  | ||||
| import os | ||||
| import unittest | ||||
| from typing import List, Any | ||||
|  | ||||
| from PIL import Image | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch.onnx import OperatorExportTypes | ||||
| import torchvision.models as models | ||||
|  | ||||
| import pycuda.driver as cuda | ||||
| # This import causes pycuda to automatically manage CUDA context creation and cleanup. | ||||
| import pycuda.autoinit | ||||
|  | ||||
| import tensorrt as trt | ||||
| TRT_LOGGER = trt.Logger(trt.Logger.WARNING) | ||||
|  | ||||
| def allocate_buffers(engine): | ||||
|     h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), | ||||
|                                     dtype=trt.nptype(trt.float32)) | ||||
|     h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)), | ||||
|                                      dtype=trt.nptype(trt.float32)) | ||||
|     d_input = cuda.mem_alloc(h_input.nbytes) | ||||
|     d_output = cuda.mem_alloc(h_output.nbytes) | ||||
|     stream = cuda.Stream() | ||||
|     return h_input, d_input, h_output, d_output, stream | ||||
|  | ||||
| def load_normalized_test_case(input_shape, test_image, pagelocked_buffer, normalization_hint): | ||||
|     def normalize_image(image): | ||||
|         c, h, w = input_shape | ||||
|         image_arr = np.asarray(image.resize((w, h), Image.ANTIALIAS)).transpose([2, 0, 1])\ | ||||
|             .astype(trt.nptype(trt.float32)).ravel() | ||||
|         if (normalization_hint == 0): | ||||
|             return (image_arr / 255.0 - 0.45) / 0.225 | ||||
|         elif (normalization_hint == 1): | ||||
|             return (image_arr / 256.0 - 0.5) | ||||
|     np.copyto(pagelocked_buffer, normalize_image(Image.open(test_image))) | ||||
|     return test_image | ||||
|  | ||||
| class Test_PT_ONNX_TRT(unittest.TestCase): | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def setUp(self): | ||||
|         data_path = os.path.join(os.path.dirname(__file__), 'data') | ||||
|         self.image_files=["binoculars.jpeg", "reflex_camera.jpeg", "tabby_tiger_cat.jpg"] | ||||
|         for index, f in enumerate(self.image_files): | ||||
|             self.image_files[index] = os.path.abspath(os.path.join(data_path, f)) | ||||
|             if not os.path.exists(self.image_files[index]): | ||||
|                 raise FileNotFoundError(self.image_files[index] + " does not exist.") | ||||
|         self.labels = open(os.path.abspath(os.path.join(data_path, "class_labels.txt")), 'r').read().split('\n') | ||||
|  | ||||
|     def build_engine_onnx(self, model_file): | ||||
|         with trt.Builder(TRT_LOGGER) as builder, builder.create_network(flags = 1) as network, trt.OnnxParser(network, TRT_LOGGER) as parser: | ||||
|             builder.max_workspace_size = 1 << 33 | ||||
|             with open(model_file, 'rb') as model: | ||||
|                 if not parser.parse(model.read()): | ||||
|                     for error in range(parser.num_errors): | ||||
|                         self.fail("ERROR: {}".format(parser.get_error(error))) | ||||
|             return builder.build_cuda_engine(network) | ||||
|  | ||||
|     def _test_model(self, model_name, input_shape = (3, 224, 224), normalization_hint = 0): | ||||
|  | ||||
|         model = getattr(models, model_name)(pretrained=True) | ||||
|  | ||||
|         shape = (1,) + input_shape | ||||
|         dummy_input  = (torch.randn(shape),) | ||||
|         onnx_name = model_name + ".onnx" | ||||
|  | ||||
|         torch.onnx.export(model, | ||||
|                           dummy_input, | ||||
|                           onnx_name, | ||||
|                           input_names = [], | ||||
|                           output_names = [], | ||||
|                           verbose=False, | ||||
|                           export_params=True, | ||||
|                           opset_version=9) | ||||
|  | ||||
|         with self.build_engine_onnx(onnx_name) as engine: | ||||
|             h_input, d_input, h_output, d_output, stream = allocate_buffers(engine) | ||||
|             with engine.create_execution_context() as context: | ||||
|                 err_count = 0 | ||||
|                 for index, f in enumerate(self.image_files): | ||||
|                     test_case = load_normalized_test_case(input_shape, f,\ | ||||
|                         h_input, normalization_hint) | ||||
|                     cuda.memcpy_htod_async(d_input, h_input, stream) | ||||
|  | ||||
|                     context.execute_async_v2(bindings=[d_input, d_output], | ||||
|                                              stream_handle=stream.handle) | ||||
|                     cuda.memcpy_dtoh_async(h_output, d_output, stream) | ||||
|                     stream.synchronize() | ||||
|  | ||||
|                     amax = np.argmax(h_output) | ||||
|                     pred = self.labels[amax] | ||||
|                     if "_".join(pred.split()) not in\ | ||||
|                             os.path.splitext(os.path.basename(test_case))[0]: | ||||
|                         err_count = err_count + 1 | ||||
|                 self.assertLessEqual(err_count, 1, "Too many recognition errors") | ||||
|  | ||||
|     def test_alexnet(self): | ||||
|         self._test_model("alexnet", (3, 227, 227)) | ||||
|  | ||||
|     def test_resnet18(self): | ||||
|         self._test_model("resnet18") | ||||
|     def test_resnet34(self): | ||||
|         self._test_model("resnet34") | ||||
|     def test_resnet50(self): | ||||
|         self._test_model("resnet50") | ||||
|     def test_resnet101(self): | ||||
|         self._test_model("resnet101") | ||||
|     @unittest.skip("Takes 2m") | ||||
|     def test_resnet152(self): | ||||
|         self._test_model("resnet152") | ||||
|  | ||||
|     def test_resnet50_2(self): | ||||
|         self._test_model("wide_resnet50_2") | ||||
|     @unittest.skip("Takes 2m") | ||||
|     def test_resnet101_2(self): | ||||
|         self._test_model("wide_resnet101_2") | ||||
|  | ||||
|     def test_squeezenet1_0(self): | ||||
|         self._test_model("squeezenet1_0") | ||||
|     def test_squeezenet1_1(self): | ||||
|         self._test_model("squeezenet1_1") | ||||
|  | ||||
|     def test_googlenet(self): | ||||
|         self._test_model("googlenet") | ||||
|     def test_inception_v3(self): | ||||
|         self._test_model("inception_v3") | ||||
|  | ||||
|     def test_mnasnet0_5(self): | ||||
|         self._test_model("mnasnet0_5", normalization_hint = 1) | ||||
|     def test_mnasnet1_0(self): | ||||
|         self._test_model("mnasnet1_0", normalization_hint = 1) | ||||
|  | ||||
|     def test_mobilenet_v2(self): | ||||
|         self._test_model("mobilenet_v2", normalization_hint = 1) | ||||
|  | ||||
|     def test_shufflenet_v2_x0_5(self): | ||||
|         self._test_model("shufflenet_v2_x0_5") | ||||
|     def test_shufflenet_v2_x1_0(self): | ||||
|         self._test_model("shufflenet_v2_x1_0") | ||||
|  | ||||
|     def test_vgg11(self): | ||||
|         self._test_model("vgg11") | ||||
|     def test_vgg11_bn(self): | ||||
|         self._test_model("vgg11_bn") | ||||
|     def test_vgg13(self): | ||||
|         self._test_model("vgg13") | ||||
|     def test_vgg13_bn(self): | ||||
|         self._test_model("vgg13_bn") | ||||
|     def test_vgg16(self): | ||||
|         self._test_model("vgg16") | ||||
|     def test_vgg16_bn(self): | ||||
|         self._test_model("vgg16_bn") | ||||
|     def test_vgg19(self): | ||||
|         self._test_model("vgg19") | ||||
|     def test_vgg19_bn(self): | ||||
|         self._test_model("vgg19_bn") | ||||
|  | ||||
|     @unittest.skip("Takes 13m") | ||||
|     def test_densenet121(self): | ||||
|         self._test_model("densenet121") | ||||
|     @unittest.skip("Takes 25m") | ||||
|     def test_densenet161(self): | ||||
|         self._test_model("densenet161") | ||||
|     @unittest.skip("Takes 27m") | ||||
|     def test_densenet169(self): | ||||
|         self._test_model("densenet169") | ||||
|     @unittest.skip("Takes 44m") | ||||
|     def test_densenet201(self): | ||||
|         self._test_model("densenet201") | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
| @ -1171,7 +1171,13 @@ function (add_onnx_tensorrt_subdir) | ||||
|   # We pass the paths we found to onnx tensorrt. | ||||
|   set(CUDNN_INCLUDE_DIR "${CUDNN_INCLUDE_PATH}") | ||||
|   set(CUDNN_LIBRARY "${CUDNN_LIBRARY_PATH}") | ||||
|   set(CMAKE_VERSION_ORIG "{CMAKE_VERSION}") | ||||
|   if (FIND_CUDA_MODULE_DEPRECATED) | ||||
|     # TODO: this WAR is for https://github.com/pytorch/pytorch/issues/18524 | ||||
|     set(CMAKE_VERSION "3.9.0") | ||||
|   endif() | ||||
|   add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/onnx-tensorrt EXCLUDE_FROM_ALL) | ||||
|   set(CMAKE_VERSION "{CMAKE_VERSION_ORIG}") | ||||
| endfunction() | ||||
| if (CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO) | ||||
|   if (USE_TENSORRT) | ||||
|  | ||||
| @ -123,9 +123,23 @@ if(CAFFE2_USE_TENSORRT) | ||||
|     PATH_SUFFIXES lib lib64 lib/x64) | ||||
|   find_package_handle_standard_args( | ||||
|     TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY) | ||||
|   if(NOT TENSORRT_FOUND) | ||||
|   if(TENSORRT_FOUND) | ||||
|     execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MAJOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MAJOR) | ||||
|     execute_process(COMMAND /bin/sh -c "[ -r \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\" ] && awk '/^\#define NV_TENSORRT_MINOR/ {print $3}' \"${TENSORRT_INCLUDE_DIR}/NvInferVersion.h\"" OUTPUT_VARIABLE TENSORRT_VERSION_MINOR) | ||||
|     if(TENSORRT_VERSION_MAJOR) | ||||
|       string(STRIP ${TENSORRT_VERSION_MAJOR} TENSORRT_VERSION_MAJOR) | ||||
|       string(STRIP ${TENSORRT_VERSION_MINOR} TENSORRT_VERSION_MINOR) | ||||
|       set(TENSORRT_VERSION "${TENSORRT_VERSION_MAJOR}.${TENSORRT_VERSION_MINOR}") | ||||
|       #CAFFE2_USE_TRT is set in Dependencies | ||||
|       set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MAJOR=${TENSORRT_VERSION_MAJOR}") | ||||
|       set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DTENSORRT_VERSION_MINOR=${TENSORRT_VERSION_MINOR}") | ||||
|     else() | ||||
|       message(WARNING "Caffe2: Cannot find ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h. Assuming TRT 5.0 which is no longer supported. Turning the option off.") | ||||
|       set(CAFFE2_USE_TENSORRT OFF) | ||||
|     endif() | ||||
|   else() | ||||
|     message(WARNING | ||||
|       "Caffe2: Cannot find TensorRT library. Turning the option off") | ||||
|       "Caffe2: Cannot find TensorRT library. Turning the option off.") | ||||
|     set(CAFFE2_USE_TENSORRT OFF) | ||||
|   endif() | ||||
| endif() | ||||
|  | ||||
							
								
								
									
										2
									
								
								third_party/onnx-tensorrt
									
									
									
									
										vendored
									
									
								
							
							
								
								
								
								
								
							
						
						
									
										2
									
								
								third_party/onnx-tensorrt
									
									
									
									
										vendored
									
									
								
							 Submodule third_party/onnx-tensorrt updated: cb3d8066f2...c153211418
									
								
							| @ -77,8 +77,8 @@ def _parse_arg(value, desc): | ||||
|         if desc == 'is': | ||||
|             for v in value.node().inputs(): | ||||
|                 if v.node().kind() != 'onnx::Constant': | ||||
|                     raise RuntimeError("Failed to export an ONNX attribute, " | ||||
|                                        "since it's not constant, please try to make " | ||||
|                     raise RuntimeError("Failed to export an ONNX attribute '" + v.node().kind() + | ||||
|                                        "', since it's not constant, please try to make " | ||||
|                                        "things (e.g., kernel size) static if possible") | ||||
|             return [int(v.node()['value']) for v in value.node().inputs()] | ||||
|         else: | ||||
|  | ||||
		Reference in New Issue
	
	Block a user