step 0 of cuDNN v8 convolution API integration (#51390)

Summary:
This PR is step 0 of adding PyTorch convolution bindings using the cuDNN frontend. The cuDNN frontend is the recommended way of using cuDNN v8 API. It is supposed to have faster release cycles, so that, for example, if people find a specific kernel has a bug, they can report it, and that kernel will be blocked in the cuDNN frontend and frameworks could just update that submodule without the need for waiting for a whole cuDNN release.

The work is not complete, and this PR is only step 0.

**What this PR does:**
- Add cudnn-frontend as a submodule.
- Modify cmake to build that submodule.
- Add bindings for convolution forward in `Conv_v8.cpp`, which is disabled by a macro by default.
- Tested manually by enabling the macro and run `test_nn.py`. All tests pass except those mentioned below.

**What this PR doesn't:**
- Only convolution forward, no backward. The backward will use v7 API.
- No 64bit-indexing support for some configuration. This is a known issue of cuDNN, and will be fixed in a later cuDNN version. PyTorch will not implement any workaround for issue, but instead, v8 API should be disabled on problematic cuDNN versions.
- No test beyond PyTorch's unit tests.
  - Not tested for correctness on real models.
  - Not benchmarked for performance.
- Benchmark cache is not thread-safe. (This is marked as `FIXME` in the code, and will be fixed in a follow-up PR)
- cuDNN benchmark is not supported.
- There are failing tests, which will be resolved later:
  ```
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_cudnn_nhwc_cuda_float16 - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0.001 and atol=1e-05, found 32 element(s) (out of 32) whose difference(s) exceeded the margin of error (in...
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_cudnn_nhwc_cuda_float32 - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=1.3e-06 and atol=1e-05, found 32 element(s) (out of 32) whose difference(s) exceeded the margin of error (...
  FAILED test/test_nn.py::TestNNDeviceTypeCUDA::test_conv_large_cuda - RuntimeError: CUDNN_BACKEND_OPERATION: cudnnFinalize Failed cudnn_status: 9
  FAILED test/test_nn.py::TestNN::test_Conv2d_depthwise_naive_groups_cuda - AssertionError: False is not true : Tensors failed to compare as equal!With rtol=0 and atol=1e-05, found 64 element(s) (out of 64) whose difference(s) exceeded the margin of error (including 0 an...
  FAILED test/test_nn.py::TestNN::test_Conv2d_deterministic_cudnn - RuntimeError: not supported yet
  FAILED test/test_nn.py::TestNN::test_ConvTranspose2d_groups_cuda_fp32 - RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
  FAILED test/test_nn.py::TestNN::test_ConvTranspose2d_groups_cuda_tf32 - RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM
  ```

Although this is not a complete implementation of cuDNN v8 API binding, I still want to merge this first. This would allow me to do small and incremental work, for the ease of development and review.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51390

Reviewed By: malfet

Differential Revision: D28513167

Pulled By: ngimel

fbshipit-source-id: 9cc20c9dec5bbbcb1f94ac9e0f59b10c34f62740
This commit is contained in:
Xiang Gao
2021-05-19 12:53:03 -07:00
committed by Facebook GitHub Bot
parent 954d39ba38
commit 6c70cbedb6
10 changed files with 217 additions and 4 deletions

3
.gitmodules vendored
View File

@ -130,6 +130,9 @@
ignore = dirty
path = third_party/tensorpipe
url = https://github.com/pytorch/tensorpipe.git
[submodule "third_party/cudnn_frontend"]
path = third_party/cudnn_frontend
url = https://github.com/NVIDIA/cudnn-frontend.git
[submodule "third_party/kineto"]
path = third_party/kineto
url = https://github.com/pytorch/kineto

View File

@ -194,6 +194,9 @@ cmake_dependent_option(
cmake_dependent_option(
USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(
USE_EXPERIMENTAL_CUDNN_V8_API "Use experimental cuDNN v8 API" OFF
"USE_CUDNN" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" OFF)

View File

@ -2,6 +2,8 @@
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#include <limits>
#include <vector>
#include <sstream>
@ -614,6 +616,8 @@ if (args.params.dataType == CUDNN_DATA_FLOAT) {
//
// ---------------------------------------------------------------------
#if !HAS_CUDNN_V8()
void raw_cudnn_convolution_forward_out_32bit(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
@ -665,6 +669,8 @@ void raw_cudnn_convolution_forward_out(
split_batch_dim_to_32bit_out(output, input, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, 1024 * 1024 * 256, raw_cudnn_convolution_forward_out_32bit);
}
#endif // !HAS_CUDNN_V8()
// ---------------------------------------------------------------------
//
// Convolution backward / Transposed convolution forward

View File

@ -1,5 +1,177 @@
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
#if AT_CUDNN_ENABLED() && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
// Coming soon
#endif // AT_CUDNN_ENABLED and CUDNN_VERSION
#if AT_CUDNN_ENABLED()
#include <ATen/native/cudnn/Macros.h>
#if HAS_CUDNN_V8()
#include <ATen/cudnn/cudnn-wrapper.h>
#include <cudnn_frontend.h>
#include <ATen/ATen.h>
#include <ATen/TensorUtils.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/native/ConvUtils.h>
#include <ATen/native/cudnn/ConvShared.h>
#include <ATen/native/utils/ParamsHash.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/TensorUtils.h>
#include <unordered_map>
namespace at { namespace native{
namespace {
uint8_t getAlignment(const Tensor &t) {
// alignment are in bytes
uint8_t alignment = 1;
uint64_t address = reinterpret_cast<uint64_t>(t.data_ptr());
while (address % alignment == 0 && alignment < 16) alignment *= 2;
return alignment;
}
cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, int64_t id, uint8_t alignment) {
auto shape = t.sizes();
auto strides = t.strides();
return cudnn_frontend::TensorBuilder()
.setDim(shape.size(), shape.data())
.setStrides(strides.size(), strides.data())
.setId(id)
.setAlignment(alignment)
.setDataType(getCudnnDataType(t))
.build();
}
cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation) {
uint64_t convDim = stride.size();
return cudnn_frontend::ConvDescBuilder()
.setDataType(dataType)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, stride.data())
.setPrePadding(convDim, padding.data())
.setPostPadding(convDim, padding.data())
.setDilation(convDim, dilation.data())
.build();
}
void filterEngineConfigs(
cudnn_frontend::EngineConfigList &from,
cudnn_frontend::EngineConfigList &to,
bool deterministic, bool allow_tf32, c10::ScalarType scalar_type)
{
auto filter = [=](cudnnBackendDescriptor_t c) {
if (deterministic) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) return true;
}
if (scalar_type == kFloat || !allow_tf32) {
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) return true;
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) return true;
}
return false;
};
cudnn_frontend::filter(from, to, filter);
}
struct CacheKey {
ConvolutionParams params;
uint8_t input_alignment;
uint8_t weight_alignment;
uint8_t output_alignment;
};
// FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp
std::unordered_map<CacheKey, cudnn_frontend::ManagedOpaqueDescriptor, ParamsHash<CacheKey>, ParamsEqual<CacheKey>> engine_cache;
}
void raw_cudnn_convolution_forward_out(
const Tensor& output, const Tensor& input, const Tensor& weight,
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
bool benchmark, bool deterministic, bool allow_tf32)
{
TORCH_CHECK(!benchmark, "not supported yet");
if (output.numel() == 0) {
return;
}
cudnnHandle_t handle = getCudnnHandle();
CacheKey key;
setConvolutionParams(&key.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32);
key.input_alignment = getAlignment(input);
key.output_alignment = getAlignment(output);
key.weight_alignment = getAlignment(weight);
auto run = [&](cudnn_frontend::ManagedOpaqueDescriptor cfg) {
auto plan = cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle)
.setEngineConfig(cfg)
.build();
auto workspace_size = plan.getWorkspaceSize();
auto workspace = at::empty({workspace_size}, input.options().dtype(kByte));
void *data_ptrs[] = {input.data_ptr(), output.data_ptr(), weight.data_ptr()};
// std::cout << plan.describe() << " requires workspace " << workspace_size << std::endl;
int64_t uids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace.data_ptr())
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
AT_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc()));
};
auto search = engine_cache.find(key);
if (search != engine_cache.end()) {
run(search->second);
return;
}
auto op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(getTensorDescriptor(input, 'x', key.input_alignment))
.setyDesc(getTensorDescriptor(output, 'y', key.output_alignment))
.setwDesc(getTensorDescriptor(weight, 'w', key.weight_alignment))
.setcDesc(getConvDescriptor(key.params.dataType, padding, stride, dilation))
.build();
// std::cout << op.describe() << std::endl;
std::array<cudnn_frontend::Operation const *, 1> ops = {&op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle)
.setOperationGraph(1, ops.data())
.build();
// std::cout << opGraph.describe() << std::endl;
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
auto fallback = cudnn_frontend::EngineFallbackListBuilder()
.setOperationGraph(opGraph)
.setOperation(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.build();
auto& engine_configs = heuristics.getEngineConfig(heuristics.getEngineConfigCount());
auto& fallback_list = fallback.getFallbackList();
cudnn_frontend::EngineConfigList filtered_configs;
filterEngineConfigs(engine_configs, filtered_configs, deterministic, allow_tf32, input.scalar_type());
filterEngineConfigs(fallback_list, filtered_configs, deterministic, allow_tf32, input.scalar_type());
for (auto &cfg : filtered_configs) {
try {
run(cfg);
engine_cache[key] = cfg;
return;
} catch (cudnn_frontend::cudnnException &e) {} catch(CuDNNError &e) {}
}
TORCH_CHECK(false, "Unable to find an engine to execute this computation");
}
}} // at::native
#endif // HAS_CUDNN_V8
#endif // AT_CUDNN_ENABLED

View File

@ -0,0 +1,12 @@
#pragma once
#include <ATen/cudnn/cudnn-wrapper.h>
// Note: The version below should not actually be 8000. Instead, it should
// be whatever version of cuDNN that v8 API work with PyTorch correctly.
// The version is set to 8000 today for convenience of debugging.
#if defined(USE_EXPERIMENTAL_CUDNN_V8_API) && defined(CUDNN_VERSION) && CUDNN_VERSION >= 8000
#define HAS_CUDNN_V8() true
#else
#define HAS_CUDNN_V8() false
#endif

View File

@ -1304,6 +1304,15 @@ elseif(USE_ROCM)
target_compile_definitions(torch_hip PRIVATE "-DTORCH_HIP_BUILD_MAIN_LIB")
endif()
if(USE_EXPERIMENTAL_CUDNN_V8_API)
if(BUILD_SPLIT_CUDA)
target_compile_definitions(torch_cuda_cu PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
target_compile_definitions(torch_cuda_cpp PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
elseif(USE_CUDA)
target_compile_definitions(torch_cuda PRIVATE "-DUSE_EXPERIMENTAL_CUDNN_V8_API")
endif()
endif()
set(EXPERIMENTAL_SINGLE_THREAD_POOL "0" CACHE STRING
"Experimental option to use a single thread pool for inter- and intra-op parallelism")
if("${EXPERIMENTAL_SINGLE_THREAD_POOL}")

View File

@ -1189,6 +1189,12 @@ if(USE_CUDA)
endif()
endif()
# ---[ cuDNN
if(USE_CUDNN)
set(CUDNN_FRONTEND_INCLUDE_DIR ${CMAKE_CURRENT_LIST_DIR}/../third_party/cudnn_frontend/include)
include_directories(${CUDNN_FRONTEND_INCLUDE_DIR})
endif()
# ---[ HIP
if(USE_ROCM)
include(${CMAKE_CURRENT_LIST_DIR}/public/LoadHIP.cmake)

View File

@ -74,6 +74,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " Split CUDA : ${BUILD_SPLIT_CUDA}")
message(STATUS " CUDA static link : ${CAFFE2_STATIC_LINK_CUDA}")
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
if(${USE_CUDNN})
message(STATUS " cuDNN version : ${CUDNN_VERSION}")

View File

@ -332,7 +332,7 @@ def check_submodules():
print('Please run:\n\tgit submodule update --init --recursive')
sys.exit(1)
for folder in folders:
check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE"])
check_for_files(folder, ["CMakeLists.txt", "Makefile", "setup.py", "LICENSE", "LICENSE.txt"])
check_for_files(os.path.join(third_party_path, 'fbgemm', 'third_party',
'asmjit'), ['CMakeLists.txt'])
check_for_files(os.path.join(third_party_path, 'onnx', 'third_party',

1
third_party/cudnn_frontend vendored Submodule