From fc66521ebddeb2f0cf711a0bddabae412bf92923 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Fri, 3 Jun 2022 06:25:54 +0000 Subject: [PATCH] [cuDNN] [cuDNN v8 API] Support cuDNN Errata Filter (#73934) Not originally mentioned in the tracking issue #58414, but is a nice-to-have feature. In summary, the errata filter allows known problematic kernels to be skipped instead of irrecoverably crashing a CUDA context (e.g., via an illegal memory access) via a JSON file supplied at run time. cuDNN frontend description: https://github.com/NVIDIA/cudnn-frontend#errata-filter Sample errata filter JSON: ``` { "version" : 1, "rules" : [ { "rule_id" : "avoid_bad_bwd_data", "operation" : "ConvBwdData", "engine" : 12, "cudnn_version_start" : 8000, "cudnn_version_end" : 9000 } ] } ``` CC @ngimel @zasdfgbnm @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/73934 Approved by: https://github.com/ngimel --- aten/src/ATen/native/cudnn/Conv_v8.cpp | 44 +++++++++++++++++++------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 24c5f3c2e3d6..843fb5297050 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -305,9 +305,20 @@ size_t get_available_workspace() { return max_block_size; } +static nlohmann::json errata_json_handle; + +bool plan_errata_exception(const cudnnHandle_t handle, const std::string & executionPlanTag) { + static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, ""); + if (!has_json) { + return false; + } else { + return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, [](){return true;}); + } +} + void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::OperationGraph& opGraph, cudnn_frontend::EngineConfigGenerator& generator, const Tensor& x, cudnn_frontend::executionPlans_t& valid_plans, at::DataPtr& workspace_ptr, unsigned int max_plans = 0) { auto initial_predicate_function = [&](cudnn_frontend::ExecutionPlan const& plan) -> bool { - return false; + return plan_errata_exception(handle, plan.getTag()); }; auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function); size_t max_block_size = get_available_workspace(); @@ -407,8 +418,9 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle, // We only get configs from this stage to avoid building unnecessary plans that are never executed -auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) { +auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) { auto opGraph = build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation); + opgraph_tag = opGraph.getTag(); auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT; auto sources = get_generator_sources(desc, x, deterministic, allow_tf32, heuristic_mode); @@ -417,8 +429,9 @@ auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendD return configs; } -auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) { +auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) { auto opGraph = build_opgraph_fused(handle, x, y, w, z, b, alpha, key, padding, stride, dilation); + opgraph_tag = opGraph.getTag(); auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT; auto sources = get_generator_sources(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, x, deterministic, allow_tf32, heuristic_mode); @@ -455,13 +468,16 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse TORCH_CHECK(false, "FIND was unable to find an engine to execute this computation"); } -void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) { +void try_configs(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) { for (auto & config : configs) { try { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) - .setEngineConfig(config) + .setEngineConfig(config, opgraph_tag) .build(); + if (plan_errata_exception(handle, plan.getTag())) { + continue; + } run_conv_plan(handle, x, y, w, plan); benchmark_cache.emplace(key, plan); return; @@ -473,13 +489,16 @@ void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key, TORCH_CHECK(false, "GET was unable to find an engine to execute this computation"); } -void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) { +void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) { for (auto & config : configs) { try { auto plan = cudnn_frontend::ExecutionPlanBuilder() .setHandle(handle) - .setEngineConfig(config) + .setEngineConfig(config, opgraph_tag) .build(); + if (plan_errata_exception(handle, plan.getTag())) { + continue; + } run_conv_plan_fused(handle, x, y, w, z, b, plan); benchmark_cache_fused.emplace(key, plan); return; @@ -496,7 +515,6 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups, const bool benchmark, const bool deterministic, const bool allow_tf32) { cudnnHandle_t handle = getCudnnHandle(); - CacheKey key; setCacheKey(key, operation, y, x, w, padding, stride, dilation, groups, deterministic, allow_tf32); // TODO: is this thread safe if cache is updated? is pointer stale? @@ -509,13 +527,14 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation, cudaGetLastError(); // clear CUDA error } } - if (!benchmark) { + std::string opgraph_tag; // extra data needed for errata filter cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(handle, operation, + opgraph_tag, x, y, w, key, padding, stride, dilation, deterministic, allow_tf32); - try_configs(configs, key, handle, x, y, w); + try_configs(configs, opgraph_tag, key, handle, x, y, w); } else { cudnn_frontend::executionPlans_t plans = get_plans_from_find(handle, operation, x, y, w, key, @@ -544,13 +563,14 @@ void run_fused_conv(const Tensor& x, const Tensor& y, const Tensor& w, const Ten cudaGetLastError(); // clear CUDA error } } - if (!benchmark) { + std::string opgraph_tag; // extra data needed for errata filter cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics_fused(handle, + opgraph_tag, x, y, w, z, b, alpha, key, padding, stride, dilation, deterministic, allow_tf32); - try_configs_fused(configs, key, handle, x, y, w, z, b); + try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b); } else { cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(handle, x, y, w, z, b, alpha, key,