mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[cuDNN] [cuDNN v8 API] Support cuDNN Errata Filter (#73934)
Not originally mentioned in the tracking issue #58414, but is a nice-to-have feature. In summary, the errata filter allows known problematic kernels to be skipped instead of irrecoverably crashing a CUDA context (e.g., via an illegal memory access) via a JSON file supplied at run time. cuDNN frontend description: https://github.com/NVIDIA/cudnn-frontend#errata-filter Sample errata filter JSON: ``` { "version" : 1, "rules" : [ { "rule_id" : "avoid_bad_bwd_data", "operation" : "ConvBwdData", "engine" : 12, "cudnn_version_start" : 8000, "cudnn_version_end" : 9000 } ] } ``` CC @ngimel @zasdfgbnm @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/73934 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
c29df68f95
commit
fc66521ebd
@ -305,9 +305,20 @@ size_t get_available_workspace() {
|
||||
return max_block_size;
|
||||
}
|
||||
|
||||
static nlohmann::json errata_json_handle;
|
||||
|
||||
bool plan_errata_exception(const cudnnHandle_t handle, const std::string & executionPlanTag) {
|
||||
static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, "");
|
||||
if (!has_json) {
|
||||
return false;
|
||||
} else {
|
||||
return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, [](){return true;});
|
||||
}
|
||||
}
|
||||
|
||||
void generate_and_filter_plans(const cudnnHandle_t handle, cudnn_frontend::OperationGraph& opGraph, cudnn_frontend::EngineConfigGenerator& generator, const Tensor& x, cudnn_frontend::executionPlans_t& valid_plans, at::DataPtr& workspace_ptr, unsigned int max_plans = 0) {
|
||||
auto initial_predicate_function = [&](cudnn_frontend::ExecutionPlan const& plan) -> bool {
|
||||
return false;
|
||||
return plan_errata_exception(handle, plan.getTag());
|
||||
};
|
||||
auto plans = generator.cudnnGetPlan(handle, opGraph, initial_predicate_function);
|
||||
size_t max_block_size = get_available_workspace();
|
||||
@ -407,8 +418,9 @@ auto get_plans_from_find_fused(const cudnnHandle_t handle,
|
||||
|
||||
|
||||
// We only get configs from this stage to avoid building unnecessary plans that are never executed
|
||||
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
|
||||
auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const CacheKey& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
|
||||
auto opGraph = build_opgraph(handle, desc, x, y, w, key, padding, stride, dilation);
|
||||
opgraph_tag = opGraph.getTag();
|
||||
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
|
||||
auto sources = get_generator_sources(desc, x, deterministic, allow_tf32, heuristic_mode);
|
||||
|
||||
@ -417,8 +429,9 @@ auto get_configs_from_heuristics(const cudnnHandle_t handle, const cudnnBackendD
|
||||
return configs;
|
||||
}
|
||||
|
||||
auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
|
||||
auto get_configs_from_heuristics_fused(const cudnnHandle_t handle, std::string& opgraph_tag, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b, const float alpha, const CacheKeyFused& key, const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const bool deterministic, const bool allow_tf32) {
|
||||
auto opGraph = build_opgraph_fused(handle, x, y, w, z, b, alpha, key, padding, stride, dilation);
|
||||
opgraph_tag = opGraph.getTag();
|
||||
auto heuristic_mode = at::native::cudnnv8_use_heur_mode_b() ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT;
|
||||
auto sources = get_generator_sources(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, x, deterministic, allow_tf32, heuristic_mode);
|
||||
|
||||
@ -455,13 +468,16 @@ void try_plans_fused(cudnn_frontend::executionPlans_t& plans, const CacheKeyFuse
|
||||
TORCH_CHECK(false, "FIND was unable to find an engine to execute this computation");
|
||||
}
|
||||
|
||||
void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
|
||||
void try_configs(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKey& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w) {
|
||||
for (auto & config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config)
|
||||
.setEngineConfig(config, opgraph_tag)
|
||||
.build();
|
||||
if (plan_errata_exception(handle, plan.getTag())) {
|
||||
continue;
|
||||
}
|
||||
run_conv_plan(handle, x, y, w, plan);
|
||||
benchmark_cache.emplace(key, plan);
|
||||
return;
|
||||
@ -473,13 +489,16 @@ void try_configs(cudnn_frontend::EngineConfigList& configs, const CacheKey& key,
|
||||
TORCH_CHECK(false, "GET was unable to find an engine to execute this computation");
|
||||
}
|
||||
|
||||
void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
|
||||
void try_configs_fused(cudnn_frontend::EngineConfigList& configs, const std::string& opgraph_tag, const CacheKeyFused& key, const cudnnHandle_t handle, const Tensor& x, const Tensor& y, const Tensor& w, const Tensor& z, const Tensor& b) {
|
||||
for (auto & config : configs) {
|
||||
try {
|
||||
auto plan = cudnn_frontend::ExecutionPlanBuilder()
|
||||
.setHandle(handle)
|
||||
.setEngineConfig(config)
|
||||
.setEngineConfig(config, opgraph_tag)
|
||||
.build();
|
||||
if (plan_errata_exception(handle, plan.getTag())) {
|
||||
continue;
|
||||
}
|
||||
run_conv_plan_fused(handle, x, y, w, z, b, plan);
|
||||
benchmark_cache_fused.emplace(key, plan);
|
||||
return;
|
||||
@ -496,7 +515,6 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
|
||||
const IntArrayRef padding, const IntArrayRef stride, const IntArrayRef dilation, const int64_t groups,
|
||||
const bool benchmark, const bool deterministic, const bool allow_tf32) {
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
|
||||
CacheKey key;
|
||||
setCacheKey(key, operation, y, x, w, padding, stride, dilation, groups, deterministic, allow_tf32);
|
||||
// TODO: is this thread safe if cache is updated? is pointer stale?
|
||||
@ -509,13 +527,14 @@ void run_single_conv(const cudnnBackendDescriptorType_t operation,
|
||||
cudaGetLastError(); // clear CUDA error
|
||||
}
|
||||
}
|
||||
|
||||
if (!benchmark) {
|
||||
std::string opgraph_tag; // extra data needed for errata filter
|
||||
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics(handle, operation,
|
||||
opgraph_tag,
|
||||
x, y, w, key,
|
||||
padding, stride, dilation,
|
||||
deterministic, allow_tf32);
|
||||
try_configs(configs, key, handle, x, y, w);
|
||||
try_configs(configs, opgraph_tag, key, handle, x, y, w);
|
||||
} else {
|
||||
cudnn_frontend::executionPlans_t plans = get_plans_from_find(handle, operation,
|
||||
x, y, w, key,
|
||||
@ -544,13 +563,14 @@ void run_fused_conv(const Tensor& x, const Tensor& y, const Tensor& w, const Ten
|
||||
cudaGetLastError(); // clear CUDA error
|
||||
}
|
||||
}
|
||||
|
||||
if (!benchmark) {
|
||||
std::string opgraph_tag; // extra data needed for errata filter
|
||||
cudnn_frontend::EngineConfigList configs = get_configs_from_heuristics_fused(handle,
|
||||
opgraph_tag,
|
||||
x, y, w, z, b, alpha, key,
|
||||
padding, stride, dilation,
|
||||
deterministic, allow_tf32);
|
||||
try_configs_fused(configs, key, handle, x, y, w, z, b);
|
||||
try_configs_fused(configs, opgraph_tag, key, handle, x, y, w, z, b);
|
||||
} else {
|
||||
cudnn_frontend::executionPlans_t plans = get_plans_from_find_fused(handle,
|
||||
x, y, w, z, b, alpha, key,
|
||||
|
Reference in New Issue
Block a user