Check all CUDA API calls for errors in caffe2/ (#81816)

Test Plan: Sandcastle

Differential Revision: D35194868

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81816
Approved by: https://github.com/ezyang
This commit is contained in:
Richard Barnes
2022-10-28 00:41:04 +00:00
committed by PyTorch MergeBot
parent 3ece9fb45d
commit 7a3afe61d2
9 changed files with 43 additions and 44 deletions

View File

@ -235,7 +235,7 @@ static void Caffe2InitializeCuda() {
// a reserved flag for cudaDeviceEnablePeerAccess that should always be
// zero currently.
// It is ok if peer access is already enabled...
cudaError_t err = cudaDeviceEnablePeerAccess(j, 0);
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0));
if ((err != cudaErrorPeerAccessAlreadyEnabled) &&
(err != cudaSuccess)) {
CAFFE_THROW(cudaGetErrorString(err));
@ -351,7 +351,7 @@ struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator {
CUDA_ENFORCE(cudaHostUnregister(data));
GetDefaultCPUAllocator()->raw_deleter()(data);
} else {
cudaError_t err = cudaFreeHost(data);
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaFreeHost(data));
profiledCPUMemoryReporter().Delete(data);
if (err == cudaErrorInvalidValue) {
free(data);
@ -598,7 +598,7 @@ struct DefaultCUDAAllocator final : public at::Allocator {
switch (g_cuda_memory_pool_type) {
case CudaMemoryPoolType::NONE: {
// If memory pool is not set up, use simple cudaFree.
cudaError_t error = cudaFree(ptr);
cudaError_t error = C10_CUDA_ERROR_HANDLED(cudaFree(ptr));
// For some reason, in Python runtime we sometimes delete a data pointer
// after the cuda runtime exits - this is odd but is probably caused by
// a static workspace that pycaffe2 uses, and the destruction got

View File

@ -195,10 +195,6 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
// SwitchToDevice()
void FinishDeviceComputation() override {
CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_)));
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
CAFFE_THROW("Encountered CUDA error: ", cudaGetErrorString(error));
}
}
inline int device_id() const {
@ -309,11 +305,13 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
}
static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) {
auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id);
auto status = cudaStreamQuery(stream);
const auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id);
const auto status = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream));
if (status == cudaErrorNotReady) {
// ignore and clear the error if not ready
(void)cudaGetLastError();
C10_CUDA_CLEAR_ERROR();
} else {
C10_CUDA_CHECK(status); // Reraise error
}
return status == cudaSuccess;
}

View File

@ -145,15 +145,15 @@ void nms_gpu_upright(
// Overlapping CPU computes and D2H memcpy
// both take about the same time
cudaEvent_t copy_done;
cudaEventCreate(&copy_done);
C10_CUDA_CHECK(cudaEventCreate(&copy_done));
int nto_copy = std::min(CHUNK_SIZE, N);
CUDA_CHECK(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&h_delete_mask[0],
&d_delete_mask[0],
nto_copy * mask_ld * sizeof(int),
cudaMemcpyDeviceToHost,
context->cuda_stream()));
CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
int offset = 0;
std::vector<int> h_keep_sorted_list;
std::vector<int> rmv(mask_ld, 0);
@ -162,7 +162,7 @@ void nms_gpu_upright(
int next_offset = offset + ncopied;
nto_copy = std::min(CHUNK_SIZE, N - next_offset);
if (nto_copy > 0) {
CUDA_CHECK(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&h_delete_mask[next_offset * mask_ld],
&d_delete_mask[next_offset * mask_ld],
nto_copy * mask_ld * sizeof(int),
@ -170,9 +170,10 @@ void nms_gpu_upright(
context->cuda_stream()));
}
// Waiting for previous copy
CUDA_CHECK(cudaEventSynchronize(copy_done));
if (nto_copy > 0)
cudaEventRecord(copy_done, context->cuda_stream());
C10_CUDA_CHECK(cudaEventSynchronize(copy_done));
if (nto_copy > 0){
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
}
for (int i = offset; i < next_offset; ++i) {
int iblock = i / BOXES_PER_THREAD;
int inblock = i % BOXES_PER_THREAD;
@ -186,15 +187,15 @@ void nms_gpu_upright(
}
offset = next_offset;
}
cudaEventDestroy(copy_done);
C10_CUDA_CHECK(cudaEventDestroy(copy_done));
const int nkeep = h_keep_sorted_list.size();
cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
d_keep_sorted_list,
&h_keep_sorted_list[0],
nkeep * sizeof(int),
cudaMemcpyHostToDevice,
context->cuda_stream());
context->cuda_stream()));
*h_nkeep = nkeep;
}
@ -502,15 +503,15 @@ void nms_gpu_rotated(
// Overlapping CPU computes and D2H memcpy
// both take about the same time
cudaEvent_t copy_done;
cudaEventCreate(&copy_done);
C10_CUDA_CHECK(cudaEventCreate(&copy_done));
int nto_copy = std::min(CHUNK_SIZE, N);
CUDA_CHECK(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&h_delete_mask[0],
&d_delete_mask[0],
nto_copy * mask_ld * sizeof(int),
cudaMemcpyDeviceToHost,
context->cuda_stream()));
CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
int offset = 0;
std::vector<int> h_keep_sorted_list;
std::vector<int> rmv(mask_ld, 0);
@ -519,7 +520,7 @@ void nms_gpu_rotated(
int next_offset = offset + ncopied;
nto_copy = std::min(CHUNK_SIZE, N - next_offset);
if (nto_copy > 0) {
CUDA_CHECK(cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
&h_delete_mask[next_offset * mask_ld],
&d_delete_mask[next_offset * mask_ld],
nto_copy * mask_ld * sizeof(int),
@ -527,9 +528,10 @@ void nms_gpu_rotated(
context->cuda_stream()));
}
// Waiting for previous copy
CUDA_CHECK(cudaEventSynchronize(copy_done));
if (nto_copy > 0)
cudaEventRecord(copy_done, context->cuda_stream());
C10_CUDA_CHECK(cudaEventSynchronize(copy_done));
if (nto_copy > 0){
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
}
for (int i = offset; i < next_offset; ++i) {
int iblock = i / BOXES_PER_THREAD;
int inblock = i % BOXES_PER_THREAD;
@ -543,15 +545,15 @@ void nms_gpu_rotated(
}
offset = next_offset;
}
cudaEventDestroy(copy_done);
C10_CUDA_CHECK(cudaEventDestroy(copy_done));
const int nkeep = h_keep_sorted_list.size();
cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
d_keep_sorted_list,
&h_keep_sorted_list[0],
nkeep * sizeof(int),
cudaMemcpyHostToDevice,
context->cuda_stream());
context->cuda_stream()));
*h_nkeep = nkeep;
}

View File

@ -691,7 +691,7 @@ TEST(UtilsNMSTest, TestPerfRotatedNMS) {
// list_nitems * sizeof(int),
// cudaMemcpyDeviceToHost,
// cuda_context.cuda_stream()));
// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream()));
// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream());
// ASSERT_EQ(keep.size(), gpu_keep.size());
// std::sort(keep.begin(), keep.end());

View File

@ -130,8 +130,7 @@ void CUDARecurrentNetworkExecutor::_ExecRange(int from, int to) {
for (int stream_id = 0; stream_id <= std::min(stream_seq, max_streams - 1);
stream_id++) {
VLOG(1) << "Wait for stream:" << stream_id;
CUDA_CHECK(
cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id)));
CUDA_CHECK(cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id)));
}
}

View File

@ -138,9 +138,9 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
}
}
}
cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize);
cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
cudaMemcpyHostToDevice);
C10_CUDA_CHECK(cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize));
C10_CUDA_CHECK(cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
cudaMemcpyHostToDevice));
// ScaleBlobsCUDAKernelBalanced kernel launch
ScaleBlobsCUDAKernelBalanced<T>
@ -150,7 +150,7 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
dOutputArr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
cudaFree(dStartCoorArr);
C10_CUDA_CHECK(cudaFree(dStartCoorArr));
*/
template <typename T>

View File

@ -305,7 +305,7 @@ CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(float)
return; \
} \
if (alpha == T(0)) { \
cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream()); \
C10_CUDA_CHECK(cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream())); \
} else { \
thrust::fill( \
thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \

View File

@ -418,12 +418,12 @@ void MomentsCUDA(
return;
}
if (std::equal(X_dims, X_dims + ndim, Y_dims)) {
cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
mean,
X,
sizeof(T) * X_size,
cudaMemcpyDeviceToDevice,
context->cuda_stream());
context->cuda_stream()));
Set<T, CUDAContext>(Y_size, T(0), var, context);
return;
}

View File

@ -2685,12 +2685,12 @@ CAFFE2_CUDA_EXPORT void CopyVector<float, CUDAContext>(
float* dst,
CUDAContext* context) {
if (src != dst && N > 0) {
cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
dst,
src,
sizeof(float) * N,
cudaMemcpyDeviceToDevice,
context->cuda_stream());
context->cuda_stream()));
}
}
@ -2701,12 +2701,12 @@ CAFFE2_CUDA_EXPORT void CopyVector<int, CUDAContext>(
int* dst,
CUDAContext* context) {
if (src != dst && N > 0) {
cudaMemcpyAsync(
C10_CUDA_CHECK(cudaMemcpyAsync(
dst,
src,
sizeof(int) * N,
cudaMemcpyDeviceToDevice,
context->cuda_stream());
context->cuda_stream()));
}
}