[CI/Build] Add test decorator for minimum GPU memory (#8925)

This commit is contained in:
Cyrus Leung
2024-09-29 10:50:51 +08:00
committed by GitHub
parent d081da0064
commit 26a68d5d7e
14 changed files with 117 additions and 73 deletions

View File

@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files):
assert output2[i] == expected_lora_output[i]
@pytest.mark.skip("Requires multiple GPUs")
@pytest.mark.parametrize("fully_sharded", [True, False])
def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def test_baichuan_tensor_parallel_equality(baichuan_lora_files,
num_gpus_available, fully_sharded):
if num_gpus_available < 4:
pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,

View File

@ -71,10 +71,10 @@ def do_sample(llm: vllm.LLM,
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", [1])
def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
tp_size):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(
model=model.model_path,
@ -164,11 +164,10 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.skip("Requires multiple GPUs")
def test_quant_model_tp_equality(tinyllama_lora_files, model):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 2:
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model):
if num_gpus_available < 2:
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
llm_tp1 = vllm.LLM(
model=model.model_path,

View File

@ -7,6 +7,7 @@ import torch
from vllm.utils import is_cpu
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
MODELS = [
@ -69,20 +70,10 @@ def test_phimoe_routing_function():
assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"])
def get_gpu_memory():
try:
props = torch.cuda.get_device_properties(torch.cuda.current_device())
gpu_memory = props.total_memory / (1024**3)
return gpu_memory
except Exception:
return 0
@pytest.mark.skipif(condition=is_cpu(),
reason="This test takes a lot time to run on CPU, "
"and vllm CI's disk space is not enough for this model.")
@pytest.mark.skipif(condition=get_gpu_memory() < 100,
reason="Skip this test if GPU memory is insufficient.")
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])

View File

@ -11,6 +11,7 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_VideoAssets)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
# Video test
@ -164,9 +165,7 @@ def run_video_test(
)
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
@ -210,9 +209,7 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors,
)
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"sizes",
@ -306,9 +303,7 @@ def run_image_test(
)
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])

View File

@ -17,7 +17,7 @@ from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from ....utils import VLLM_PATH
from ....utils import VLLM_PATH, large_gpu_test
from ...utils import check_logprobs_close
if TYPE_CHECKING:
@ -121,10 +121,7 @@ def load_outputs_w_logprobs(filename: "StrPath") -> OutputsLogprobs:
for tokens, text, logprobs in json_data]
@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@ -157,10 +154,7 @@ def test_chat(
name_1="output")
@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:

View File

@ -9,6 +9,7 @@ from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....utils import large_gpu_test
from ...utils import check_logprobs_close
_LIMIT_IMAGE_PER_PROMPT = 1
@ -227,29 +228,26 @@ def _run_test(
)
SIZES = [
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
]
@pytest.mark.skip(
reason=
"Model is too big, test passed on L40 locally but will OOM on CI machine.")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("sizes", SIZES)
@pytest.mark.parametrize(
"sizes",
[
# Text only
[],
# Single-size
[(512, 512)],
# Single-size, batched
[(512, 512), (512, 512), (512, 512)],
# Multi-size, batched
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028)],
# Multi-size, batched, including text only
[(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024),
(1024, 1024), (512, 1536), (512, 2028), None],
# mllama has 8 possible aspect ratios, carefully set the sizes
# to cover all of them
])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])

View File

@ -24,8 +24,8 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform
from vllm.utils import (FlexibleArgumentParser, cuda_device_count_stateless,
get_open_port, is_hip)
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
cuda_device_count_stateless, get_open_port, is_hip)
if current_platform.is_rocm():
from amdsmi import (amdsmi_get_gpu_vram_usage,
@ -455,6 +455,37 @@ def fork_new_process_for_each_test(
return wrapper
def large_gpu_test(*, min_gb: int):
"""
Decorate a test to be skipped if no GPU is available or it does not have
sufficient memory.
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
"""
try:
if current_platform.is_cpu():
memory_gb = 0
else:
memory_gb = current_platform.get_device_total_memory() / GB_bytes
except Exception as e:
warnings.warn(
f"An error occurred when finding the available memory: {e}",
stacklevel=2,
)
memory_gb = 0
test_skipif = pytest.mark.skipif(
memory_gb < min_gb,
reason=f"Need at least {memory_gb}GB GPU memory to run the test.",
)
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
return test_skipif(fork_new_process_for_each_test(f))
return wrapper
def multi_gpu_test(*, num_gpus: int):
"""
Decorate a test to be run only when multiple GPUs are available.

View File

@ -1,3 +1,4 @@
import psutil
import torch
from .interface import Platform, PlatformEnum
@ -10,6 +11,10 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -59,6 +59,13 @@ def get_physical_device_name(device_id: int = 0) -> str:
return pynvml.nvmlDeviceGetName(handle)
@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_total_memory(device_id: int = 0) -> int:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total)
@with_nvml_context
def warn_if_different_devices():
device_ids: int = pynvml.nvmlDeviceGetCount()
@ -107,6 +114,11 @@ class CudaPlatform(Platform):
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_name(physical_device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_total_memory(physical_device_id)
@classmethod
@with_nvml_context
def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:

View File

@ -85,6 +85,12 @@ class Platform:
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
"""Get the total memory of a device in bytes."""
raise NotImplementedError
@classmethod

View File

@ -29,3 +29,8 @@ class RocmPlatform(Platform):
@lru_cache(maxsize=8)
def get_device_name(cls, device_id: int = 0) -> str:
return torch.cuda.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.cuda.get_device_properties(device_id)
return device_props.total_memory

View File

@ -10,6 +10,10 @@ class TpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def inference_mode(cls):
return torch.no_grad()

View File

@ -8,13 +8,15 @@ class XPUPlatform(Platform):
@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
return DeviceCapability(major=int(
torch.xpu.get_device_capability(device_id)['version'].split('.')
[0]),
minor=int(
torch.xpu.get_device_capability(device_id)
['version'].split('.')[1]))
major, minor, *_ = torch.xpu.get_device_capability(
device_id)['version'].split('.')
return DeviceCapability(major=int(major), minor=int(minor))
@staticmethod
def get_device_name(device_id: int = 0) -> str:
return torch.xpu.get_device_name(device_id)
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory

View File

@ -119,6 +119,9 @@ STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000
"""The number of bytes in one gigabyte (GB)."""
GiB_bytes = 1 << 30
"""The number of bytes in one gibibyte (GiB)."""