mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[cudnn] Support v8 API in fbcode (#96512)
Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch Reviewed By: ngimel Differential Revision: D43784279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96512 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe0afc5852
commit
788300cc2a
@ -45,7 +45,7 @@ namespace at { namespace native {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// TODO: remove duplicate code in Conv_v7.cpp
|
// TODO: remove duplicate code in Conv_v7.cpp
|
||||||
constexpr size_t operator "" _TiB(unsigned long long n) {
|
constexpr int64_t operator "" _TiB(unsigned long long n) {
|
||||||
return size_t(n) << 40;
|
return size_t(n) << 40;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t get_available_workspace() {
|
int64_t get_available_workspace() {
|
||||||
int device;
|
int device;
|
||||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||||
size_t max_block_size = 0;
|
size_t max_block_size = 0;
|
||||||
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
||||||
return max_block_size;
|
return static_cast<int64_t>(max_block_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
static nlohmann::json errata_json_handle;
|
static nlohmann::json errata_json_handle;
|
||||||
@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
|
|||||||
return plan_errata_exception(handle, plan.getTag());
|
return plan_errata_exception(handle, plan.getTag());
|
||||||
};
|
};
|
||||||
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
|
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
|
||||||
size_t max_block_size = get_available_workspace();
|
int64_t max_block_size = get_available_workspace();
|
||||||
size_t max_workspace_size = 0u;
|
int64_t max_workspace_size = 0;
|
||||||
std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) {
|
std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) {
|
||||||
size_t curr_workspace_size = plan.getWorkspaceSize();
|
int64_t curr_workspace_size = plan.getWorkspaceSize();
|
||||||
if (curr_workspace_size <= max_block_size) {
|
if (curr_workspace_size <= max_block_size) {
|
||||||
if (curr_workspace_size > max_workspace_size) {
|
if (curr_workspace_size > max_workspace_size) {
|
||||||
max_workspace_size = plan.getWorkspaceSize();
|
max_workspace_size = plan.getWorkspaceSize();
|
||||||
@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
|
|||||||
if (remove_invalid) {
|
if (remove_invalid) {
|
||||||
cudnn_frontend::executionPlans_t new_valid_plans;
|
cudnn_frontend::executionPlans_t new_valid_plans;
|
||||||
for (auto &plan : valid_plans) {
|
for (auto &plan : valid_plans) {
|
||||||
if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
|
if (plan.getWorkspaceSize() <= max_workspace_size) {
|
||||||
new_valid_plans.emplace_back(std::move(plan));
|
new_valid_plans.emplace_back(std::move(plan));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,7 +5,7 @@
|
|||||||
// Note: The version below should not actually be 8000. Instead, it should
|
// Note: The version below should not actually be 8000. Instead, it should
|
||||||
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
|
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
|
||||||
// The version is set to 8000 today for convenience of debugging.
|
// The version is set to 8000 today for convenience of debugging.
|
||||||
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300
|
||||||
#define HAS_CUDNN_V8() true
|
#define HAS_CUDNN_V8() true
|
||||||
#else
|
#else
|
||||||
#define HAS_CUDNN_V8() false
|
#define HAS_CUDNN_V8() false
|
||||||
|
|||||||
1
defs.bzl
1
defs.bzl
@ -34,6 +34,7 @@ default_compiler_flags = [
|
|||||||
"-DTH_INDEX_BASE=0",
|
"-DTH_INDEX_BASE=0",
|
||||||
"-DMAGMA_V2",
|
"-DMAGMA_V2",
|
||||||
"-DNO_CUDNN_DESTROY_HANDLE",
|
"-DNO_CUDNN_DESTROY_HANDLE",
|
||||||
|
"-DUSE_EXPERIMENTAL_CUDNN_V8_API", # enable cudnn v8 api
|
||||||
"-DUSE_FBGEMM",
|
"-DUSE_FBGEMM",
|
||||||
"-DUSE_QNNPACK",
|
"-DUSE_QNNPACK",
|
||||||
"-DUSE_PYTORCH_QNNPACK",
|
"-DUSE_PYTORCH_QNNPACK",
|
||||||
|
|||||||
Reference in New Issue
Block a user