mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)
Fixes #155668 Pull Request resolved: https://github.com/pytorch/pytorch/pull/155900 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fc177801af
commit
4a26bb8a12
@ -4,7 +4,11 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
TEST_CUDA,
|
||||
TEST_MULTIGPU,
|
||||
)
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
|
||||
|
||||
|
||||
@ -31,6 +35,19 @@ class TestCudaPrimaryCtx(TestCase):
|
||||
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
|
||||
)
|
||||
|
||||
def test_set_device_0(self):
|
||||
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
|
||||
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
|
||||
# Initially, we should not have any context on device 0.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
torch.cuda.set_device(0)
|
||||
if _get_torch_cuda_version() >= (12, 0):
|
||||
# Now after the device was set, the contex should present in CUDA 12.
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
else:
|
||||
# In CUDA 11 the context should not be created.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_str_repr(self):
|
||||
x = torch.randn(1, device="cuda:1")
|
||||
|
Reference in New Issue
Block a user