mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Check all CUDA API calls for errors in caffe2/ (#81816)
Test Plan: Sandcastle Differential Revision: D35194868 Pull Request resolved: https://github.com/pytorch/pytorch/pull/81816 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3ece9fb45d
commit
7a3afe61d2
@ -235,7 +235,7 @@ static void Caffe2InitializeCuda() {
|
||||
// a reserved flag for cudaDeviceEnablePeerAccess that should always be
|
||||
// zero currently.
|
||||
// It is ok if peer access is already enabled...
|
||||
cudaError_t err = cudaDeviceEnablePeerAccess(j, 0);
|
||||
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaDeviceEnablePeerAccess(j, 0));
|
||||
if ((err != cudaErrorPeerAccessAlreadyEnabled) &&
|
||||
(err != cudaSuccess)) {
|
||||
CAFFE_THROW(cudaGetErrorString(err));
|
||||
@ -351,7 +351,7 @@ struct CAFFE2_CUDA_API PinnedCPUAllocator final : public at::Allocator {
|
||||
CUDA_ENFORCE(cudaHostUnregister(data));
|
||||
GetDefaultCPUAllocator()->raw_deleter()(data);
|
||||
} else {
|
||||
cudaError_t err = cudaFreeHost(data);
|
||||
cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaFreeHost(data));
|
||||
profiledCPUMemoryReporter().Delete(data);
|
||||
if (err == cudaErrorInvalidValue) {
|
||||
free(data);
|
||||
@ -598,7 +598,7 @@ struct DefaultCUDAAllocator final : public at::Allocator {
|
||||
switch (g_cuda_memory_pool_type) {
|
||||
case CudaMemoryPoolType::NONE: {
|
||||
// If memory pool is not set up, use simple cudaFree.
|
||||
cudaError_t error = cudaFree(ptr);
|
||||
cudaError_t error = C10_CUDA_ERROR_HANDLED(cudaFree(ptr));
|
||||
// For some reason, in Python runtime we sometimes delete a data pointer
|
||||
// after the cuda runtime exits - this is odd but is probably caused by
|
||||
// a static workspace that pycaffe2 uses, and the destruction got
|
||||
|
@ -195,10 +195,6 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
|
||||
// SwitchToDevice()
|
||||
void FinishDeviceComputation() override {
|
||||
CUDA_ENFORCE(cudaStreamSynchronize(getCudaObjects().GetStream(gpu_id_)));
|
||||
cudaError_t error = cudaGetLastError();
|
||||
if (error != cudaSuccess) {
|
||||
CAFFE_THROW("Encountered CUDA error: ", cudaGetErrorString(error));
|
||||
}
|
||||
}
|
||||
|
||||
inline int device_id() const {
|
||||
@ -309,11 +305,13 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
|
||||
}
|
||||
|
||||
static bool IsStreamFree(const DeviceOption& option, StreamId stream_id) {
|
||||
auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id);
|
||||
auto status = cudaStreamQuery(stream);
|
||||
const auto stream = CUDAContext::cuda_stream(option.device_id(), stream_id);
|
||||
const auto status = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream));
|
||||
if (status == cudaErrorNotReady) {
|
||||
// ignore and clear the error if not ready
|
||||
(void)cudaGetLastError();
|
||||
C10_CUDA_CLEAR_ERROR();
|
||||
} else {
|
||||
C10_CUDA_CHECK(status); // Reraise error
|
||||
}
|
||||
return status == cudaSuccess;
|
||||
}
|
||||
|
@ -145,15 +145,15 @@ void nms_gpu_upright(
|
||||
// Overlapping CPU computes and D2H memcpy
|
||||
// both take about the same time
|
||||
cudaEvent_t copy_done;
|
||||
cudaEventCreate(©_done);
|
||||
C10_CUDA_CHECK(cudaEventCreate(©_done));
|
||||
int nto_copy = std::min(CHUNK_SIZE, N);
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
&h_delete_mask[0],
|
||||
&d_delete_mask[0],
|
||||
nto_copy * mask_ld * sizeof(int),
|
||||
cudaMemcpyDeviceToHost,
|
||||
context->cuda_stream()));
|
||||
CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
int offset = 0;
|
||||
std::vector<int> h_keep_sorted_list;
|
||||
std::vector<int> rmv(mask_ld, 0);
|
||||
@ -162,7 +162,7 @@ void nms_gpu_upright(
|
||||
int next_offset = offset + ncopied;
|
||||
nto_copy = std::min(CHUNK_SIZE, N - next_offset);
|
||||
if (nto_copy > 0) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
&h_delete_mask[next_offset * mask_ld],
|
||||
&d_delete_mask[next_offset * mask_ld],
|
||||
nto_copy * mask_ld * sizeof(int),
|
||||
@ -170,9 +170,10 @@ void nms_gpu_upright(
|
||||
context->cuda_stream()));
|
||||
}
|
||||
// Waiting for previous copy
|
||||
CUDA_CHECK(cudaEventSynchronize(copy_done));
|
||||
if (nto_copy > 0)
|
||||
cudaEventRecord(copy_done, context->cuda_stream());
|
||||
C10_CUDA_CHECK(cudaEventSynchronize(copy_done));
|
||||
if (nto_copy > 0){
|
||||
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
}
|
||||
for (int i = offset; i < next_offset; ++i) {
|
||||
int iblock = i / BOXES_PER_THREAD;
|
||||
int inblock = i % BOXES_PER_THREAD;
|
||||
@ -186,15 +187,15 @@ void nms_gpu_upright(
|
||||
}
|
||||
offset = next_offset;
|
||||
}
|
||||
cudaEventDestroy(copy_done);
|
||||
C10_CUDA_CHECK(cudaEventDestroy(copy_done));
|
||||
|
||||
const int nkeep = h_keep_sorted_list.size();
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
d_keep_sorted_list,
|
||||
&h_keep_sorted_list[0],
|
||||
nkeep * sizeof(int),
|
||||
cudaMemcpyHostToDevice,
|
||||
context->cuda_stream());
|
||||
context->cuda_stream()));
|
||||
|
||||
*h_nkeep = nkeep;
|
||||
}
|
||||
@ -502,15 +503,15 @@ void nms_gpu_rotated(
|
||||
// Overlapping CPU computes and D2H memcpy
|
||||
// both take about the same time
|
||||
cudaEvent_t copy_done;
|
||||
cudaEventCreate(©_done);
|
||||
C10_CUDA_CHECK(cudaEventCreate(©_done));
|
||||
int nto_copy = std::min(CHUNK_SIZE, N);
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
&h_delete_mask[0],
|
||||
&d_delete_mask[0],
|
||||
nto_copy * mask_ld * sizeof(int),
|
||||
cudaMemcpyDeviceToHost,
|
||||
context->cuda_stream()));
|
||||
CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
int offset = 0;
|
||||
std::vector<int> h_keep_sorted_list;
|
||||
std::vector<int> rmv(mask_ld, 0);
|
||||
@ -519,7 +520,7 @@ void nms_gpu_rotated(
|
||||
int next_offset = offset + ncopied;
|
||||
nto_copy = std::min(CHUNK_SIZE, N - next_offset);
|
||||
if (nto_copy > 0) {
|
||||
CUDA_CHECK(cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
&h_delete_mask[next_offset * mask_ld],
|
||||
&d_delete_mask[next_offset * mask_ld],
|
||||
nto_copy * mask_ld * sizeof(int),
|
||||
@ -527,9 +528,10 @@ void nms_gpu_rotated(
|
||||
context->cuda_stream()));
|
||||
}
|
||||
// Waiting for previous copy
|
||||
CUDA_CHECK(cudaEventSynchronize(copy_done));
|
||||
if (nto_copy > 0)
|
||||
cudaEventRecord(copy_done, context->cuda_stream());
|
||||
C10_CUDA_CHECK(cudaEventSynchronize(copy_done));
|
||||
if (nto_copy > 0){
|
||||
C10_CUDA_CHECK(cudaEventRecord(copy_done, context->cuda_stream()));
|
||||
}
|
||||
for (int i = offset; i < next_offset; ++i) {
|
||||
int iblock = i / BOXES_PER_THREAD;
|
||||
int inblock = i % BOXES_PER_THREAD;
|
||||
@ -543,15 +545,15 @@ void nms_gpu_rotated(
|
||||
}
|
||||
offset = next_offset;
|
||||
}
|
||||
cudaEventDestroy(copy_done);
|
||||
C10_CUDA_CHECK(cudaEventDestroy(copy_done));
|
||||
|
||||
const int nkeep = h_keep_sorted_list.size();
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
d_keep_sorted_list,
|
||||
&h_keep_sorted_list[0],
|
||||
nkeep * sizeof(int),
|
||||
cudaMemcpyHostToDevice,
|
||||
context->cuda_stream());
|
||||
context->cuda_stream()));
|
||||
|
||||
*h_nkeep = nkeep;
|
||||
}
|
||||
|
@ -691,7 +691,7 @@ TEST(UtilsNMSTest, TestPerfRotatedNMS) {
|
||||
// list_nitems * sizeof(int),
|
||||
// cudaMemcpyDeviceToHost,
|
||||
// cuda_context.cuda_stream()));
|
||||
// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream()));
|
||||
// CUDA_CHECK(cudaStreamSynchronize(cuda_context.cuda_stream());
|
||||
|
||||
// ASSERT_EQ(keep.size(), gpu_keep.size());
|
||||
// std::sort(keep.begin(), keep.end());
|
||||
|
@ -130,8 +130,7 @@ void CUDARecurrentNetworkExecutor::_ExecRange(int from, int to) {
|
||||
for (int stream_id = 0; stream_id <= std::min(stream_seq, max_streams - 1);
|
||||
stream_id++) {
|
||||
VLOG(1) << "Wait for stream:" << stream_id;
|
||||
CUDA_CHECK(
|
||||
cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id)));
|
||||
CUDA_CHECK(cudaStreamSynchronize(CUDAContext::cuda_stream(gpu_id, stream_id)));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -138,9 +138,9 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
|
||||
}
|
||||
}
|
||||
}
|
||||
cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize);
|
||||
cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
|
||||
cudaMemcpyHostToDevice);
|
||||
C10_CUDA_CHECK(cudaMalloc(&dStartCoorArr, sizeof(int) * coorArrSize));
|
||||
C10_CUDA_CHECK(cudaMemcpy(dStartCoorArr, startCoorArr, sizeof(int) * coorArrSize,
|
||||
cudaMemcpyHostToDevice));
|
||||
|
||||
// ScaleBlobsCUDAKernelBalanced kernel launch
|
||||
ScaleBlobsCUDAKernelBalanced<T>
|
||||
@ -150,7 +150,7 @@ REGISTER_CUDA_OPERATOR(ScaleBlobs, ScaleBlobsOp<CUDAContext>);
|
||||
dOutputArr);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
cudaFree(dStartCoorArr);
|
||||
C10_CUDA_CHECK(cudaFree(dStartCoorArr));
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
|
@ -305,7 +305,7 @@ CAFFE2_SPECIALIZED_HALF_SCALE_CUDA_KERNEL(float)
|
||||
return; \
|
||||
} \
|
||||
if (alpha == T(0)) { \
|
||||
cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream()); \
|
||||
C10_CUDA_CHECK(cudaMemsetAsync(Y, 0, sizeof(T) * N, context->cuda_stream())); \
|
||||
} else { \
|
||||
thrust::fill( \
|
||||
thrust::cuda::par.on(context->cuda_stream()), Y, Y + N, alpha); \
|
||||
|
@ -418,12 +418,12 @@ void MomentsCUDA(
|
||||
return;
|
||||
}
|
||||
if (std::equal(X_dims, X_dims + ndim, Y_dims)) {
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
mean,
|
||||
X,
|
||||
sizeof(T) * X_size,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
context->cuda_stream());
|
||||
context->cuda_stream()));
|
||||
Set<T, CUDAContext>(Y_size, T(0), var, context);
|
||||
return;
|
||||
}
|
||||
|
@ -2685,12 +2685,12 @@ CAFFE2_CUDA_EXPORT void CopyVector<float, CUDAContext>(
|
||||
float* dst,
|
||||
CUDAContext* context) {
|
||||
if (src != dst && N > 0) {
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
dst,
|
||||
src,
|
||||
sizeof(float) * N,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
context->cuda_stream());
|
||||
context->cuda_stream()));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2701,12 +2701,12 @@ CAFFE2_CUDA_EXPORT void CopyVector<int, CUDAContext>(
|
||||
int* dst,
|
||||
CUDAContext* context) {
|
||||
if (src != dst && N > 0) {
|
||||
cudaMemcpyAsync(
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||
dst,
|
||||
src,
|
||||
sizeof(int) * N,
|
||||
cudaMemcpyDeviceToDevice,
|
||||
context->cuda_stream());
|
||||
context->cuda_stream()));
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user