mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Last commit for the day. With all the previous changes this should give an exact reference speed that TensorFlow with CuDNN3 should achieve in the end.
This commit is contained in:
		
				
					committed by
					
						 Yangqing Jia
						Yangqing Jia
					
				
			
			
				
	
			
			
			
						parent
						
							896e8e5274
						
					
				
				
					commit
					05eda208a5
				
			| @ -58,10 +58,11 @@ class Blob { | ||||
|    */ | ||||
|   template <class T> | ||||
|   T* GetMutable() { | ||||
|     if (!IsType<T>()) { | ||||
|     if (IsType<T>()) { | ||||
|       return static_cast<T*>(pointer_); | ||||
|     } else { | ||||
|       return Reset<T>(new T()); | ||||
|     } | ||||
|     return static_cast<T*>(pointer_); | ||||
|   } | ||||
|  | ||||
|   /** | ||||
|  | ||||
| @ -171,7 +171,7 @@ bool Caffe2EnableCudaPeerAccess() { | ||||
|       int can_access; | ||||
|       CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, i, j)); | ||||
|       if (can_access) { | ||||
|         CAFFE_LOG_INFO << "Enabling peer access from " << i << " to " << j; | ||||
|         CAFFE_VLOG(1) << "Enabling peer access from " << i << " to " << j; | ||||
|         // Note: just for future reference, the 0 here is not a gpu id, it is | ||||
|         // a reserved flag for cudaDeviceEnablePeerAccess that should always be | ||||
|         // zero currently. | ||||
|  | ||||
| @ -58,12 +58,13 @@ class CUDAContext { | ||||
|   inline bool FinishDeviceComputation() { | ||||
|     cudaStreamSynchronize(cuda_stream_); | ||||
|     cudaError_t error = cudaGetLastError(); | ||||
|     if (error != cudaSuccess) { | ||||
|     if (error == cudaSuccess) { | ||||
|       return true; | ||||
|     } else { | ||||
|       CAFFE_LOG_ERROR << "Encountered CUDA error: " | ||||
|                       << cudaGetErrorString(error); | ||||
|       return false; | ||||
|     } | ||||
|     return true; | ||||
|   } | ||||
|  | ||||
|   int cuda_gpu_id() { return cuda_gpu_id_; } | ||||
| @ -71,7 +72,9 @@ class CUDAContext { | ||||
|   inline cudaStream_t& cuda_stream() { return cuda_stream_; } | ||||
|  | ||||
|   cublasHandle_t& cublas_handle() { | ||||
|     if (!cublas_handle_) { | ||||
|     if (cublas_handle_) { | ||||
|       return cublas_handle_; | ||||
|     } else { | ||||
|       CUBLAS_CHECK(cublasCreate(&cublas_handle_)); | ||||
|       // The default is CUBLAS_POINTER_MODE_HOST. You can override | ||||
|       // it after obtaining the cublas handle, but do that with | ||||
| @ -79,19 +82,22 @@ class CUDAContext { | ||||
|       CUBLAS_CHECK(cublasSetPointerMode( | ||||
|           cublas_handle_, CUBLAS_POINTER_MODE_HOST)); | ||||
|       CUBLAS_CHECK(cublasSetStream(cublas_handle_, cuda_stream_)); | ||||
|       return cublas_handle_; | ||||
|     } | ||||
|     return cublas_handle_; | ||||
|   } | ||||
|  | ||||
|   curandGenerator_t& curand_generator() { | ||||
|     if (!curand_generator_) { | ||||
|     if (curand_generator_) { | ||||
|       return curand_generator_; | ||||
|     } else { | ||||
|       CURAND_CHECK(curandCreateGenerator( | ||||
|           &curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); | ||||
|       CURAND_CHECK(curandSetPseudoRandomGeneratorSeed( | ||||
|           curand_generator_, random_seed_)); | ||||
|       CURAND_CHECK(curandSetStream(curand_generator_, cuda_stream_)); | ||||
|       return curand_generator_; | ||||
|     } | ||||
|     return curand_generator_; | ||||
|  | ||||
|   } | ||||
|  | ||||
|   static inline void* New(size_t nbytes) { | ||||
| @ -106,8 +112,6 @@ class CUDAContext { | ||||
|   inline void Memcpy(size_t nbytes, const void* src, void* dst) { | ||||
|     CUDA_CHECK(cudaMemcpyAsync( | ||||
|         dst, src, nbytes, cudaMemcpyDefault, cuda_stream_)); | ||||
|     // TODO(Yangqing): do we want to synchronize inside copy? | ||||
|     CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); | ||||
|   } | ||||
|  | ||||
|   template <typename T, class SrcContext, class DstContext> | ||||
|  | ||||
| @ -177,8 +177,8 @@ class Tensor { | ||||
|     CAFFE_CHECK_EQ(src.size_, size_) | ||||
|         << "Size mismatch - did you call reshape before sharing the data?"; | ||||
|     // It is possible that the source tensor hasn't called mutable_data() yet, | ||||
|     // in which case ShareData() does make much sense since we don't really know | ||||
|     // what to share yet. | ||||
|     // in which case ShareData() doesn't make much sense since we don't really | ||||
|     // know what to share yet. | ||||
|     CAFFE_CHECK(src.data_.get()) << "Source tensor has no content yet."; | ||||
|     // Finally, do sharing. | ||||
|     data_ = src.data_; | ||||
| @ -218,14 +218,14 @@ class Tensor { | ||||
|    * and a new storage will be created. | ||||
|    */ | ||||
|   inline void* raw_mutable_data(const TypeMeta& meta) { | ||||
|     if (!data_.get() || meta_ != meta) { | ||||
|     if (meta_ == meta && data_.get()) { | ||||
|       return data_.get(); | ||||
|     } else { | ||||
|       meta_ = meta; | ||||
|       CAFFE_CHECK_GT(size_, 0); | ||||
|       data_.reset(static_cast<void*>(Context::New(size_ * meta_.itemsize())), | ||||
|                   Context::Delete); | ||||
|       return data_.get(); | ||||
|     } else { | ||||
|       return data_.get(); | ||||
|     } | ||||
|   } | ||||
|  | ||||
| @ -294,8 +294,8 @@ class Tensor { | ||||
|    * this function will produce a fatal message. | ||||
|    */ | ||||
|   inline int dim(const int i) const { | ||||
|     CAFFE_CHECK_LT(i, dims_.size()) << "Exceeding ndim limit " << dims_.size(); | ||||
|     CAFFE_CHECK_GE(i, 0) << "Cannot have negative index"; | ||||
|     CAFFE_DCHECK_LT(i, dims_.size()) << "Exceeding ndim limit " << dims_.size(); | ||||
|     CAFFE_DCHECK_GE(i, 0) << "Cannot have negative index"; | ||||
|     return dims_[i]; | ||||
|   } | ||||
|  | ||||
|  | ||||
| @ -47,7 +47,7 @@ class CuDNNPoolOp : public ConvPoolOpBase<CUDAContext> { | ||||
|  | ||||
|     if (cudnn_input_dims_ != X.dims()) { | ||||
|       // Dimensions changed; we will need to re-initialize things. | ||||
|       CAFFE_LOG_INFO << "Changing the cudnn descriptor configurations."; | ||||
|       CAFFE_VLOG(1) << "Changing the cudnn descriptor configurations."; | ||||
|       cudnn_input_dims_ = X.dims(); | ||||
|       CUDNN_CHECK(cudnnSetTensor4dDescriptor( | ||||
|           bottom_desc_, GetCudnnTensorFormat(order_), | ||||
| @ -135,7 +135,7 @@ class CuDNNPoolGradientOp : public ConvPoolOpBase<CUDAContext> { | ||||
|  | ||||
|     if (cudnn_input_dims_ != X.dims()) { | ||||
|       // Dimensions changed; we will need to re-initialize things. | ||||
|       CAFFE_LOG_INFO << "Changing the cudnn descriptor configurations."; | ||||
|       CAFFE_VLOG(1) << "Changing the cudnn descriptor configurations."; | ||||
|       cudnn_input_dims_ = X.dims(); | ||||
|       CUDNN_CHECK(cudnnSetTensor4dDescriptor( | ||||
|           bottom_desc_, GetCudnnTensorFormat(order_), | ||||
|  | ||||
| @ -1,3 +1,37 @@ | ||||
| """ | ||||
| Benchmark for common convnets. | ||||
|  | ||||
| Speed on Titan X, with 10 warmup steps and 10 main steps and with different | ||||
| versions of cudnn, are as follows: | ||||
|  | ||||
|                           V3              v4 | ||||
| AlexNet         32.5 / 108.0    27.4 /  90.1 | ||||
| OverFeat       113.0 / 342.3    91.7 / 276.5 | ||||
| Inception      134.5 / 485.8   125.7 / 450.6 | ||||
| VGG (batch 64) 200.8 / 650.0   164.1 / 551.7 | ||||
|  | ||||
| (Note that these numbers involve a "full" backprop, i.e. the gradient | ||||
| with respect to the input image is also computed.) | ||||
|  | ||||
| To get the numbers, simply run: | ||||
|  | ||||
| for MODEL in AlexNet OverFeat Inception; do | ||||
|   PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \ | ||||
|     --batch_size 128 --model $MODEL --forward_only True | ||||
| done | ||||
| for MODEL in AlexNet OverFeat Inception; do | ||||
|   PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \ | ||||
|     --batch_size 128 --model $MODEL | ||||
| done | ||||
| PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \ | ||||
|   --batch_size 64 --model VGGA --forward_only True | ||||
| PYTHONPATH=../gen:$PYTHONPATH python convnet_benchmarks.py \ | ||||
|   --batch_size 64 --model VGGA | ||||
|  | ||||
| Note that VGG needs to be run at batch 64 due to memory limit on the backward | ||||
| pass. | ||||
| """ | ||||
|  | ||||
| import argparse | ||||
| import numpy as np | ||||
| import time | ||||
| @ -200,23 +234,22 @@ def Inception(order): | ||||
|   return model, 224 | ||||
|  | ||||
|  | ||||
| def Benchmark(model_gen, order, batch_size, cudnn_limit, forward_only, | ||||
|               iterations): | ||||
|   model, input_size = model_gen(order) | ||||
| def Benchmark(model_gen, arg): | ||||
|   model, input_size = model_gen(arg.order) | ||||
|   for op in model.net._net.op: | ||||
|     if op.type == 'Conv': | ||||
|       op.engine = 'CUDNN' | ||||
|       #op.arg.add().CopyFrom(utils.MakeArgument('ws_nbytes_limit', cudnn_limit)) | ||||
|       #op.arg.add().CopyFrom(utils.MakeArgument('ws_nbytes_limit', arg.cudnn_limit)) | ||||
|       op.arg.add().CopyFrom(utils.MakeArgument('exhaustive_search', 1)) | ||||
|       op.arg.add().CopyFrom(utils.MakeArgument('shared_ws_name', 'cudnn_workspace')) | ||||
|     elif op.type in ['MaxPool', 'AveragePool', 'Relu', 'Softmax']: | ||||
|       op.engine = 'CUDNN' | ||||
|   if forward_only: | ||||
|     print 'Running forward only.' | ||||
|   if arg.forward_only: | ||||
|     print arg.model, ': running forward only.' | ||||
|   else: | ||||
|     print 'Running forward-backward.' | ||||
|     print arg.model, ': running forward-backward.' | ||||
|     model.AddGradientOperators() | ||||
|     if order == 'NHWC': | ||||
|     if arg.order == 'NHWC': | ||||
|       print ('==WARNING==\n' | ||||
|              'NHWC order with CuDNN may not be supported yet, so I might\n' | ||||
|              'exit suddenly.') | ||||
| @ -224,49 +257,58 @@ def Benchmark(model_gen, order, batch_size, cudnn_limit, forward_only, | ||||
|   model.net.RunAllOnGPU() | ||||
|  | ||||
|   workspace.ResetWorkspace() | ||||
|   if order == 'NCHW': | ||||
|     data_shape = (batch_size, 3, input_size, input_size) | ||||
|   if arg.order == 'NCHW': | ||||
|     data_shape = (arg.batch_size, 3, input_size, input_size) | ||||
|   else: | ||||
|     data_shape = (batch_size, input_size, input_size, 3) | ||||
|     data_shape = (arg.batch_size, input_size, input_size, 3) | ||||
|   device_option = model.net.Proto().device_option | ||||
|   workspace.FeedBlob("data", np.random.randn(*data_shape).astype(np.float32), | ||||
|                      device_option) | ||||
|   workspace.FeedBlob("label", np.asarray(range(batch_size)).astype(np.int32), | ||||
|   workspace.FeedBlob("label", np.asarray(range(arg.batch_size)).astype(np.int32), | ||||
|                      device_option) | ||||
|  | ||||
|   workspace.RunNetOnce(model.param_init_net) | ||||
|   workspace.CreateNet(model.net) | ||||
|   workspace.RunNet(model.net.Proto().name) | ||||
|  | ||||
|   # Print out all the tensors. | ||||
|   #for name in workspace.Blobs(): | ||||
|   #  content = workspace.FetchBlob(name) | ||||
|   #  print name, content if type(content) is str else content.shape | ||||
|   for i in range(arg.warmup_iterations): | ||||
|     workspace.RunNet(model.net.Proto().name) | ||||
|  | ||||
|   start = time.time() | ||||
|   for i in range(iterations): | ||||
|   for i in range(arg.iterations): | ||||
|     workspace.RunNet(model.net.Proto().name) | ||||
|   print 'Spent: ', (time.time() - start) / iterations | ||||
|   print 'Layer-wise benchmark.' | ||||
|   workspace.BenchmarkNet(model.net.Proto().name, 10, 50, True) | ||||
|   print 'Done.' | ||||
|   print 'Spent: ', (time.time() - start) / arg.iterations | ||||
|   if arg.layer_wise_benchmark: | ||||
|     print 'Layer-wise benchmark.' | ||||
|     workspace.BenchmarkNet( | ||||
|         model.net.Proto().name, 1, arg.iterations, True) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   parser = argparse.ArgumentParser(description="Caffe2 benchmark.") | ||||
|   parser.add_argument("--batch_size", type=int, help="The batch size.") | ||||
|   parser.add_argument("--model", type=str, help="The model to benchmark.") | ||||
|   parser.add_argument("--order", type=str, help="The order to evaluate.") | ||||
|   parser.add_argument("--cudnn_ws", type=int, help="The cudnn workspace size.") | ||||
|   parser.add_argument("--iterations", type=int, default=100, | ||||
|   parser.add_argument("--batch_size", type=int, default=128, | ||||
|                       help="The batch size.") | ||||
|   parser.add_argument("--model", type=str, | ||||
|                       help="The model to benchmark.") | ||||
|   parser.add_argument("--order", type=str, default="NCHW", | ||||
|                       help="The order to evaluate.") | ||||
|   parser.add_argument("--cudnn_ws", type=int, default=-1, | ||||
|                       help="The cudnn workspace size.") | ||||
|   parser.add_argument("--iterations", type=int, default=10, | ||||
|                       help="Number of iterations to run the network.") | ||||
|   parser.add_argument("--warmup_iterations", type=int, default=10, | ||||
|                       help="Number of warm-up iterations before benchmarking.") | ||||
|   parser.add_argument("--forward_only", type=bool, default=False, | ||||
|                       help="If set, only run the forward pass.") | ||||
|   parser.add_argument("--layer_wise_benchmark", type=bool, default=False, | ||||
|                       help="If True, run the layer-wise benchmark as well.") | ||||
|   args = parser.parse_args() | ||||
|   if (not args.batch_size or not args.model or not args.order or not args.cudnn_ws): | ||||
|     parser.print_help() | ||||
|  | ||||
|   workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) | ||||
|   model_map = {'AlexNet': AlexNet, 'OverFeat': OverFeat, 'VGGA': VGGA, 'Inception': Inception} | ||||
|   Benchmark(model_map[args.model], args.order, args.batch_size, args.cudnn_ws, | ||||
|             args.forward_only, args.iterations) | ||||
|   model_map = { | ||||
|       'AlexNet': AlexNet, | ||||
|       'OverFeat': OverFeat, | ||||
|       'VGGA': VGGA, | ||||
|       'Inception': Inception | ||||
|       } | ||||
|   Benchmark(model_map[args.model], args) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user