mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Fix CI slow tests: ImportError: vLLM is not installed (#4287)
This commit is contained in:
committed by
GitHub
parent
ef40c047aa
commit
c7c041ecc8
@ -22,7 +22,7 @@ from transformers.testing_utils import require_torch_multi_accelerator, torch_de
|
||||
from trl.extras.vllm_client import VLLMClient
|
||||
from trl.scripts.vllm_serve import chunk_list
|
||||
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators, require_vllm
|
||||
|
||||
|
||||
class TestChunkList(TrlTestCase):
|
||||
@ -53,6 +53,7 @@ class TestChunkList(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServer(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -212,6 +213,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerTP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -274,6 +276,7 @@ class TestVLLMClientServerTP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -336,6 +339,7 @@ class TestVLLMClientServerDP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
||||
"""Test the device parameter functionality in init_communicator."""
|
||||
|
||||
|
Reference in New Issue
Block a user