[Bugfix][Ray] Set the cuda context eagerly in the ray worker (#19583)

This commit is contained in:
kourosh hakhamaneshi
2025-06-19 22:01:16 -07:00
committed by GitHub
parent e3a3e4db46
commit 5e666f72cd
4 changed files with 107 additions and 0 deletions

View File

@ -271,6 +271,15 @@ steps:
commands:
- pytest -v -s prefix_caching
- label: Platform Tests (CUDA)
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/cuda
commands:
- pytest -v -s cuda/test_cuda_context.py
- label: Samplers Test # 36min
mirror_hardwares: [amdexperimental]
source_file_dependencies:

View File

@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ctypes
from concurrent.futures import ThreadPoolExecutor
import pytest
import torch
from vllm.platforms import current_platform
def check_cuda_context():
"""Check CUDA driver context status"""
try:
cuda = ctypes.CDLL('libcuda.so')
device = ctypes.c_int()
result = cuda.cuCtxGetDevice(ctypes.byref(device))
return (True, device.value) if result == 0 else (False, None)
except Exception:
return False, None
def run_cuda_test_in_thread(device_input, expected_device_id):
"""Run CUDA context test in separate thread for isolation"""
try:
# New thread should have no CUDA context initially
valid_before, device_before = check_cuda_context()
if valid_before:
return False, \
"CUDA context should not exist in new thread, " \
f"got device {device_before}"
# Test setting CUDA context
current_platform.set_device(device_input)
# Verify context is created correctly
valid_after, device_id = check_cuda_context()
if not valid_after:
return False, "CUDA context should be valid after set_cuda_context"
if device_id != expected_device_id:
return False, \
f"Expected device {expected_device_id}, got {device_id}"
return True, "Success"
except Exception as e:
return False, f"Exception in thread: {str(e)}"
class TestSetCudaContext:
"""Test suite for the set_cuda_context function."""
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="CUDA not available")
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
argvalues=[
(0, 0),
(torch.device('cuda:0'), 0),
('cuda:0', 0),
],
ids=["int", "torch_device", "string"])
def test_set_cuda_context_parametrized(self, device_input,
expected_device_id):
"""Test setting CUDA context in isolated threads."""
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_cuda_test_in_thread, device_input,
expected_device_id)
success, message = future.result(timeout=30)
assert success, message
@pytest.mark.skipif(not current_platform.is_cuda(),
reason="CUDA not available")
def test_set_cuda_context_invalid_device_type(self):
"""Test error handling for invalid device type."""
with pytest.raises(ValueError, match="Expected a cuda device"):
current_platform.set_device(torch.device('cpu'))
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -71,6 +71,17 @@ class CudaPlatformBase(Platform):
# though vLLM doesn't support these GPUs.
return [torch.float32]
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
super().set_device(device)
# With this trick we can force the device to be set eagerly
# see https://github.com/pytorch/pytorch/issues/155668
# for why and when it is needed
_ = torch.zeros(1, device=device)
@classmethod
def get_device_capability(cls,
device_id: int = 0

View File

@ -298,6 +298,13 @@ class Platform:
np.random.seed(seed)
torch.manual_seed(seed)
@classmethod
def set_device(cls, device: torch.device) -> None:
"""
Set the device for the current platform.
"""
torch.cuda.set_device(device)
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None