mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19278 Based on discussion in https://github.com/pytorch/pytorch/pull/19278 and https://github.com/pytorch/pytorch/pull/18687, changes to replicate.py will be reverted to disallow replicating multi-device modules. Reviewed By: pietern Differential Revision: D14940018 fbshipit-source-id: 7504c0f4325c2639264c52dcbb499e61c9ad2c26
35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
r"""This file is allowed to initialize CUDA context when imported."""
|
|
|
|
import torch
|
|
import torch.cuda
|
|
from common_utils import TEST_WITH_ROCM, TEST_NUMBA
|
|
|
|
|
|
TEST_CUDA = torch.cuda.is_available()
|
|
TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
|
|
CUDA_DEVICE = TEST_CUDA and torch.device("cuda:0")
|
|
# note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
|
|
TEST_CUDNN = TEST_CUDA and (TEST_WITH_ROCM or torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
|
|
TEST_CUDNN_VERSION = TEST_CUDNN and torch.backends.cudnn.version()
|
|
|
|
if TEST_NUMBA:
|
|
import numba.cuda
|
|
TEST_NUMBA_CUDA = numba.cuda.is_available()
|
|
else:
|
|
TEST_NUMBA_CUDA = False
|
|
|
|
# Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
|
|
# RNG have been initialized.
|
|
__cuda_ctx_rng_initialized = False
|
|
|
|
|
|
# after this call, CUDA context and RNG must have been initialized on each GPU
|
|
def initialize_cuda_context_rng():
|
|
global __cuda_ctx_rng_initialized
|
|
assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
|
|
if not __cuda_ctx_rng_initialized:
|
|
# initialize cuda context and rng for memory tests
|
|
for i in range(torch.cuda.device_count()):
|
|
torch.randn(1, device="cuda:{}".format(i))
|
|
__cuda_ctx_rng_initialized = True
|