[cudnn] Support v8 API in fbcode (#96512)

Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API

Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch

Reviewed By: ngimel

Differential Revision: D43784279

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96512
Approved by: https://github.com/malfet
This commit is contained in:
Xiaodong Wang
2023-03-23 01:41:04 +00:00
committed by PyTorch MergeBot
parent fe0afc5852
commit 788300cc2a
3 changed files with 9 additions and 8 deletions

View File

@ -45,7 +45,7 @@ namespace at { namespace native {
namespace {
// TODO: remove duplicate code in Conv_v7.cpp
constexpr size_t operator "" _TiB(unsigned long long n) {
constexpr int64_t operator "" _TiB(unsigned long long n) {
return size_t(n) << 40;
}
@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
}
}
size_t get_available_workspace() {
int64_t get_available_workspace() {
int device;
C10_CUDA_CHECK(cudaGetDevice(&device));
size_t max_block_size = 0;
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
return max_block_size;
return static_cast<int64_t>(max_block_size);
}
static nlohmann::json errata_json_handle;
@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
return plan_errata_exception(handle, plan.getTag());
};
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
size_t max_block_size = get_available_workspace();
size_t max_workspace_size = 0u;
int64_t max_block_size = get_available_workspace();
int64_t max_workspace_size = 0;
std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) {
size_t curr_workspace_size = plan.getWorkspaceSize();
int64_t curr_workspace_size = plan.getWorkspaceSize();
if (curr_workspace_size <= max_block_size) {
if (curr_workspace_size > max_workspace_size) {
max_workspace_size = plan.getWorkspaceSize();
@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
if (remove_invalid) {
cudnn_frontend::executionPlans_t new_valid_plans;
for (auto &plan : valid_plans) {
if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
if (plan.getWorkspaceSize() <= max_workspace_size) {
new_valid_plans.emplace_back(std::move(plan));
}
}

View File

@ -5,7 +5,7 @@
// Note: The version below should not actually be 8000. Instead, it should
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
// The version is set to 8000 today for convenience of debugging.
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300
#define HAS_CUDNN_V8() true
#else
#define HAS_CUDNN_V8() false

View File

@ -34,6 +34,7 @@ default_compiler_flags = [
"-DTH_INDEX_BASE=0",
"-DMAGMA_V2",
"-DNO_CUDNN_DESTROY_HANDLE",
"-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api
"-DUSE_FBGEMM",
"-DUSE_QNNPACK",
"-DUSE_PYTORCH_QNNPACK",