[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)

Fixes #155668

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155900
Approved by: https://github.com/ngimel
This commit is contained in:
Aidyn-A
2025-06-17 18:59:41 +00:00
committed by PyTorch MergeBot
parent fc177801af
commit 4a26bb8a12
4 changed files with 34 additions and 10 deletions

View File

@ -4,7 +4,11 @@ import sys
import unittest
import torch
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
TEST_CUDA,
TEST_MULTIGPU,
)
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
@ -31,6 +35,19 @@ class TestCudaPrimaryCtx(TestCase):
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
)
def test_set_device_0(self):
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
# Initially, we should not have any context on device 0.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
torch.cuda.set_device(0)
if _get_torch_cuda_version() >= (12, 0):
# Now after the device was set, the contex should present in CUDA 12.
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
else:
# In CUDA 11 the context should not be created.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_str_repr(self):
x = torch.randn(1, device="cuda:1")