mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[C10][CUDA] Eagerly create context on torch.cuda.set_device(device) call (#155900)"
This reverts commit 8142a0286016e63a0e91b5667e1fb1a5e868ffd7.
Reverted https://github.com/pytorch/pytorch/pull/155900 on behalf of https://github.com/clee2000 due to causing some sort of hang? in test_distributed_spawn [GH job link](https://github.com/pytorch/pytorch/actions/runs/15678895788/job/44168117193) [HUD commit link](8142a02860
) note to self: bad TD ([comment](https://github.com/pytorch/pytorch/pull/155900#issuecomment-2977365699))
This commit is contained in:
@ -130,8 +130,8 @@ DeviceIndex current_device() {
|
||||
return cur_device;
|
||||
}
|
||||
|
||||
void set_device(DeviceIndex device, const bool force) {
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device, force));
|
||||
void set_device(DeviceIndex device) {
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device));
|
||||
}
|
||||
|
||||
void device_synchronize() {
|
||||
@ -231,11 +231,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
||||
return err;
|
||||
}
|
||||
|
||||
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
||||
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
|
||||
if (force) {
|
||||
return cudaSetDevice(device);
|
||||
}
|
||||
cudaError_t SetDevice(DeviceIndex device) {
|
||||
TORCH_CHECK(device >= 0, "device id must be positive!", device);
|
||||
targetDeviceIndex = -1;
|
||||
int cur_device = -1;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
||||
@ -312,11 +309,8 @@ cudaError_t GetDevice(DeviceIndex* device) {
|
||||
return err;
|
||||
}
|
||||
|
||||
cudaError_t SetDevice(DeviceIndex device, const bool force) {
|
||||
TORCH_CHECK(device >= 0, "device id must be non-negative!", device);
|
||||
if (force) {
|
||||
return cudaSetDevice(device);
|
||||
}
|
||||
cudaError_t SetDevice(DeviceIndex device) {
|
||||
TORCH_CHECK(device >= 0, "device id must be positive!", device);
|
||||
int cur_device = -1;
|
||||
C10_CUDA_CHECK(cudaGetDevice(&cur_device));
|
||||
if (device == cur_device) {
|
||||
|
@ -27,7 +27,7 @@ C10_CUDA_API DeviceIndex device_count_ensure_non_zero();
|
||||
|
||||
C10_CUDA_API DeviceIndex current_device();
|
||||
|
||||
C10_CUDA_API void set_device(DeviceIndex device, const bool force = false);
|
||||
C10_CUDA_API void set_device(DeviceIndex device);
|
||||
|
||||
C10_CUDA_API void device_synchronize();
|
||||
|
||||
@ -38,8 +38,7 @@ C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count);
|
||||
|
||||
C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device);
|
||||
|
||||
C10_CUDA_API cudaError_t
|
||||
SetDevice(DeviceIndex device, const bool force = false);
|
||||
C10_CUDA_API cudaError_t SetDevice(DeviceIndex device);
|
||||
|
||||
C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device);
|
||||
|
||||
|
@ -4,11 +4,7 @@ import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_cuda import (
|
||||
_get_torch_cuda_version,
|
||||
TEST_CUDA,
|
||||
TEST_MULTIGPU,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
|
||||
|
||||
|
||||
@ -35,19 +31,6 @@ class TestCudaPrimaryCtx(TestCase):
|
||||
TestCudaPrimaryCtx.CTX_ALREADY_CREATED_ERR_MSG,
|
||||
)
|
||||
|
||||
def test_set_device_0(self):
|
||||
# In CUDA 12 the behavior of cudaSetDevice has changed. It eagerly creates context on target.
|
||||
# The behavior of `torch.cuda.set_device(0)` should also create context on the device 0.
|
||||
# Initially, we should not have any context on device 0.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
torch.cuda.set_device(0)
|
||||
if _get_torch_cuda_version() >= (12, 0):
|
||||
# Now after the device was set, the contex should present in CUDA 12.
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(0))
|
||||
else:
|
||||
# In CUDA 11 the context should not be created.
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
|
||||
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
|
||||
def test_str_repr(self):
|
||||
x = torch.randn(1, device="cuda:1")
|
||||
|
@ -64,7 +64,7 @@ PyObject* THCPModule_setDevice_wrap(PyObject* self, PyObject* arg) {
|
||||
auto device = THPUtils_unpackLong(arg);
|
||||
|
||||
torch::utils::device_lazy_init(at::kCUDA);
|
||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device), /*force*/ true);
|
||||
c10::cuda::set_device(static_cast<c10::DeviceIndex>(device));
|
||||
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
|
Reference in New Issue
Block a user