mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Move hasPrimaryContext
to c10::cuda
(#96800)
This method has to be accessible from `c10` to enable CUDA-12 integration. Implemented by providing private `c10::cuda:_internal::setHasPrimaryContext` that passes the pointer to the implementation (in `torch_cuda`) back to c10. Use global class constructor/destructor to guarantee RAII. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96800 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbd3df93c4
commit
24ce3a7c34
@ -22,7 +22,7 @@ using CaptureStatus = c10::cuda::CaptureStatus;
|
||||
inline CaptureStatus currentStreamCaptureStatus() {
|
||||
#if !defined(USE_ROCM) || ROCM_VERSION >= 50300
|
||||
// don't create a context if we don't have to
|
||||
if (at::cuda::detail::hasPrimaryContext(c10::cuda::current_device())) {
|
||||
if (c10::cuda::hasPrimaryContext(c10::cuda::current_device())) {
|
||||
return c10::cuda::currentStreamCaptureStatusMayInitCtx();
|
||||
} else {
|
||||
return CaptureStatus::None;
|
||||
|
@ -168,7 +168,7 @@ class CUDAHostAllocator {
|
||||
// primary context, if available. See pytorch/pytorch#21081.
|
||||
at::OptionalDeviceGuard device_guard;
|
||||
auto primary_ctx_device_index =
|
||||
at::cuda::detail::getDeviceIndexWithPrimaryContext();
|
||||
c10::cuda::getDeviceIndexWithPrimaryContext();
|
||||
if (primary_ctx_device_index.has_value()) {
|
||||
device_guard.reset_device(
|
||||
at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));
|
||||
|
@ -40,9 +40,11 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
namespace detail {
|
||||
namespace c10::cuda::_internal {
|
||||
void setHasPrimaryContext(bool (*func)(int64_t));
|
||||
}
|
||||
|
||||
namespace at::cuda::detail {
|
||||
|
||||
const at::cuda::NVRTC& nvrtc();
|
||||
int64_t current_device();
|
||||
@ -53,6 +55,29 @@ void set_magma_init_fn(void (*fn)()) {
|
||||
magma_init_fn = fn;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool _hasPrimaryContext(int64_t device_index) {
|
||||
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
|
||||
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
|
||||
unsigned int ctx_flags;
|
||||
// In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
|
||||
// (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
|
||||
int ctx_is_active = 0;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
|
||||
return ctx_is_active == 1;
|
||||
}
|
||||
|
||||
// Register hasPrimaryContext back to c10::cuda
|
||||
struct _Initializer {
|
||||
_Initializer() {
|
||||
c10::cuda::_internal::setHasPrimaryContext(_hasPrimaryContext);
|
||||
}
|
||||
~_Initializer() {
|
||||
c10::cuda::_internal::setHasPrimaryContext(nullptr);
|
||||
}
|
||||
} initializer;
|
||||
} // anonymous namespace
|
||||
|
||||
// Sets the CUDA_MODULE_LOADING environment variable
|
||||
// if it's not set by the user.
|
||||
void maybe_set_cuda_module_loading(const std::string &def_value) {
|
||||
@ -209,36 +234,8 @@ int64_t CUDAHooks::current_device() const {
|
||||
return at::cuda::detail::current_device();
|
||||
}
|
||||
|
||||
bool hasPrimaryContext(int64_t device_index) {
|
||||
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
|
||||
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
|
||||
unsigned int ctx_flags;
|
||||
// In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird
|
||||
// (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero.
|
||||
int ctx_is_active = 0;
|
||||
AT_CUDA_DRIVER_CHECK(nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
|
||||
return ctx_is_active == 1;
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
|
||||
return at::cuda::detail::hasPrimaryContext(device_index);
|
||||
}
|
||||
|
||||
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
|
||||
// check current device first
|
||||
int64_t current_device_index = current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
}
|
||||
}
|
||||
for (const auto device_index : c10::irange(at::cuda::device_count())) {
|
||||
if (device_index == current_device_index) continue;
|
||||
if (hasPrimaryContext(device_index)) {
|
||||
return device_index;
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
return _hasPrimaryContext(device_index);
|
||||
}
|
||||
|
||||
Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
|
||||
@ -443,6 +440,4 @@ using at::RegistererCUDAHooksRegistry;
|
||||
|
||||
REGISTER_CUDA_HOOKS(CUDAHooks);
|
||||
|
||||
} // namespace detail
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
} // namespace at::cuda::detail
|
||||
|
@ -15,8 +15,6 @@ namespace at { namespace cuda { namespace detail {
|
||||
// in the same library where Magma will be used.
|
||||
TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
|
||||
|
||||
TORCH_CUDA_CPP_API bool hasPrimaryContext(int64_t device_index);
|
||||
TORCH_CUDA_CPP_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
|
||||
|
||||
// The real implementation of CUDAHooksInterface
|
||||
struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
|
@ -1,7 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/Registry.h>
|
||||
@ -10,14 +9,14 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
// Forward-declares at::cuda::NVRTC
|
||||
namespace at { namespace cuda {
|
||||
struct NVRTC;
|
||||
}} // at::cuda
|
||||
|
||||
// Forward-declares at::Context, at::Generator and at::cuda::NVRTC
|
||||
namespace at {
|
||||
class Context;
|
||||
}
|
||||
struct Generator;
|
||||
namespace cuda {
|
||||
struct NVRTC;
|
||||
} // namespace cuda
|
||||
} // namespace at
|
||||
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
@ -3,8 +3,7 @@
|
||||
|
||||
#include <limits>
|
||||
|
||||
namespace c10 {
|
||||
namespace cuda {
|
||||
namespace c10::cuda {
|
||||
|
||||
namespace {
|
||||
// returns -1 on failure
|
||||
@ -149,5 +148,38 @@ void warn_or_error_on_sync() {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
c10::optional<int64_t> getDeviceIndexWithPrimaryContext() {
|
||||
// check current device first
|
||||
int64_t current_device_index = current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
}
|
||||
}
|
||||
for (const auto device_index : c10::irange(at::cuda::device_count())) {
|
||||
if (device_index == current_device_index)
|
||||
continue;
|
||||
if (hasPrimaryContext(device_index)) {
|
||||
return device_index;
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
namespace _internal {
|
||||
bool dummyHasPrimaryContext(C10_UNUSED int64_t device_index) {
|
||||
TORCH_CHECK(false, "Should never been called");
|
||||
}
|
||||
bool (*hasPrimaryContext)(int64_t) = dummyHasPrimaryContext;
|
||||
|
||||
// Private api to be called from CUDAHooks.cpp
|
||||
C10_CUDA_API void setHasPrimaryContext(bool (*func)(int64_t)) {
|
||||
hasPrimaryContext = func ? func : dummyHasPrimaryContext;
|
||||
}
|
||||
} // namespace _internal
|
||||
|
||||
bool hasPrimaryContext(int64_t device_index) {
|
||||
return _internal::hasPrimaryContext(device_index);
|
||||
}
|
||||
|
||||
} // namespace c10::cuda
|
||||
|
@ -96,5 +96,8 @@ C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
|
||||
}
|
||||
|
||||
C10_CUDA_API bool hasPrimaryContext(int64_t device_index);
|
||||
C10_CUDA_API c10::optional<int64_t> getDeviceIndexWithPrimaryContext();
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace c10
|
||||
|
@ -486,7 +486,7 @@ PyObject* THCPModule_hasPrimaryContext(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
THPUtils_checkLong(arg), "invalid argument to has_primary_context");
|
||||
int64_t device_index = static_cast<int64_t>(THPUtils_unpackLong(arg));
|
||||
if (at::cuda::detail::hasPrimaryContext(device_index)) {
|
||||
if (c10::cuda::hasPrimaryContext(device_index)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
|
Reference in New Issue
Block a user