mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[cudnn] Support v8 API in fbcode (#96512)
Summary: It turns out we never turn on cudnn v8 API which blocks bf16 conv. Enable the new v8 API Test Plan: buck run mode/dev-nosan scripts/xdwang/example:fc_pytorch Reviewed By: ngimel Differential Revision: D43784279 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96512 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
fe0afc5852
commit
788300cc2a
@ -45,7 +45,7 @@ namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
// TODO: remove duplicate code in Conv_v7.cpp
|
||||
constexpr size_t operator "" _TiB(unsigned long long n) {
|
||||
constexpr int64_t operator "" _TiB(unsigned long long n) {
|
||||
return size_t(n) << 40;
|
||||
}
|
||||
|
||||
@ -323,12 +323,12 @@ auto get_generator_sources(const cudnnBackendDescriptorType_t& desc, const Tenso
|
||||
}
|
||||
}
|
||||
|
||||
size_t get_available_workspace() {
|
||||
int64_t get_available_workspace() {
|
||||
int device;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&device));
|
||||
size_t max_block_size = 0;
|
||||
c10::cuda::CUDACachingAllocator::cacheInfo(device, &max_block_size);
|
||||
return max_block_size;
|
||||
return static_cast<int64_t>(max_block_size);
|
||||
}
|
||||
|
||||
static nlohmann::json errata_json_handle;
|
||||
@ -347,10 +347,10 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
|
||||
return plan_errata_exception(handle, plan.getTag());
|
||||
};
|
||||
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
|
||||
size_t max_block_size = get_available_workspace();
|
||||
size_t max_workspace_size = 0u;
|
||||
int64_t max_block_size = get_available_workspace();
|
||||
int64_t max_workspace_size = 0;
|
||||
std::for_each(plans.begin(), plans.end(), [&] (cudnn_frontend::ExecutionPlan& plan) {
|
||||
size_t curr_workspace_size = plan.getWorkspaceSize();
|
||||
int64_t curr_workspace_size = plan.getWorkspaceSize();
|
||||
if (curr_workspace_size <= max_block_size) {
|
||||
if (curr_workspace_size > max_workspace_size) {
|
||||
max_workspace_size = plan.getWorkspaceSize();
|
||||
@ -373,7 +373,7 @@ void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::Opera
|
||||
if (remove_invalid) {
|
||||
cudnn_frontend::executionPlans_t new_valid_plans;
|
||||
for (auto &plan : valid_plans) {
|
||||
if (static_cast<size_t>(plan.getWorkspaceSize()) <= max_workspace_size) {
|
||||
if (plan.getWorkspaceSize() <= max_workspace_size) {
|
||||
new_valid_plans.emplace_back(std::move(plan));
|
||||
}
|
||||
}
|
||||
|
@ -5,7 +5,7 @@
|
||||
// Note: The version below should not actually be 8000. Instead, it should
|
||||
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
|
||||
// The version is set to 8000 today for convenience of debugging.
|
||||
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
|
||||
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8300
|
||||
#define HAS_CUDNN_V8() true
|
||||
#else
|
||||
#define HAS_CUDNN_V8() false
|
||||
|
Reference in New Issue
Block a user