mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
wrap cudaStreamSynchronize calls (#61889)
Summary: This is a first step towards creating context manager that errors out on synchronizing calls. Pull Request resolved: https://github.com/pytorch/pytorch/pull/61889 Reviewed By: albanD Differential Revision: D29805280 Pulled By: ngimel fbshipit-source-id: b66400fbe0941b7daa51e6b30abe27b9cccd4e8a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
3d6aa3a2f6
commit
6284d2a82b
4
.github/workflows/lint.yml
vendored
4
.github/workflows/lint.yml
vendored
@ -86,6 +86,10 @@ jobs:
|
||||
if: always()
|
||||
run: |
|
||||
(! git --no-pager grep -I -no $'#include <cub/' -- ./aten ':(exclude)aten/src/ATen/cuda/cub.cuh' || (echo "The above files have direct cub include; please include ATen/cuda/cub.cuh instead and wrap your cub calls in at::native namespace if necessary"; false))
|
||||
- name: Ensure no raw cuda api calls
|
||||
if: always()
|
||||
run: |
|
||||
(! git --no-pager grep -I -no $'cudaStreamSynchronize' -- ./aten ./c10 ':(exclude)aten/src/ATen/test' ':(exclude)c10/cuda/CUDAFunctions.h' || (echo "The above files call raw cuda APIs directly; please use at::cuda wrappers instead"; false))
|
||||
|
||||
clang-format:
|
||||
runs-on: ubuntu-18.04
|
||||
|
3
Makefile
3
Makefile
@ -87,7 +87,8 @@ quick_checks:
|
||||
--step 'Ensure no unqualified noqa' \
|
||||
--step 'Ensure no unqualified type ignore' \
|
||||
--step 'Ensure no direct cub include' \
|
||||
--step 'Ensure correct trailing newlines'
|
||||
--step 'Ensure correct trailing newlines' \
|
||||
--step 'Ensure no raw cuda api calls'
|
||||
|
||||
flake8:
|
||||
@$(PYTHON) tools/actions_local_runner.py \
|
||||
|
@ -2,35 +2,19 @@
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda.h>
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#include <hip/hip_version.h>
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Scalar _local_scalar_dense_cuda(const Tensor& self) {
|
||||
Scalar r;
|
||||
#if HIP_VERSION >= 301
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
|
||||
scalar_t value;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_CHECK(hipMemcpyWithStream(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
|
||||
at::cuda::memcpy_and_sync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream);
|
||||
r = Scalar(value);
|
||||
});
|
||||
#else
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, self.scalar_type(), "_local_scalar_dense_cuda", [&] {
|
||||
scalar_t value;
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(&value, self.data_ptr<scalar_t>(), sizeof(scalar_t), cudaMemcpyDeviceToHost, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
r = Scalar(value);
|
||||
});
|
||||
#endif
|
||||
return r;
|
||||
}
|
||||
|
||||
|
@ -224,12 +224,7 @@ static void copy_kernel_cuda(TensorIterator& iter, bool non_blocking) {
|
||||
void* ptr = (dst_device == kCPU ? dst : src);
|
||||
AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(ptr, stream));
|
||||
} else {
|
||||
#if HIP_VERSION >= 301
|
||||
AT_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
|
||||
#else
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
at::cuda::memcpy_and_sync(dst, src, nbytes, kind, stream);
|
||||
}
|
||||
|
||||
if (iter.tensor(0).is_conj() != iter.tensor(1).is_conj()) {
|
||||
|
@ -75,14 +75,14 @@ struct MagmaStreamSyncGuard {
|
||||
MagmaStreamSyncGuard() {
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
if (stream != at::cuda::getDefaultCUDAStream()) {
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
at::cuda::stream_synchronize(stream);
|
||||
}
|
||||
}
|
||||
|
||||
~MagmaStreamSyncGuard() noexcept(false) {
|
||||
auto default_stream = at::cuda::getDefaultCUDAStream();
|
||||
if (at::cuda::getCurrentCUDAStream() != default_stream) {
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(default_stream));
|
||||
at::cuda::stream_synchronize(default_stream);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -61,9 +61,7 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
|
||||
auto temp_storage = allocator.allocate(temp_storage_bytes);
|
||||
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
|
||||
int num_nonzeros_h;
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream));
|
||||
//need to synchronize to make sure data is available on the host
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
at::cuda::memcpy_and_sync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream);
|
||||
//expected output size is num_nonzeros x ndim
|
||||
//we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output)
|
||||
//we are able to directly use passed output with this size and strides, and we can also (per contract)
|
||||
|
@ -105,9 +105,8 @@ void calculate_mode(
|
||||
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
values_data, &mode, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
|
||||
AT_CUDA_CHECK(cudaMemcpyAsync(
|
||||
indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
//memcpy_and_sync will synchronize results
|
||||
at::cuda::memcpy_and_sync(indices_data, &index, sizeof(scalar_t), cudaMemcpyHostToDevice, stream);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
|
@ -5,9 +5,6 @@
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <c10/util/typeid.h>
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#include <hip/hip_version.h>
|
||||
#endif
|
||||
|
||||
scalar_t* THCStorage_(data)(THCState *state, const THCStorage *self)
|
||||
{
|
||||
@ -26,16 +23,9 @@ void THCStorage_(set)(THCState *state, THCStorage *self, ptrdiff_t index, scalar
|
||||
2,
|
||||
"index out of bounds");
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
||||
#if HIP_VERSION >= 301
|
||||
THCudaCheck(hipMemcpyWithStream(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream));
|
||||
#else
|
||||
THCudaCheck(cudaMemcpyAsync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
|
||||
at::cuda::memcpy_and_sync(THCStorage_(data)(state, self) + index, &value, sizeof(scalar_t),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream));
|
||||
THCudaCheck(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
stream);
|
||||
}
|
||||
|
||||
scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t index)
|
||||
@ -46,14 +36,8 @@ scalar_t THCStorage_(get)(THCState *state, const THCStorage *self, ptrdiff_t ind
|
||||
"index out of bounds");
|
||||
scalar_t value;
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
||||
#if HIP_VERSION >= 301
|
||||
THCudaCheck(hipMemcpyWithStream(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
|
||||
cudaMemcpyDeviceToHost, stream));
|
||||
#else
|
||||
THCudaCheck(cudaMemcpyAsync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
|
||||
cudaMemcpyDeviceToHost, stream));
|
||||
THCudaCheck(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
at::cuda::memcpy_and_sync(&value, THCStorage_(data)(state, self) + index, sizeof(scalar_t),
|
||||
cudaMemcpyDeviceToHost, stream);
|
||||
return value;
|
||||
}
|
||||
|
||||
|
@ -2,30 +2,18 @@
|
||||
#define THC_GENERIC_FILE "THC/generic/THCStorageCopy.cpp"
|
||||
#else
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#include <hip/hip_version.h>
|
||||
#endif
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
|
||||
void THCStorage_(copyCPU)(THCState *state, THCStorage *self, struct THStorage *src)
|
||||
{
|
||||
THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match");
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
||||
#if HIP_VERSION >= 301
|
||||
THCudaCheck(hipMemcpyWithStream(
|
||||
THCStorage_(data)(state, self),
|
||||
at::cuda::memcpy_and_sync(THCStorage_(data)(state, self),
|
||||
THStorage_(data)(src),
|
||||
self->nbytes(),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream));
|
||||
#else
|
||||
THCudaCheck(cudaMemcpyAsync(
|
||||
THCStorage_(data)(state, self),
|
||||
THStorage_(data)(src),
|
||||
self->nbytes(),
|
||||
cudaMemcpyHostToDevice,
|
||||
stream));
|
||||
THCudaCheck(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
stream);
|
||||
|
||||
}
|
||||
|
||||
#define TH_CUDA_STORAGE_IMPLEMENT_COPY(TYPEC) \
|
||||
@ -61,22 +49,12 @@ void THStorage_(copyCuda)(THCState *state, THStorage *self, struct THCStorage *s
|
||||
{
|
||||
THArgCheck(self->nbytes() == src->nbytes(), 2, "size does not match");
|
||||
cudaStream_t stream = c10::cuda::getCurrentCUDAStream();
|
||||
#if HIP_VERSION >= 301
|
||||
THCudaCheck(hipMemcpyWithStream(
|
||||
at::cuda::memcpy_and_sync(
|
||||
THStorage_(data)(self),
|
||||
THCStorage_(data)(state, src),
|
||||
self->nbytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream));
|
||||
#else
|
||||
THCudaCheck(cudaMemcpyAsync(
|
||||
THStorage_(data)(self),
|
||||
THCStorage_(data)(state, src),
|
||||
self->nbytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
stream));
|
||||
THCudaCheck(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
stream);
|
||||
}
|
||||
|
||||
#define TH_CUDA_STORAGE_IMPLEMENT_COPYTO(TYPEC) \
|
||||
|
@ -22,10 +22,10 @@ configure_file(
|
||||
set(C10_CUDA_SRCS
|
||||
CUDAStream.cpp
|
||||
CUDAFunctions.cpp
|
||||
CUDAMiscFunctions.cpp
|
||||
CUDACachingAllocator.cpp
|
||||
impl/CUDAGuardImpl.cpp
|
||||
impl/CUDATest.cpp
|
||||
CUDAFunctions.cpp
|
||||
)
|
||||
set(C10_CUDA_HEADERS
|
||||
CUDAException.h
|
||||
@ -34,6 +34,7 @@ set(C10_CUDA_HEADERS
|
||||
CUDAMathCompat.h
|
||||
CUDAStream.h
|
||||
CUDAFunctions.h
|
||||
CUDAMiscFunctions.h
|
||||
impl/CUDAGuardImpl.h
|
||||
impl/CUDATest.h
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAMiscFunctions.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <cuda.h>
|
||||
|
@ -1,6 +1,3 @@
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
@ -141,18 +138,5 @@ void device_synchronize() {
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
const char* get_cuda_check_suffix() noexcept {
|
||||
static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING");
|
||||
static bool blocking_enabled =
|
||||
(device_blocking_flag && atoi(device_blocking_flag));
|
||||
if (blocking_enabled) {
|
||||
return "";
|
||||
} else {
|
||||
return "\nCUDA kernel errors might be asynchronously reported at some"
|
||||
" other API call,so the stacktrace below might be incorrect."
|
||||
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
@ -8,7 +8,12 @@
|
||||
// The naming convention used here matches the naming convention of torch.cuda
|
||||
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#include <hip/hip_version.h>
|
||||
#endif
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
@ -30,7 +35,25 @@ C10_CUDA_API void set_device(DeviceIndex device);
|
||||
|
||||
C10_CUDA_API void device_synchronize();
|
||||
|
||||
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
|
||||
// the subsequent functions are defined in the header because for performance
|
||||
// reasons we want them to be inline
|
||||
C10_CUDA_API void __inline__ memcpy_and_sync(
|
||||
void* dst,
|
||||
void* src,
|
||||
int64_t nbytes,
|
||||
cudaMemcpyKind kind,
|
||||
cudaStream_t stream) {
|
||||
#if defined(HIP_VERSION) && (HIP_VERSION >= 301)
|
||||
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream));
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
#endif
|
||||
}
|
||||
|
||||
C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
20
c10/cuda/CUDAMiscFunctions.cpp
Normal file
20
c10/cuda/CUDAMiscFunctions.cpp
Normal file
@ -0,0 +1,20 @@
|
||||
#include <c10/cuda/CUDAMiscFunctions.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
|
||||
const char* get_cuda_check_suffix() noexcept {
|
||||
static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING");
|
||||
static bool blocking_enabled =
|
||||
(device_blocking_flag && atoi(device_blocking_flag));
|
||||
if (blocking_enabled) {
|
||||
return "";
|
||||
} else {
|
||||
return "\nCUDA kernel errors might be asynchronously reported at some"
|
||||
" other API call,so the stacktrace below might be incorrect."
|
||||
"\nFor debugging consider passing CUDA_LAUNCH_BLOCKING=1.";
|
||||
}
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
11
c10/cuda/CUDAMiscFunctions.h
Normal file
11
c10/cuda/CUDAMiscFunctions.h
Normal file
@ -0,0 +1,11 @@
|
||||
#pragma once
|
||||
// this file is to avoid circular dependency between CUDAFunctions.h and
|
||||
// CUDAExceptions.h
|
||||
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
C10_CUDA_API const char* get_cuda_check_suffix() noexcept;
|
||||
}
|
||||
} // namespace c10
|
@ -7,8 +7,7 @@
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
#include <c10/cuda/CUDAMacros.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
/*
|
||||
@ -128,7 +127,7 @@ class C10_CUDA_API CUDAStream {
|
||||
|
||||
void synchronize() const {
|
||||
DeviceGuard guard{stream_.device()};
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream()));
|
||||
c10::cuda::stream_synchronize(stream());
|
||||
}
|
||||
|
||||
int priority() const {
|
||||
|
@ -161,7 +161,7 @@ CudaIPCSentData::CudaIPCSentData(
|
||||
event_sync_required_ = true;
|
||||
} else {
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
at::cuda::stream_synchronize(stream);
|
||||
event_ = nullptr;
|
||||
event_sync_required_ = false;
|
||||
}
|
||||
@ -169,7 +169,7 @@ CudaIPCSentData::CudaIPCSentData(
|
||||
// cuIpcGetEventHandle with HIP is not supported, so we have to sync
|
||||
// stream instead of passing event
|
||||
auto stream = c10::cuda::getCurrentCUDAStream(device.index());
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
at::cuda::stream_synchronize(stream);
|
||||
event_sync_required_ = false;
|
||||
#endif
|
||||
}
|
||||
|
@ -426,7 +426,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
|
||||
|
||||
// TODO: Instead of cudaStreamSynchronize it is possible to add Stream
|
||||
// Callback and release counter inside of it (need to check performance impact)
|
||||
cudaStreamSynchronize(c10::cuda::getCurrentCUDAStream(device));
|
||||
at::cuda::stream_synchronize(c10::cuda::getCurrentCUDAStream(device));
|
||||
|
||||
// We don't want to break existing code, so resource deletion is best
|
||||
// effort basis. Exception expected if producer process terminated
|
||||
|
@ -9,7 +9,6 @@
|
||||
|
||||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/cuda/CUDAFunctions.h>
|
||||
@ -529,7 +528,7 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
|
||||
stream,
|
||||
kernel_arguments.getBuffer(),
|
||||
nullptr));
|
||||
AT_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
at::cuda::stream_synchronize(stream);
|
||||
}
|
||||
|
||||
return alloced_outputs;
|
||||
|
@ -8123,6 +8123,7 @@ C10_MAPPINGS = collections.OrderedDict(
|
||||
("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)),
|
||||
("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)),
|
||||
("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)),
|
||||
("c10/cuda/CUDAMiscFunctions.h", ("c10/hip/HIPMiscFunctions.h", API_C10)),
|
||||
("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)),
|
||||
("c10/cuda/CUDAGraphsC10Utils.h", ("c10/hip/HIPGraphsC10Utils.h", API_C10)),
|
||||
("c10/cuda/CUDACachingAllocator.h", ("c10/hip/HIPCachingAllocator.h", API_C10)),
|
||||
|
Reference in New Issue
Block a user