mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Fix CudaStreamGuard in AOTInductor ABI compatible mode (#109471)
Summary: Use a RAII class to wrap around at::cuda::CUDAStreamGuard. Previous implementation didn't follow the exact CUDAStreamGuard behavior. Test Plan: CI Differential Revision: D49355542 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109471 Approved by: https://github.com/chenyang78
This commit is contained in:
committed by
PyTorch MergeBot
parent
d2ca5fa6c5
commit
6ffa59031a
@ -358,24 +358,18 @@ inline RAIIAtenTensorHandle create_raii_tensor_handle_for_temp(
|
||||
class AOTICudaStreamGuard {
|
||||
public:
|
||||
AOTICudaStreamGuard(cudaStream_t stream, int32_t device_index) {
|
||||
// store the current stream and set the new stream as current
|
||||
cudaStream_t current_stream;
|
||||
AOTI_TORCH_ERROR_CHECK(aoti_torch_get_current_cuda_stream(
|
||||
reinterpret_cast<void**>(¤t_stream), device_index));
|
||||
stream_ = current_stream;
|
||||
CUDAStreamGuardHandle ptr;
|
||||
AOTI_TORCH_ERROR_CHECK(
|
||||
aoti_torch_set_current_cuda_stream(stream, device_index));
|
||||
}
|
||||
|
||||
~AOTICudaStreamGuard() noexcept(false) {
|
||||
// restore the previous stream as current
|
||||
AOTI_TORCH_ERROR_CHECK(
|
||||
aoti_torch_set_current_cuda_stream(stream_, device_index_));
|
||||
aoti_torch_create_cuda_stream_guard(&ptr, stream, device_index));
|
||||
guard_ =
|
||||
std::unique_ptr<void, std::function<void(void*)>>(ptr, [](void* ptr) {
|
||||
AOTI_TORCH_ERROR_CHECK(aoti_torch_delete_cuda_stream_guard(
|
||||
reinterpret_cast<CUDAStreamGuardHandle>(ptr)));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t stream_;
|
||||
int32_t device_index_;
|
||||
std::unique_ptr<void, std::function<void(void*)>> guard_;
|
||||
};
|
||||
|
||||
} // namespace aot_inductor
|
||||
|
@ -163,11 +163,18 @@ AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError aoti_torch_mm_out(
|
||||
AtenTensorHandle mat2);
|
||||
|
||||
#ifdef USE_CUDA
|
||||
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
|
||||
aoti_torch_get_current_cuda_stream(void** ret, int32_t device_index);
|
||||
|
||||
struct CUDAStreamGuardOpaque;
|
||||
using CUDAStreamGuardHandle = CUDAStreamGuardOpaque*;
|
||||
|
||||
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
|
||||
aoti_torch_set_current_cuda_stream(void* stream, int32_t device_index);
|
||||
aoti_torch_create_cuda_stream_guard(
|
||||
CUDAStreamGuardHandle* ret_guard, // returns new reference
|
||||
void* stream,
|
||||
int32_t device_index);
|
||||
|
||||
AOTI_TORCH_EXPORT AOTI_TORCH_NOINLINE AOTITorchError
|
||||
aoti_torch_delete_cuda_stream_guard(CUDAStreamGuardHandle guard);
|
||||
#endif
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
@ -2,23 +2,23 @@
|
||||
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
|
||||
#include <torch/csrc/inductor/aoti_torch/utils.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
AOTITorchError aoti_torch_get_current_cuda_stream(
|
||||
void** ret,
|
||||
int32_t device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
cudaStream_t cuda_stream = c10::cuda::getCurrentCUDAStream(device_index);
|
||||
*ret = reinterpret_cast<void*>(cuda_stream);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_set_current_cuda_stream(
|
||||
AOTITorchError aoti_torch_create_cuda_stream_guard(
|
||||
CUDAStreamGuardHandle* ret_guard,
|
||||
void* stream,
|
||||
int32_t device_index) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
|
||||
cudaStream_t cuda_stream = static_cast<cudaStream_t>(stream);
|
||||
c10::cuda::setCurrentCUDAStream(
|
||||
at::cuda::getStreamFromExternal(cuda_stream, device_index));
|
||||
at::cuda::CUDAStreamGuard* guard =
|
||||
new at::cuda::CUDAStreamGuard(at::cuda::getStreamFromExternal(
|
||||
static_cast<cudaStream_t>(stream), device_index));
|
||||
*ret_guard = reinterpret_cast<CUDAStreamGuardHandle>(guard);
|
||||
});
|
||||
}
|
||||
|
||||
AOTITorchError aoti_torch_delete_cuda_stream_guard(
|
||||
CUDAStreamGuardHandle guard) {
|
||||
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(
|
||||
{ delete reinterpret_cast<at::cuda::CUDAStreamGuard*>(guard); });
|
||||
}
|
||||
|
Reference in New Issue
Block a user