Move hasPrimaryContext to c10::cuda (#96800)

This method has to be accessible from `c10` to enable CUDA-12 integration.
Implemented by providing private `c10::cuda:_internal::setHasPrimaryContext` that passes the pointer to the implementation (in `torch_cuda`) back to c10.
Use global class constructor/destructor to guarantee RAII.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96800
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Shulga
2023-03-17 04:50:31 +00:00
committed by PyTorch MergeBot
parent cbd3df93c4
commit 24ce3a7c34
8 changed files with 78 additions and 51 deletions

View File

@ -22,7 +22,7 @@ using CaptureStatus = c10::cuda::CaptureStatus;
inline CaptureStatus currentStreamCaptureStatus() {
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
// don't create a context if we don't have to
if (at::cuda::detail::hasPrimaryContext(c10::cuda::current_device())) {
if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
} else {
return CaptureStatus::None;

View File

@ -168,7 +168,7 @@ class CUDAHostAllocator {
// primary context, if available. See pytorch/pytorch#21081.
at::OptionalDeviceGuard device_guard;
auto primary_ctx_device_index =
at::cuda::detail::getDeviceIndexWithPrimaryContext();
c10::cuda::getDeviceIndexWithPrimaryContext();
if (primary_ctx_device_index.has_value()) {
device_guard.reset_device(
at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));

View File

@ -40,9 +40,11 @@
#include <functional>
#include <memory>
namespace at {
namespace cuda {
namespace detail {
namespace c10::cuda::_internal {
void setHasPrimaryContext(bool (*func)(int64_t));
}
namespace at::cuda::detail {
const at::cuda::NVRTC& nvrtc();
int64_t current_device();
@ -53,6 +55,29 @@ void set_magma_init_fn(void (*fn)()) {
magma_init_fn = fn;
}
namespace {
bool _hasPrimaryContext(int64_t device_index) {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
// In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
// (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
int ctx_is_active = 0;
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
return ctx_is_active == 1;
}
// Register hasPrimaryContext back to c10::cuda
struct _Initializer {
_Initializer() {
c10::cuda::_internal::setHasPrimaryContext(_hasPrimaryContext);
}
~_Initializer() {
c10::cuda::_internal::setHasPrimaryContext(nullptr);
}
} initializer;
} // anonymous namespace
// Sets the CUDA_MODULE_LOADING environment variable
// if it's not set by the user.
void maybe_set_cuda_module_loading(const std::string &def_value) {
@ -209,36 +234,8 @@ int64_t CUDAHooks::current_device() const {
return at::cuda::detail::current_device();
}
bool hasPrimaryContext(int64_t device_index) {
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
unsigned int ctx_flags;
// In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
// (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
int ctx_is_active = 0;
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
return ctx_is_active == 1;
}
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
return at::cuda::detail::hasPrimaryContext(device_index);
}
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
// check current device first
int64_t current_device_index = current_device();
if (current_device_index >= 0) {
if (hasPrimaryContext(current_device_index)) {
return current_device_index;
}
}
for (const auto device_index : c10::irange(at::cuda::device_count())) {
if (device_index == current_device_index) continue;
if (hasPrimaryContext(device_index)) {
return device_index;
}
}
return c10::nullopt;
return _hasPrimaryContext(device_index);
}
Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
@ -443,6 +440,4 @@ using at::RegistererCUDAHooksRegistry;
REGISTER_CUDA_HOOKS(CUDAHooks);
} // namespace detail
} // namespace cuda
} // namespace at
} // namespace at::cuda::detail

View File

@ -15,8 +15,6 @@ namespace at { namespace cuda { namespace detail {
// in the same library where Magma will be used.
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
TORCH_CUDA_CPP_API bool hasPrimaryContext(int64_t device_index);
TORCH_CUDA_CPP_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
// The real implementation of CUDAHooksInterface
struct CUDAHooks : public at::CUDAHooksInterface {

View File

@ -1,7 +1,6 @@
#pragma once
#include <c10/core/Allocator.h>
#include <ATen/core/Generator.h>
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <c10/util/Registry.h>
@ -10,14 +9,14 @@
#include <functional>
#include <memory>
// Forward-declares at::cuda::NVRTC
namespace at { namespace cuda {
struct NVRTC;
}} // at::cuda
// Forward-declares at::Context, at::Generator and at::cuda::NVRTC
namespace at {
class Context;
}
struct Generator;
namespace cuda {
struct NVRTC;
} // namespace cuda
} // namespace at
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {

View File

@ -3,8 +3,7 @@
#include <limits>
namespace c10 {
namespace cuda {
namespace c10::cuda {
namespace {
// returns -1 on failure
@ -149,5 +148,38 @@ void warn_or_error_on_sync() {
}
}
} // namespace cuda
} // namespace c10
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
// check current device first
int64_t current_device_index = current_device();
if (current_device_index >= 0) {
if (hasPrimaryContext(current_device_index)) {
return current_device_index;
}
}
for (const auto device_index : c10::irange(at::cuda::device_count())) {
if (device_index == current_device_index)
continue;
if (hasPrimaryContext(device_index)) {
return device_index;
}
}
return c10::nullopt;
}
namespace _internal {
bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) {
TORCH_CHECK(false, "Should never been called");
}
bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext;
// Private api to be called from CUDAHooks.cpp
C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) {
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
}
} // namespace _internal
bool hasPrimaryContext(int64_t device_index) {
return _internal::hasPrimaryContext(device_index);
}
} // namespace c10::cuda

View File

@ -96,5 +96,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}
C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
} // namespace cuda
} // namespace c10

View File

@ -486,7 +486,7 @@ PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
THPUtils_checkLong(arg), "invalid argument to has_primary_context");
int64_t device_index = static_cast<int64_t>(THPUtils_unpackLong(arg));
if (at::cuda::detail::hasPrimaryContext(device_index)) {
if (c10::cuda::hasPrimaryContext(device_index)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;