[inductor] Fix CudaStreamGuard in AOTInductor ABI compatible mode (#109471)

Summary: Use a RAII class to wrap around at::cuda::CUDAStreamGuard. Previous implementation didn't follow the exact CUDAStreamGuard behavior.

Test Plan: CI

Differential Revision: D49355542

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109471
Approved by: https://github.com/chenyang78
This commit is contained in:
Bin Bao
2023-09-18 15:54:58 +00:00
committed by PyTorch MergeBot
parent d2ca5fa6c5
commit 6ffa59031a
3 changed files with 31 additions and 30 deletions

View File

@ -358,24 +358,18 @@ inline RAIIAtenTensorHandle create_raii_tensor_handle_for_temp(
class AOTICudaStreamGuard {
public:
AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) {
// store the current stream and set the new stream as current
cudaStream_t current_stream;
AOTI_TORCH_ERROR_CHECK(aoti_torch_get_current_cuda_stream(
reinterpret_cast<void**>(&current_stream), device_index));
stream_ = current_stream;
CUDAStreamGuardHandle ptr;
AOTI_TORCH_ERROR_CHECK(
aoti_torch_set_current_cuda_stream(stream, device_index));
}
~AOTICudaStreamGuard() noexcept(false) {
// restore the previous stream as current
AOTI_TORCH_ERROR_CHECK(
aoti_torch_set_current_cuda_stream(stream_, device_index_));
aoti_torch_create_cuda_stream_guard(&ptr, stream, device_index));
guard_ =
std::unique_ptr<void, std::function<void(void*)>>(ptr, [](void* ptr) {
AOTI_TORCH_ERROR_CHECK(aoti_torch_delete_cuda_stream_guard(
reinterpret_cast<CUDAStreamGuardHandle>(ptr)));
});
}
private:
cudaStream_t stream_;
int32_t device_index_;
std::unique_ptr<void, std::function<void(void*)>> guard_;
};
} // namespace aot_inductor

View File

@ -163,11 +163,18 @@ AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError aoti_torch_mm_out(
AtenTensorHandle mat2);
#ifdef USE_CUDA
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
aoti_torch_get_current_cuda_stream(void** ret, int32_t device_index);
struct CUDAStreamGuardOpaque;
using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*;
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
aoti_torch_set_current_cuda_stream(void* stream, int32_t device_index);
aoti_torch_create_cuda_stream_guard(
CUDAStreamGuardHandle* ret_guard, // returns new reference
void* stream,
int32_t device_index);
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
#endif
#ifdef __cplusplus

View File

@ -2,23 +2,23 @@
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/inductor/aoti_torch/utils.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
AOTITorchError aoti_torch_get_current_cuda_stream(
void** ret,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
cudaStream_t cuda_stream = c10::cuda::getCurrentCUDAStream(device_index);
*ret = reinterpret_cast<void*>(cuda_stream);
});
}
AOTITorchError aoti_torch_set_current_cuda_stream(
AOTITorchError aoti_torch_create_cuda_stream_guard(
CUDAStreamGuardHandle* ret_guard,
void* stream,
int32_t device_index) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
cudaStream_t cuda_stream = static_cast<cudaStream_t>(stream);
c10::cuda::setCurrentCUDAStream(
at::cuda::getStreamFromExternal(cuda_stream, device_index));
at::cuda::CUDAStreamGuard* guard =
new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
static_cast<cudaStream_t>(stream), device_index));
*ret_guard = reinterpret_cast<CUDAStreamGuardHandle>(guard);
});
}
AOTITorchError aoti_torch_delete_cuda_stream_guard(
CUDAStreamGuardHandle guard) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
{ delete reinterpret_cast<at::cuda::CUDAStreamGuard*>(guard); });
}