Revert "[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)"

This reverts commit 8142a0286016e63a0e91b5667e1fb1a5e868ffd7.

Reverted https://github.com/pytorch/pytorch/pull/155900 on behalf of https://github.com/clee2000 due to causing some sort of hang? in test_distributed_spawn [GH job link](https://github.com/pytorch/pytorch/actions/runs/15678895788/job/44168117193) [HUD commit link](8142a02860) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/155900#issuecomment-2977365699))
This commit is contained in:
PyTorch MergeBot
2025-06-16 16:59:25 +00:00
parent 2a4e357192
commit 365ce465f3
4 changed files with 10 additions and 34 deletions

View File

@ -130,8 +130,8 @@ DeviceIndex current_device() {
return cur_device;
}
void set_device(DeviceIndex device, const bool force) {
C10_CUDA_CHECK(c10::cuda::SetDevice(device, force));
void set_device(DeviceIndex device) {
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
}
void device_synchronize() {
@ -231,11 +231,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
return err;
}
cudaError_t SetDevice(DeviceIndex device, const bool force) {
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
if (force) {
return cudaSetDevice(device);
}
cudaError_t SetDevice(DeviceIndex device) {
TORCH_CHECK(device >= 0, "device id must be positive!", device);
targetDeviceIndex = -1;
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
@ -312,11 +309,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
return err;
}
cudaError_t SetDevice(DeviceIndex device, const bool force) {
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
if (force) {
return cudaSetDevice(device);
}
cudaError_t SetDevice(DeviceIndex device) {
TORCH_CHECK(device >= 0, "device id must be positive!", device);
int cur_device = -1;
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
if (device == cur_device) {

View File

@ -27,7 +27,7 @@ C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
C10_CUDA_API DeviceIndex current_device();
C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
C10_CUDA_API void set_device(DeviceIndex device);
C10_CUDA_API void device_synchronize();
@ -38,8 +38,7 @@ C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
C10_CUDA_API cudaError_t
SetDevice(DeviceIndex device, const bool force = false);
C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);

View File

@ -4,11 +4,7 @@ import sys
import unittest
import torch
from torch.testing._internal.common_cuda import (
_get_torch_cuda_version,
TEST_CUDA,
TEST_MULTIGPU,
)
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
@ -35,19 +31,6 @@ class TestCudaPrimaryCtx(TestCase):
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
)
def test_set_device_0(self):
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
# Initially, we should not have any context on device 0.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
torch.cuda.set_device(0)
if _get_torch_cuda_version() >= (12, 0):
# Now after the device was set, the contex should present in CUDA 12.
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
else:
# In CUDA 11 the context should not be created.
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_str_repr(self):
x = torch.randn(1, device="cuda:1")

View File

@ -64,7 +64,7 @@ PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
auto device = THPUtils_unpackLong(arg);
torch::utils::device_lazy_init(at::kCUDA);
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device), /*force*/ true);
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS