mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Summary: Currently the C10_CUDA_CHECK only shows source location in CUDAException like below: ``` Exception raised from c10_cuda_check_implementation at fbcode/caffe2/c10/cuda/CUDAException.cpp:44 ``` which is not terribly useful. By checking the original diff D39619861 that introduced c10_cuda_check_implementation, it seems the original macro would show the source location correctly but c10_cuda_check_implementation broke it. This diff will propagate caller source location to c10_cuda_check_implementation to fix the issue. Test Plan: CI Observed desired error message after the change: ``` CUDA error: an illegal memory access was encountered Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Device-side assertion tracking was not enabled by user. Exception raised from operator() at fbcode/sigrid/predictor/aed/AedContainer.cpp:659 (most recent call first): ``` Note the last line reports actual caller location. Rollback Plan: Reviewed By: Raymo111 Differential Revision: D81880552 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162808 Approved by: https://github.com/janeyx99
98 lines
4.3 KiB
C++
98 lines
4.3 KiB
C++
#pragma once
|
|
|
|
#include <c10/cuda/CUDADeviceAssertionHost.h>
|
|
#include <c10/cuda/CUDAMacros.h>
|
|
#include <c10/cuda/CUDAMiscFunctions.h>
|
|
#include <c10/macros/Macros.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/irange.h>
|
|
#include <cuda.h>
|
|
|
|
// Note [CHECK macro]
|
|
// ~~~~~~~~~~~~~~~~~~
|
|
// This is a macro so that AT_ERROR can get accurate __LINE__
|
|
// and __FILE__ information. We could split this into a short
|
|
// macro and a function implementation if we pass along __LINE__
|
|
// and __FILE__, but no one has found this worth doing.
|
|
|
|
// Used to denote errors from CUDA framework.
|
|
// This needs to be declared here instead util/Exception.h for proper conversion
|
|
// during hipify.
|
|
namespace c10 {
|
|
class C10_CUDA_API CUDAError : public c10::Error {
|
|
using Error::Error;
|
|
};
|
|
} // namespace c10
|
|
|
|
#define C10_CUDA_CHECK(EXPR) \
|
|
do { \
|
|
const cudaError_t __err = EXPR; \
|
|
c10::cuda::c10_cuda_check_implementation( \
|
|
static_cast<int32_t>(__err), \
|
|
__FILE__, \
|
|
__func__, /* Line number data type not well-defined between \
|
|
compilers, so we perform an explicit cast */ \
|
|
static_cast<uint32_t>(__LINE__), \
|
|
true); \
|
|
} while (0)
|
|
|
|
#define C10_CUDA_CHECK_WARN(EXPR) \
|
|
do { \
|
|
const cudaError_t __err = EXPR; \
|
|
if (C10_UNLIKELY(__err != cudaSuccess)) { \
|
|
[[maybe_unused]] auto error_unused = cudaGetLastError(); \
|
|
TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Indicates that a CUDA error is handled in a non-standard way
|
|
#define C10_CUDA_ERROR_HANDLED(EXPR) EXPR
|
|
|
|
// Intentionally ignore a CUDA error
|
|
#define C10_CUDA_IGNORE_ERROR(EXPR) \
|
|
do { \
|
|
const cudaError_t __err = EXPR; \
|
|
if (C10_UNLIKELY(__err != cudaSuccess)) { \
|
|
[[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Clear the last CUDA error
|
|
#define C10_CUDA_CLEAR_ERROR() \
|
|
do { \
|
|
[[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \
|
|
} while (0)
|
|
|
|
// This should be used directly after every kernel launch to ensure
|
|
// the launch happened correctly and provide an early, close-to-source
|
|
// diagnostic if it didn't.
|
|
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
|
|
|
|
/// Launches a CUDA kernel appending to it all the information need to handle
|
|
/// device-side assertion failures. Checks that the launch was successful.
|
|
#define TORCH_DSA_KERNEL_LAUNCH( \
|
|
kernel, blocks, threads, shared_mem, stream, ...) \
|
|
do { \
|
|
auto& launch_registry = \
|
|
c10::cuda::CUDAKernelLaunchRegistry::get_singleton_ref(); \
|
|
kernel<<<blocks, threads, shared_mem, stream>>>( \
|
|
__VA_ARGS__, \
|
|
launch_registry.get_uvm_assertions_ptr_for_current_device(), \
|
|
launch_registry.insert( \
|
|
__FILE__, __FUNCTION__, __LINE__, #kernel, stream.id())); \
|
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); \
|
|
} while (0)
|
|
|
|
namespace c10::cuda {
|
|
|
|
/// In the event of a CUDA failure, formats a nice error message about that
|
|
/// failure and also checks for device-side assertion failures
|
|
C10_CUDA_API void c10_cuda_check_implementation(
|
|
const int32_t err,
|
|
const char* filename,
|
|
const char* function_name,
|
|
const uint32_t line_number,
|
|
const bool include_device_assertions);
|
|
|
|
} // namespace c10::cuda
|