mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Currently the C10_CUDA_CHECK only shows source location in CUDAException like below: ``` Exception raised from c10_cuda_check_implementation at fbcode/caffe2/c10/cuda/CUDAException.cpp:44 ``` which is not terribly useful. By checking the original diff D39619861 that introduced c10_cuda_check_implementation, it seems the original macro would show the source location correctly but c10_cuda_check_implementation broke it. This diff will propagate caller source location to c10_cuda_check_implementation to fix the issue. Test Plan: CI Observed desired error message after the change: ``` CUDA error: an illegal memory access was encountered Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Device-side assertion tracking was not enabled by user. Exception raised from operator() at fbcode/sigrid/predictor/aed/AedContainer.cpp:659 (most recent call first): ``` Note the last line reports actual caller location. Rollback Plan: Reviewed By: Raymo111 Differential Revision: D81880552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162808 Approved by: https://github.com/janeyx99
48 lines
1.5 KiB
C++
48 lines
1.5 KiB
C++
#include <c10/cuda/CUDAException.h>
|
|
|
|
#include <c10/cuda/CUDADeviceAssertionHost.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <string>
|
|
|
|
namespace c10::cuda {
|
|
|
|
void c10_cuda_check_implementation(
|
|
const int32_t err,
|
|
const char* filename,
|
|
const char* function_name,
|
|
const uint32_t line_number,
|
|
const bool include_device_assertions) {
|
|
const auto cuda_error = static_cast<cudaError_t>(err);
|
|
const auto cuda_kernel_failure = include_device_assertions
|
|
? c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref().has_failed()
|
|
: false;
|
|
|
|
if (C10_LIKELY(cuda_error == cudaSuccess && !cuda_kernel_failure)) {
|
|
return;
|
|
}
|
|
|
|
[[maybe_unused]] auto error_unused = cudaGetLastError();
|
|
|
|
std::string check_message;
|
|
#ifndef STRIP_ERROR_MESSAGES
|
|
check_message.append("CUDA error: ");
|
|
const char* error_string = cudaGetErrorString(cuda_error);
|
|
check_message.append(error_string);
|
|
check_message.append(c10::cuda::get_cuda_error_help(cuda_error));
|
|
check_message.append(c10::cuda::get_cuda_check_suffix());
|
|
check_message.append("\n");
|
|
if (include_device_assertions) {
|
|
check_message.append(c10_retrieve_device_side_assertion_info());
|
|
} else {
|
|
check_message.append(
|
|
"Device-side assertions were explicitly omitted for this error check; the error probably arose while initializing the DSA handlers.");
|
|
}
|
|
#endif
|
|
throw c10::AcceleratorError(
|
|
{function_name, filename, line_number}, err, check_message);
|
|
}
|
|
|
|
} // namespace c10::cuda
|