mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Add test decorator for minimum GPU memory (#8925)
This commit is contained in:
@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
|
||||
assert output2[i] == expected_lora_output[i]
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
@pytest.mark.parametrize("fully_sharded", [True, False])
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
|
||||
num_gpus_available, fully_sharded):
|
||||
if num_gpus_available < 4:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
|
||||
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
|
@ -71,10 +71,10 @@ def do_sample(llm: vllm.LLM,
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < tp_size:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
|
||||
tp_size):
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(
|
||||
model=model.model_path,
|
||||
@ -164,11 +164,10 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_quant_model_tp_equality(tinyllama_lora_files, model):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 2:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
||||
model):
|
||||
if num_gpus_available < 2:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
|
||||
llm_tp1 = vllm.LLM(
|
||||
model=model.model_path,
|
||||
|
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
MODELS = [
|
||||
@ -69,20 +70,10 @@ def test_phimoe_routing_function():
|
||||
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
gpu_memory = props.total_memory / (1024**3)
|
||||
return gpu_memory
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(condition=is_cpu(),
|
||||
reason="This test takes a lot time to run on CPU, "
|
||||
"and vllm CI's disk space is not enough for this model.")
|
||||
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
|
||||
reason="Skip this test if GPU memory is insufficient.")
|
||||
@large_gpu_test(min_gb=80)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
|
@ -11,6 +11,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_VideoAssets)
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
# Video test
|
||||
@ -164,9 +165,7 @@ def run_video_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
@ -306,9 +303,7 @@ def run_image_test(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
|
@ -17,7 +17,7 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
from vllm.sequence import Logprob, SampleLogprobs
|
||||
|
||||
from ....utils import VLLM_PATH
|
||||
from ....utils import VLLM_PATH, large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
|
||||
for tokens, text, logprobs in json_data]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||
)
|
||||
@large_gpu_test(min_gb=80)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@ -157,10 +154,7 @@ def test_chat(
|
||||
name_1="output")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||
)
|
||||
@large_gpu_test(min_gb=80)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
|
@ -9,6 +9,7 @@ from vllm.sequence import SampleLogprobs
|
||||
|
||||
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
|
||||
_ImageAssets)
|
||||
from ....utils import large_gpu_test
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
_LIMIT_IMAGE_PER_PROMPT = 1
|
||||
@ -227,29 +228,26 @@ def _run_test(
|
||||
)
|
||||
|
||||
|
||||
SIZES = [
|
||||
# Text only
|
||||
[],
|
||||
# Single-size
|
||||
[(512, 512)],
|
||||
# Single-size, batched
|
||||
[(512, 512), (512, 512), (512, 512)],
|
||||
# Multi-size, batched
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028)],
|
||||
# Multi-size, batched, including text only
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||
# to cover all of them
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("sizes", SIZES)
|
||||
@pytest.mark.parametrize(
|
||||
"sizes",
|
||||
[
|
||||
# Text only
|
||||
[],
|
||||
# Single-size
|
||||
[(512, 512)],
|
||||
# Single-size, batched
|
||||
[(512, 512), (512, 512), (512, 512)],
|
||||
# Multi-size, batched
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028)],
|
||||
# Multi-size, batched, including text only
|
||||
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
|
||||
(1024, 1024), (512, 1536), (512, 2028), None],
|
||||
# mllama has 8 possible aspect ratios, carefully set the sizes
|
||||
# to cover all of them
|
||||
])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
|
@ -24,8 +24,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
|
||||
get_open_port, is_hip)
|
||||
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
|
||||
cuda_device_count_stateless, get_open_port, is_hip)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
||||
@ -455,6 +455,37 @@ def fork_new_process_for_each_test(
|
||||
return wrapper
|
||||
|
||||
|
||||
def large_gpu_test(*, min_gb: int):
|
||||
"""
|
||||
Decorate a test to be skipped if no GPU is available or it does not have
|
||||
sufficient memory.
|
||||
|
||||
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
|
||||
"""
|
||||
try:
|
||||
if current_platform.is_cpu():
|
||||
memory_gb = 0
|
||||
else:
|
||||
memory_gb = current_platform.get_device_total_memory() / GB_bytes
|
||||
except Exception as e:
|
||||
warnings.warn(
|
||||
f"An error occurred when finding the available memory: {e}",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
memory_gb = 0
|
||||
|
||||
test_skipif = pytest.mark.skipif(
|
||||
memory_gb < min_gb,
|
||||
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
|
||||
)
|
||||
|
||||
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
|
||||
return test_skipif(fork_new_process_for_each_test(f))
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def multi_gpu_test(*, num_gpus: int):
|
||||
"""
|
||||
Decorate a test to be run only when multiple GPUs are available.
|
||||
|
@ -1,3 +1,4 @@
|
||||
import psutil
|
||||
import torch
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
@ -10,6 +11,10 @@ class CpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "cpu"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
return psutil.virtual_memory().total
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
|
||||
return pynvml.nvmlDeviceGetName(handle)
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
@with_nvml_context
|
||||
def get_physical_device_total_memory(device_id: int = 0) -> int:
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
|
||||
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
|
||||
|
||||
|
||||
@with_nvml_context
|
||||
def warn_if_different_devices():
|
||||
device_ids: int = pynvml.nvmlDeviceGetCount()
|
||||
@ -107,6 +114,11 @@ class CudaPlatform(Platform):
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_name(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
physical_device_id = device_id_to_physical_device_id(device_id)
|
||||
return get_physical_device_total_memory(physical_device_id)
|
||||
|
||||
@classmethod
|
||||
@with_nvml_context
|
||||
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
|
||||
|
@ -85,6 +85,12 @@ class Platform:
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
"""Get the name of a device."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
"""Get the total memory of a device in bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -29,3 +29,8 @@ class RocmPlatform(Platform):
|
||||
@lru_cache(maxsize=8)
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return torch.cuda.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.cuda.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
@ -10,6 +10,10 @@ class TpuPlatform(Platform):
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
return torch.no_grad()
|
||||
|
@ -8,13 +8,15 @@ class XPUPlatform(Platform):
|
||||
|
||||
@staticmethod
|
||||
def get_device_capability(device_id: int = 0) -> DeviceCapability:
|
||||
return DeviceCapability(major=int(
|
||||
torch.xpu.get_device_capability(device_id)['version'].split('.')
|
||||
[0]),
|
||||
minor=int(
|
||||
torch.xpu.get_device_capability(device_id)
|
||||
['version'].split('.')[1]))
|
||||
major, minor, *_ = torch.xpu.get_device_capability(
|
||||
device_id)['version'].split('.')
|
||||
return DeviceCapability(major=int(major), minor=int(minor))
|
||||
|
||||
@staticmethod
|
||||
def get_device_name(device_id: int = 0) -> str:
|
||||
return torch.xpu.get_device_name(device_id)
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
device_props = torch.xpu.get_device_properties(device_id)
|
||||
return device_props.total_memory
|
||||
|
@ -119,6 +119,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
GB_bytes = 1_000_000_000
|
||||
"""The number of bytes in one gigabyte (GB)."""
|
||||
|
||||
GiB_bytes = 1 << 30
|
||||
"""The number of bytes in one gibibyte (GiB)."""
|
||||
|
||||
|
Reference in New Issue
Block a user