mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[cudnn] Support v8 API in fbcode (#96512)
Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch Reviewed By: ngimel Differential Revision: D43784279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96512 Approved by: https://github.com/malfet
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							fe0afc5852
						
					
				
				
					commit
					788300cc2a
				
			| @ -45,7 +45,7 @@ namespace at { namespace native { | |||||||
| namespace { | namespace { | ||||||
|  |  | ||||||
| // TODO: remove duplicate code in Conv_v7.cpp | // TODO: remove duplicate code in Conv_v7.cpp | ||||||
| constexpr size_t operator "" _TiB(unsigned long long n) { | constexpr int64_t operator "" _TiB(unsigned long long n) { | ||||||
|   return size_t(n) << 40; |   return size_t(n) << 40; | ||||||
| } | } | ||||||
|  |  | ||||||
| @ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| size_t get_available_workspace() { | int64_t get_available_workspace() { | ||||||
|   int device; |   int device; | ||||||
|   C10_CUDA_CHECK(cudaGetDevice(&device)); |   C10_CUDA_CHECK(cudaGetDevice(&device)); | ||||||
|   size_t max_block_size = 0; |   size_t max_block_size = 0; | ||||||
|   c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size); |   c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size); | ||||||
|   return max_block_size; |   return static_cast<int64_t>(max_block_size); | ||||||
| } | } | ||||||
|  |  | ||||||
| static nlohmann::json errata_json_handle; | static nlohmann::json errata_json_handle; | ||||||
| @ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera | |||||||
|     return plan_errata_exception(handle, plan.getTag()); |     return plan_errata_exception(handle, plan.getTag()); | ||||||
|   }; |   }; | ||||||
|   auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function); |   auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function); | ||||||
|   size_t max_block_size = get_available_workspace(); |   int64_t max_block_size = get_available_workspace(); | ||||||
|   size_t max_workspace_size = 0u; |   int64_t max_workspace_size = 0; | ||||||
|   std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) { |   std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) { | ||||||
|     size_t curr_workspace_size = plan.getWorkspaceSize(); |     int64_t curr_workspace_size = plan.getWorkspaceSize(); | ||||||
|     if (curr_workspace_size <= max_block_size) { |     if (curr_workspace_size <= max_block_size) { | ||||||
|       if (curr_workspace_size > max_workspace_size) { |       if (curr_workspace_size > max_workspace_size) { | ||||||
|         max_workspace_size = plan.getWorkspaceSize(); |         max_workspace_size = plan.getWorkspaceSize(); | ||||||
| @ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera | |||||||
|   if (remove_invalid) { |   if (remove_invalid) { | ||||||
|     cudnn_frontend::executionPlans_t new_valid_plans; |     cudnn_frontend::executionPlans_t new_valid_plans; | ||||||
|     for (auto &plan : valid_plans) { |     for (auto &plan : valid_plans) { | ||||||
|       if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) { |       if (plan.getWorkspaceSize() <= max_workspace_size) { | ||||||
|         new_valid_plans.emplace_back(std::move(plan)); |         new_valid_plans.emplace_back(std::move(plan)); | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|  | |||||||
| @ -5,7 +5,7 @@ | |||||||
| // Note: The version below should not actually be 8000. Instead, it should | // Note: The version below should not actually be 8000. Instead, it should | ||||||
| // be whatever version of cuDNN that v8 API work with PyTorch correctly. | // be whatever version of cuDNN that v8 API work with PyTorch correctly. | ||||||
| // The version is set to 8000 today for convenience of debugging. | // The version is set to 8000 today for convenience of debugging. | ||||||
| #if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000 | #if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300 | ||||||
| #define HAS_CUDNN_V8() true | #define HAS_CUDNN_V8() true | ||||||
| #else | #else | ||||||
| #define HAS_CUDNN_V8() false | #define HAS_CUDNN_V8() false | ||||||
|  | |||||||
							
								
								
									
										1
									
								
								defs.bzl
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								defs.bzl
									
									
									
									
									
								
							| @ -34,6 +34,7 @@ default_compiler_flags = [ | |||||||
|     "-DTH_INDEX_BASE=0", |     "-DTH_INDEX_BASE=0", | ||||||
|     "-DMAGMA_V2", |     "-DMAGMA_V2", | ||||||
|     "-DNO_CUDNN_DESTROY_HANDLE", |     "-DNO_CUDNN_DESTROY_HANDLE", | ||||||
|  |     "-DUSE_EXPERIMENTAL_CUDNN_V8_API",  # enable cudnn v8 api | ||||||
|     "-DUSE_FBGEMM", |     "-DUSE_FBGEMM", | ||||||
|     "-DUSE_QNNPACK", |     "-DUSE_QNNPACK", | ||||||
|     "-DUSE_PYTORCH_QNNPACK", |     "-DUSE_PYTORCH_QNNPACK", | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user