mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[misc][ci] fix quant test (#8449)
This commit is contained in:
@ -10,6 +10,8 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
|
||||
models_4bit_to_test = [
|
||||
('huggyllama/llama-7b', 'quantize model inflight'),
|
||||
]
|
||||
@ -29,6 +31,7 @@ models_pre_quant_8bit_to_test = [
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
|
||||
@fork_new_process_for_each_test
|
||||
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
@ -41,6 +44,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
models_pre_qaunt_4bit_to_test)
|
||||
@fork_new_process_for_each_test
|
||||
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
@ -52,6 +56,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.parametrize("model_name, description",
|
||||
models_pre_quant_8bit_to_test)
|
||||
@fork_new_process_for_each_test
|
||||
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
|
||||
model_name, description) -> None:
|
||||
|
||||
@ -77,18 +82,8 @@ def validate_generated_texts(hf_runner,
|
||||
model_name,
|
||||
hf_model_kwargs=None):
|
||||
|
||||
if hf_model_kwargs is None:
|
||||
hf_model_kwargs = {}
|
||||
|
||||
# Run with HF runner
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
hf_outputs = llm.generate_greedy(prompts, 8)
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# NOTE: run vLLM first, as it requires a clean process
|
||||
# when using distributed inference
|
||||
|
||||
#Run with vLLM runner
|
||||
with vllm_runner(model_name,
|
||||
@ -104,6 +99,19 @@ def validate_generated_texts(hf_runner,
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if hf_model_kwargs is None:
|
||||
hf_model_kwargs = {}
|
||||
|
||||
# Run with HF runner
|
||||
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
|
||||
hf_outputs = llm.generate_greedy(prompts, 8)
|
||||
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
|
||||
|
||||
# Clean up the GPU memory for the next test
|
||||
torch.cuda.synchronize()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Compare the generated strings
|
||||
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
|
||||
hf_str = hf_log["generated_text"]
|
||||
|
@ -1,12 +1,10 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def is_quant_method_supported(quant_method: str) -> bool:
|
||||
# Currently, all quantization methods require Nvidia or AMD GPUs
|
||||
if not torch.cuda.is_available():
|
||||
if not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||
return False
|
||||
|
||||
capability = current_platform.get_device_capability()
|
||||
|
Reference in New Issue
Block a user