[Bugfix][Ray] Set the cuda context eagerly in the ray worker (#19583)
This commit is contained in:
committed by
GitHub
parent
e3a3e4db46
commit
5e666f72cd
@ -271,6 +271,15 @@ steps:
|
|||||||
commands:
|
commands:
|
||||||
- pytest -v -s prefix_caching
|
- pytest -v -s prefix_caching
|
||||||
|
|
||||||
|
|
||||||
|
- label: Platform Tests (CUDA)
|
||||||
|
mirror_hardwares: [amdexperimental]
|
||||||
|
source_file_dependencies:
|
||||||
|
- vllm/
|
||||||
|
- tests/cuda
|
||||||
|
commands:
|
||||||
|
- pytest -v -s cuda/test_cuda_context.py
|
||||||
|
|
||||||
- label: Samplers Test # 36min
|
- label: Samplers Test # 36min
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
|||||||
80
tests/cuda/test_cuda_context.py
Normal file
80
tests/cuda/test_cuda_context.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import ctypes
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_context():
|
||||||
|
"""Check CUDA driver context status"""
|
||||||
|
try:
|
||||||
|
cuda = ctypes.CDLL('libcuda.so')
|
||||||
|
device = ctypes.c_int()
|
||||||
|
result = cuda.cuCtxGetDevice(ctypes.byref(device))
|
||||||
|
return (True, device.value) if result == 0 else (False, None)
|
||||||
|
except Exception:
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
def run_cuda_test_in_thread(device_input, expected_device_id):
|
||||||
|
"""Run CUDA context test in separate thread for isolation"""
|
||||||
|
try:
|
||||||
|
# New thread should have no CUDA context initially
|
||||||
|
valid_before, device_before = check_cuda_context()
|
||||||
|
if valid_before:
|
||||||
|
return False, \
|
||||||
|
"CUDA context should not exist in new thread, " \
|
||||||
|
f"got device {device_before}"
|
||||||
|
|
||||||
|
# Test setting CUDA context
|
||||||
|
current_platform.set_device(device_input)
|
||||||
|
|
||||||
|
# Verify context is created correctly
|
||||||
|
valid_after, device_id = check_cuda_context()
|
||||||
|
if not valid_after:
|
||||||
|
return False, "CUDA context should be valid after set_cuda_context"
|
||||||
|
if device_id != expected_device_id:
|
||||||
|
return False, \
|
||||||
|
f"Expected device {expected_device_id}, got {device_id}"
|
||||||
|
|
||||||
|
return True, "Success"
|
||||||
|
except Exception as e:
|
||||||
|
return False, f"Exception in thread: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSetCudaContext:
|
||||||
|
"""Test suite for the set_cuda_context function."""
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="CUDA not available")
|
||||||
|
@pytest.mark.parametrize(argnames="device_input,expected_device_id",
|
||||||
|
argvalues=[
|
||||||
|
(0, 0),
|
||||||
|
(torch.device('cuda:0'), 0),
|
||||||
|
('cuda:0', 0),
|
||||||
|
],
|
||||||
|
ids=["int", "torch_device", "string"])
|
||||||
|
def test_set_cuda_context_parametrized(self, device_input,
|
||||||
|
expected_device_id):
|
||||||
|
"""Test setting CUDA context in isolated threads."""
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(run_cuda_test_in_thread, device_input,
|
||||||
|
expected_device_id)
|
||||||
|
success, message = future.result(timeout=30)
|
||||||
|
assert success, message
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="CUDA not available")
|
||||||
|
def test_set_cuda_context_invalid_device_type(self):
|
||||||
|
"""Test error handling for invalid device type."""
|
||||||
|
with pytest.raises(ValueError, match="Expected a cuda device"):
|
||||||
|
current_platform.set_device(torch.device('cpu'))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
@ -71,6 +71,17 @@ class CudaPlatformBase(Platform):
|
|||||||
# though vLLM doesn't support these GPUs.
|
# though vLLM doesn't support these GPUs.
|
||||||
return [torch.float32]
|
return [torch.float32]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_device(cls, device: torch.device) -> None:
|
||||||
|
"""
|
||||||
|
Set the device for the current platform.
|
||||||
|
"""
|
||||||
|
super().set_device(device)
|
||||||
|
# With this trick we can force the device to be set eagerly
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/155668
|
||||||
|
# for why and when it is needed
|
||||||
|
_ = torch.zeros(1, device=device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(cls,
|
def get_device_capability(cls,
|
||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
|
|||||||
@ -298,6 +298,13 @@ class Platform:
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_device(cls, device: torch.device) -> None:
|
||||||
|
"""
|
||||||
|
Set the device for the current platform.
|
||||||
|
"""
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def pre_register_and_update(cls,
|
def pre_register_and_update(cls,
|
||||||
parser: Optional[FlexibleArgumentParser] = None
|
parser: Optional[FlexibleArgumentParser] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user