diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index ef8bda66c5a1..0d0d36f989c4 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -86,6 +86,10 @@ jobs: if: always() run: | (! git --no-pager grep -I -no $'#include #include -#include - -#ifdef __HIP_PLATFORM_HCC__ -#include -#endif namespace at { namespace native { Scalar _local_scalar_dense_cuda(const Tensor& self) { Scalar r; -#if HIP_VERSION >= 301 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] { scalar_t value; cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(hipMemcpyWithStream(&value, self.data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); + at::cuda::memcpy_and_sync(&value, self.data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream); r = Scalar(value); }); -#else - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( - at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] { - scalar_t value; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data_ptr(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - r = Scalar(value); - }); -#endif return r; } diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 968f7207d578..84e016016d9b 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -224,12 +224,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) { void* ptr = (dst_device == kCPU ? dst : src); AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(ptr, stream)); } else { -#if HIP_VERSION >= 301 - AT_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); -#else - AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); -#endif + at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream); } if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) { diff --git a/aten/src/ATen/native/cuda/MiscUtils.h b/aten/src/ATen/native/cuda/MiscUtils.h index 5373f61edb8f..b28f77c70174 100644 --- a/aten/src/ATen/native/cuda/MiscUtils.h +++ b/aten/src/ATen/native/cuda/MiscUtils.h @@ -75,14 +75,14 @@ struct MagmaStreamSyncGuard { MagmaStreamSyncGuard() { auto stream = at::cuda::getCurrentCUDAStream(); if (stream != at::cuda::getDefaultCUDAStream()) { - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + at::cuda::stream_synchronize(stream); } } ~MagmaStreamSyncGuard() noexcept(false) { auto default_stream = at::cuda::getDefaultCUDAStream(); if (at::cuda::getCurrentCUDAStream() != default_stream) { - AT_CUDA_CHECK(cudaStreamSynchronize(default_stream)); + at::cuda::stream_synchronize(default_stream); } } }; diff --git a/aten/src/ATen/native/cuda/Nonzero.cu b/aten/src/ATen/native/cuda/Nonzero.cu index 79ee419336a6..e72493738a40 100644 --- a/aten/src/ATen/native/cuda/Nonzero.cu +++ b/aten/src/ATen/native/cuda/Nonzero.cu @@ -61,9 +61,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){ auto temp_storage = allocator.allocate(temp_storage_bytes); cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream); int num_nonzeros_h; - C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream)); - //need to synchronize to make sure data is available on the host - C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream); //expected output size is num_nonzeros x ndim //we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output) //we are able to directly use passed output with this size and strides, and we can also (per contract) diff --git a/aten/src/ATen/native/cuda/TensorModeKernel.cu b/aten/src/ATen/native/cuda/TensorModeKernel.cu index c7455f55760e..c24d7fb25037 100644 --- a/aten/src/ATen/native/cuda/TensorModeKernel.cu +++ b/aten/src/ATen/native/cuda/TensorModeKernel.cu @@ -105,9 +105,8 @@ void calculate_mode( AT_CUDA_CHECK(cudaMemcpyAsync( values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream)); - AT_CUDA_CHECK(cudaMemcpyAsync( - indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + //memcpy_and_sync will synchronize results + at::cuda::memcpy_and_sync(indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream); } template diff --git a/aten/src/THC/generic/THCStorage.cpp b/aten/src/THC/generic/THCStorage.cpp index 59664404880d..63ce73388e83 100644 --- a/aten/src/THC/generic/THCStorage.cpp +++ b/aten/src/THC/generic/THCStorage.cpp @@ -5,9 +5,6 @@ #include #include -#ifdef __HIP_PLATFORM_HCC__ -#include -#endif scalar_t* THCStorage_(data)(THCState *state, const THCStorage *self) { @@ -26,16 +23,9 @@ void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, scalar 2, "index out of bounds"); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); -#if HIP_VERSION >= 301 - THCudaCheck(hipMemcpyWithStream(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t), - cudaMemcpyHostToDevice, - stream)); -#else - THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t), + at::cuda::memcpy_and_sync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t), cudaMemcpyHostToDevice, - stream)); - THCudaCheck(cudaStreamSynchronize(stream)); -#endif + stream); } scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index) @@ -46,14 +36,8 @@ scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t ind "index out of bounds"); scalar_t value; cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); -#if HIP_VERSION >= 301 - THCudaCheck(hipMemcpyWithStream(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t), - cudaMemcpyDeviceToHost, stream)); -#else - THCudaCheck(cudaMemcpyAsync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t), - cudaMemcpyDeviceToHost, stream)); - THCudaCheck(cudaStreamSynchronize(stream)); -#endif + at::cuda::memcpy_and_sync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t), + cudaMemcpyDeviceToHost, stream); return value; } diff --git a/aten/src/THC/generic/THCStorageCopy.cpp b/aten/src/THC/generic/THCStorageCopy.cpp index 1411e30fd2c5..086b6d421a00 100644 --- a/aten/src/THC/generic/THCStorageCopy.cpp +++ b/aten/src/THC/generic/THCStorageCopy.cpp @@ -2,30 +2,18 @@ #define THC_GENERIC_FILE "THC/generic/THCStorageCopy.cpp" #else -#ifdef __HIP_PLATFORM_HCC__ -#include -#endif +#include void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src) { THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match"); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); -#if HIP_VERSION >= 301 - THCudaCheck(hipMemcpyWithStream( - THCStorage_(data)(state, self), + at::cuda::memcpy_and_sync(THCStorage_(data)(state, self), THStorage_(data)(src), self->nbytes(), cudaMemcpyHostToDevice, - stream)); -#else - THCudaCheck(cudaMemcpyAsync( - THCStorage_(data)(state, self), - THStorage_(data)(src), - self->nbytes(), - cudaMemcpyHostToDevice, - stream)); - THCudaCheck(cudaStreamSynchronize(stream)); -#endif + stream); + } #define TH_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC) \ @@ -61,22 +49,12 @@ void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *s { THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match"); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); -#if HIP_VERSION >= 301 - THCudaCheck(hipMemcpyWithStream( + at::cuda::memcpy_and_sync( THStorage_(data)(self), THCStorage_(data)(state, src), self->nbytes(), cudaMemcpyDeviceToHost, - stream)); -#else - THCudaCheck(cudaMemcpyAsync( - THStorage_(data)(self), - THCStorage_(data)(state, src), - self->nbytes(), - cudaMemcpyDeviceToHost, - stream)); - THCudaCheck(cudaStreamSynchronize(stream)); -#endif + stream); } #define TH_CUDA_STORAGE_IMPLEMENT_COPYTO(TYPEC) \ diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index 256fc54b08a1..3803498b3352 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -22,10 +22,10 @@ configure_file( set(C10_CUDA_SRCS CUDAStream.cpp CUDAFunctions.cpp + CUDAMiscFunctions.cpp CUDACachingAllocator.cpp impl/CUDAGuardImpl.cpp impl/CUDATest.cpp - CUDAFunctions.cpp ) set(C10_CUDA_HEADERS CUDAException.h @@ -34,6 +34,7 @@ set(C10_CUDA_HEADERS CUDAMathCompat.h CUDAStream.h CUDAFunctions.h + CUDAMiscFunctions.h impl/CUDAGuardImpl.h impl/CUDATest.h ) diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index 2c084d6cbf93..c3ff821bb332 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -1,6 +1,7 @@ #pragma once -#include +#include +#include #include #include #include diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 6838b05a7fa7..1f781b37099a 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -1,6 +1,3 @@ -#include - -#include #include #include @@ -141,18 +138,5 @@ void device_synchronize() { C10_CUDA_CHECK(cudaDeviceSynchronize()); } -const char* get_cuda_check_suffix() noexcept { - static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING"); - static bool blocking_enabled = - (device_blocking_flag && atoi(device_blocking_flag)); - if (blocking_enabled) { - return ""; - } else { - return "\nCUDA kernel errors might be asynchronously reported at some" - " other API call,so the stacktrace below might be incorrect." - "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1."; - } -} - } // namespace cuda } // namespace c10 diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 25f687827dfc..1464999d715c 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -8,7 +8,12 @@ // The naming convention used here matches the naming convention of torch.cuda #include +#include #include +#ifdef __HIP_PLATFORM_HCC__ +#include +#endif +#include namespace c10 { namespace cuda { @@ -30,7 +35,25 @@ C10_CUDA_API void set_device(DeviceIndex device); C10_CUDA_API void device_synchronize(); -C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +// the subsequent functions are defined in the header because for performance +// reasons we want them to be inline +C10_CUDA_API void __inline__ memcpy_and_sync( + void* dst, + void* src, + int64_t nbytes, + cudaMemcpyKind kind, + cudaStream_t stream) { +#if defined(HIP_VERSION) && (HIP_VERSION >= 301) + C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); +#else + C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +#endif +} + +C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { + C10_CUDA_CHECK(cudaStreamSynchronize(stream)); +} } // namespace cuda } // namespace c10 diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp new file mode 100644 index 000000000000..7655ca8c6a60 --- /dev/null +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -0,0 +1,20 @@ +#include +#include + +namespace c10 { +namespace cuda { + +const char* get_cuda_check_suffix() noexcept { + static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING"); + static bool blocking_enabled = + (device_blocking_flag && atoi(device_blocking_flag)); + if (blocking_enabled) { + return ""; + } else { + return "\nCUDA kernel errors might be asynchronously reported at some" + " other API call,so the stacktrace below might be incorrect." + "\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1."; + } +} +} // namespace cuda +} // namespace c10 diff --git a/c10/cuda/CUDAMiscFunctions.h b/c10/cuda/CUDAMiscFunctions.h new file mode 100644 index 000000000000..eca8fd042f61 --- /dev/null +++ b/c10/cuda/CUDAMiscFunctions.h @@ -0,0 +1,11 @@ +#pragma once +// this file is to avoid circular dependency between CUDAFunctions.h and +// CUDAExceptions.h + +#include + +namespace c10 { +namespace cuda { +C10_CUDA_API const char* get_cuda_check_suffix() noexcept; +} +} // namespace c10 diff --git a/c10/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h index f46ae5bf007a..07d4997a087f 100644 --- a/c10/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -7,8 +7,7 @@ #include #include -#include -#include +#include #include /* @@ -128,7 +127,7 @@ class C10_CUDA_API CUDAStream { void synchronize() const { DeviceGuard guard{stream_.device()}; - C10_CUDA_CHECK(cudaStreamSynchronize(stream())); + c10::cuda::stream_synchronize(stream()); } int priority() const { diff --git a/torch/csrc/CudaIPCTypes.cpp b/torch/csrc/CudaIPCTypes.cpp index 62574976b57f..13801bc1fa19 100644 --- a/torch/csrc/CudaIPCTypes.cpp +++ b/torch/csrc/CudaIPCTypes.cpp @@ -161,7 +161,7 @@ CudaIPCSentData::CudaIPCSentData( event_sync_required_ = true; } else { auto stream = c10::cuda::getCurrentCUDAStream(device.index()); - C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + at::cuda::stream_synchronize(stream); event_ = nullptr; event_sync_required_ = false; } @@ -169,7 +169,7 @@ CudaIPCSentData::CudaIPCSentData( // cuIpcGetEventHandle with HIP is not supported, so we have to sync // stream instead of passing event auto stream = c10::cuda::getCurrentCUDAStream(device.index()); - C10_CUDA_CHECK(cudaStreamSynchronize(stream)); + at::cuda::stream_synchronize(stream); event_sync_required_ = false; #endif } diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 98706dda0f0e..30511534253d 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -426,7 +426,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args) // TODO: Instead of cudaStreamSynchronize it is possible to add Stream // Callback and release counter inside of it (need to check performance impact) - cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream(device)); + at::cuda::stream_synchronize(c10::cuda::getCurrentCUDAStream(device)); // We don't want to break existing code, so resource deletion is best // effort basis. Exception expected if producer process terminated diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d9ca0cc977ed..e2c40bd3d5d6 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -529,7 +528,7 @@ std::vector FusionExecutor::runFusion( stream, kernel_arguments.getBuffer(), nullptr)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + at::cuda::stream_synchronize(stream); } return alloced_outputs; diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index 0494df31d899..5ab001c82c7e 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -8123,6 +8123,7 @@ C10_MAPPINGS = collections.OrderedDict( ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)), ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), ("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)), ("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)),