mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Extend torch.cuda.is_available() to attempt an NVML-based CUDA availability assessment when explicitly requested by the user (#85951)
Fixes #83973 (This is a substitute PR for https://github.com/pytorch/pytorch/pull/85024)
First of all, thanks for your invaluable contributions to PyTorch everyone!
Given how extensively `torch.cuda.is_available` is used in the PyTorch ecosystem, IMHO it's worthwhile to provide downstream libraries/frameworks/users the ability to alter the default behavior of `torch.cuda.is_available` in the context of their PyTorch usage.
I'm confident there are many current and future such use cases which could benefit from leveraging a weakened, NVML-based `torch.cuda.is_available` assessment at a downstream framework's explicit direction (thanks @malfet 81da50a972
!). Though one could always patch out the `torch.cuda.is_available` function with another implementation in a downstream library, I think this environmental variable based configuration option is more convenient and the cost to including the option is quite low.
As discussed in https://github.com/pytorch/pytorch/pull/85024#issuecomment-1261542045, this PR gates new non-default NVML-based CUDA behavior with an environmental variable (PYTORCH_NVML_BASED_CUDA_CHK) that allows a user/framework to invoke non-default, NVML-based `is_available()` assessments if desired.
Thanks again for your work everyone!
@ngimel @malfet @awaelchli
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85951
Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd7c86eaa4
commit
ce56ee11fd
@ -502,6 +502,26 @@ have a flag that can be used to disable CUDA, in combination with
|
||||
else:
|
||||
args.device = torch.device('cpu')
|
||||
|
||||
.. note::
|
||||
|
||||
When assessing the availability of CUDA in a given environment (:meth:`~torch.cuda.is_available`), PyTorch's default
|
||||
behavior is to call the CUDA Runtime API method `cudaGetDeviceCount`_. Because this call in turn initializes the
|
||||
CUDA Driver API (via `cuInit`_) if it is not already initialized, subsequent forks of a process that has run
|
||||
:meth:`~torch.cuda.is_available` will fail with a CUDA initialization error.
|
||||
|
||||
One can set ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` in your environment before importing PyTorch modules that execute
|
||||
:meth:`~torch.cuda.is_available` (or before executing it directly) in order to direct
|
||||
:meth:`~torch.cuda.is_available` to attempt an NVML-based assessment (`nvmlDeviceGetCount_v2`_). If the
|
||||
NVML-based assessment is successful (i.e. NVML discovery/initialization does not fail),
|
||||
:meth:`~torch.cuda.is_available` calls will not poison subsequent forks.
|
||||
|
||||
If NVML discovery/initialization fails, :meth:`~torch.cuda.is_available` will fallback to the standard CUDA Runtime
|
||||
API assessment and the aforementioned fork constraint will apply.
|
||||
|
||||
Note that the above NVML-based CUDA availability assessment provides a weaker guarantee than the default CUDA
|
||||
Runtime API approach (which requires CUDA initialization to succeed). In some circumstances, the NVML-based check
|
||||
may succeed while later CUDA initialization fails.
|
||||
|
||||
Now that we have ``args.device``, we can use it to create a Tensor on the
|
||||
desired device.
|
||||
|
||||
@ -582,6 +602,7 @@ also preserve :class:`torch.device` and :class:`torch.dtype` of a Tensor).
|
||||
y_cpu = torch.ones_like(x_cpu)
|
||||
y_gpu = torch.zeros_like(x_gpu)
|
||||
|
||||
|
||||
.. _cuda-memory-pinning:
|
||||
|
||||
Use pinned memory buffers
|
||||
@ -633,6 +654,15 @@ by GIL of Python interpreter.
|
||||
If you use :class:`~torch.nn.parallel.DistributedDataParallel`, you could use
|
||||
`torch.distributed.launch` utility to launch your program, see :ref:`distributed-launch`.
|
||||
|
||||
.. _cudaGetDeviceCount:
|
||||
https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g18808e54893cfcaafefeab31a73cc55f
|
||||
|
||||
.. _cuInit:
|
||||
https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3
|
||||
|
||||
.. _nvmlDeviceGetCount_v2:
|
||||
https://docs.nvidia.com/deploy/nvml-api/group__nvmlDeviceQueries.html#group__nvmlDeviceQueries_1ga93623b195bff04bbe3490ca33c8a42d
|
||||
|
||||
.. _cuda-graph-semantics:
|
||||
|
||||
CUDA Graphs
|
||||
|
@ -251,6 +251,7 @@ ROCM_BLOCKLIST = [
|
||||
"distributed/_shard/test_replicated_tensor",
|
||||
"test_determination",
|
||||
"test_jit_legacy",
|
||||
"test_cuda_nvml_based_avail",
|
||||
]
|
||||
|
||||
RUN_PARALLEL_BLOCKLIST = [
|
||||
@ -266,6 +267,7 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_tensorexpr",
|
||||
"test_cuda_primary_ctx",
|
||||
"test_cuda_trace",
|
||||
"test_cuda_nvml_based_avail",
|
||||
] + FSDP_TEST
|
||||
|
||||
CI_SERIAL_LIST = [
|
||||
@ -749,6 +751,7 @@ def run_test_ops(test_module, test_directory, options):
|
||||
|
||||
CUSTOM_HANDLERS = {
|
||||
"test_cuda_primary_ctx": test_cuda_primary_ctx,
|
||||
"test_cuda_nvml_based_avail": get_run_test_with_subprocess_fn(),
|
||||
"test_cuda_trace": get_run_test_with_subprocess_fn(),
|
||||
"test_cpp_extensions_aot_no_ninja": test_cpp_extensions_aot_no_ninja,
|
||||
"test_cpp_extensions_aot_ninja": test_cpp_extensions_aot_ninja,
|
||||
|
69
test/test_cuda_nvml_based_avail.py
Normal file
69
test/test_cuda_nvml_based_avail.py
Normal file
@ -0,0 +1,69 @@
|
||||
# Owner(s): ["module: cuda"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
import multiprocessing
|
||||
import torch
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
# NOTE: Each of the tests in this module need to be run in a brand new process to ensure CUDA is uninitialized
|
||||
# prior to test initiation.
|
||||
with patch.dict(os.environ, {"PYTORCH_NVML_BASED_CUDA_CHECK": "1"}):
|
||||
# Before executing the desired tests, we need to disable CUDA initialization and fork_handler additions that would
|
||||
# otherwise be triggered by the `torch.testing._internal.common_utils` module import
|
||||
from torch.testing._internal.common_utils import (parametrize, instantiate_parametrized_tests, run_tests, TestCase,
|
||||
IS_WINDOWS)
|
||||
# NOTE: Because `remove_device_and_dtype_suffixes` initializes CUDA context (triggered via the import of
|
||||
# `torch.testing._internal.common_device_type` which imports `torch.testing._internal.common_cuda`) we need
|
||||
# to bypass that method here which should be irrelevant to the parameterized tests in this module.
|
||||
torch.testing._internal.common_utils.remove_device_and_dtype_suffixes = lambda x: x
|
||||
|
||||
TEST_CUDA = torch.cuda.is_available()
|
||||
if not TEST_CUDA:
|
||||
print('CUDA not available, skipping tests', file=sys.stderr)
|
||||
TestCase = object # type: ignore[misc, assignment] # noqa: F811
|
||||
|
||||
|
||||
class TestExtendedCUDAIsAvail(TestCase):
|
||||
SUBPROCESS_REMINDER_MSG = (
|
||||
"\n REMINDER: Tests defined in test_cuda_nvml_based_avail.py must be run in a process "
|
||||
"where there CUDA Driver API has not been initialized. Before further debugging, ensure you are either using "
|
||||
"run_test.py or have added --subprocess to run each test in a different subprocess.")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch.cuda.device_count.cache_clear() # clear the lru_cache on this method before our test
|
||||
|
||||
@staticmethod
|
||||
def in_bad_fork_test() -> bool:
|
||||
_ = torch.cuda.is_available()
|
||||
return torch.cuda._is_in_bad_fork()
|
||||
|
||||
# These tests validate the behavior and activation of the weaker, NVML-based, user-requested
|
||||
# `torch.cuda.is_available()` assessment. The NVML-based assessment should be attempted when
|
||||
# `PYTORCH_NVML_BASED_CUDA_CHECK` is set to 1, reverting to the default CUDA Runtime API check otherwise.
|
||||
# If the NVML-based assessment is attempted but fails, the CUDA Runtime API check should be executed
|
||||
@unittest.skipIf(IS_WINDOWS, "Needs fork")
|
||||
@parametrize("nvml_avail", [True, False])
|
||||
@parametrize("avoid_init", ['1', '0', None])
|
||||
def test_cuda_is_available(self, avoid_init, nvml_avail):
|
||||
patch_env = {"PYTORCH_NVML_BASED_CUDA_CHECK": avoid_init} if avoid_init else {}
|
||||
with patch.dict(os.environ, **patch_env):
|
||||
if nvml_avail:
|
||||
_ = torch.cuda.is_available()
|
||||
else:
|
||||
with patch.object(torch.cuda, '_device_count_nvml', return_value=-1):
|
||||
_ = torch.cuda.is_available()
|
||||
with multiprocessing.get_context("fork").Pool(1) as pool:
|
||||
in_bad_fork = pool.apply(TestExtendedCUDAIsAvail.in_bad_fork_test)
|
||||
if os.getenv('PYTORCH_NVML_BASED_CUDA_CHECK') == '1' and nvml_avail:
|
||||
self.assertFalse(in_bad_fork, TestExtendedCUDAIsAvail.SUBPROCESS_REMINDER_MSG)
|
||||
else:
|
||||
assert in_bad_fork
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestExtendedCUDAIsAvail)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
@ -79,14 +79,25 @@ def _is_compiled() -> bool:
|
||||
r"""Returns true if compile with CUDA support."""
|
||||
return hasattr(torch._C, '_cuda_getDeviceCount')
|
||||
|
||||
def _nvml_based_avail() -> bool:
|
||||
return os.getenv('PYTORCH_NVML_BASED_CUDA_CHECK') == '1'
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Returns a bool indicating if CUDA is currently available."""
|
||||
if not _is_compiled():
|
||||
return False
|
||||
# This function never throws and returns 0 if driver is missing or can't
|
||||
# be initialized
|
||||
if _nvml_based_avail():
|
||||
# The user has set an env variable to request this availability check that attempts to avoid fork poisoning by
|
||||
# using NVML at the cost of a weaker CUDA availability assessment. Note that if NVML discovery/initialization
|
||||
# fails, this assessment falls back to the default CUDA Runtime API assessment (`cudaGetDeviceCount`)
|
||||
return device_count() > 0
|
||||
else:
|
||||
# The default availability inspection never throws and returns 0 if the driver is missing or can't
|
||||
# be initialized. This uses the CUDA Runtime API `cudaGetDeviceCount` which in turn initializes the CUDA Driver
|
||||
# API via `cuInit`
|
||||
return torch._C._cuda_getDeviceCount() > 0
|
||||
|
||||
|
||||
def is_bf16_supported():
|
||||
r"""Returns a bool indicating if the current CUDA/ROCm device supports dtype bfloat16"""
|
||||
# Check for ROCm, if true return true, no ROCM_VERSION check required,
|
||||
|
Reference in New Issue
Block a user