mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[CI/Build] Add is_quant_method_supported
to control quantization test configurations (#5253)
This commit is contained in:
@ -4,17 +4,8 @@ Run `pytest tests/models/test_aqlm.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
aqlm_not_supported = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
aqlm_not_supported = (capability <
|
||||
QUANTIZATION_METHODS["aqlm"].get_min_capability())
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
# In this test we hardcode prompts and generations for the model so we don't
|
||||
# need to require the AQLM package as a dependency
|
||||
@ -67,7 +58,7 @@ ground_truth_generations = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.skipif(aqlm_not_supported,
|
||||
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
|
||||
reason="AQLM is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
|
@ -8,8 +8,8 @@ import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
@ -67,16 +67,8 @@ EXPECTED_STRS_MAP = {
|
||||
},
|
||||
}
|
||||
|
||||
fp8_not_supported = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
fp8_not_supported = (capability <
|
||||
QUANTIZATION_METHODS["fp8"].get_min_capability())
|
||||
|
||||
|
||||
@pytest.mark.skipif(fp8_not_supported,
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
|
@ -11,9 +11,8 @@ Run `pytest tests/models/test_gptq_marlin.py`.
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
@ -22,14 +21,6 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
MAX_MODEL_LEN = 1024
|
||||
|
||||
gptq_marlin_not_supported = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
gptq_marlin_not_supported = (
|
||||
capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability())
|
||||
|
||||
MODELS = [
|
||||
# act_order==False, group_size=channelwise
|
||||
("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"),
|
||||
@ -53,7 +44,7 @@ MODELS = [
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=3)
|
||||
@pytest.mark.skipif(gptq_marlin_not_supported,
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="gptq_marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
||||
|
@ -9,18 +9,9 @@ Run `pytest tests/models/test_marlin_24.py`.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.models.utils import check_logprobs_close
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
marlin_not_supported = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
marlin_not_supported = (
|
||||
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -47,7 +38,7 @@ model_pairs = [
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(marlin_not_supported,
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
|
||||
reason="Marlin24 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
|
@ -13,20 +13,11 @@ Run `pytest tests/models/test_marlin.py`.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
marlin_not_supported = True
|
||||
|
||||
if torch.cuda.is_available():
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
marlin_not_supported = (
|
||||
capability < QUANTIZATION_METHODS["marlin"].get_min_capability())
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
@ -45,7 +36,7 @@ model_pairs = [
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(marlin_not_supported,
|
||||
@pytest.mark.skipif(not is_quant_method_supported("marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
|
@ -5,16 +5,12 @@ Run `pytest tests/quantization/test_bitsandbytes.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
|
||||
reason='bitsandbytes is not supported on this GPU type.')
|
||||
def test_load_bnb_model(vllm_runner) -> None:
|
||||
with vllm_runner('huggyllama/llama-7b',
|
||||
quantization='bitsandbytes',
|
||||
|
@ -5,16 +5,12 @@ Run `pytest tests/quantization/test_fp8.py --forked`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
capability < QUANTIZATION_METHODS["fp8"].get_min_capability(),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.")
|
||||
def test_load_fp16_model(vllm_runner) -> None:
|
||||
with vllm_runner("facebook/opt-125m", quantization="fp8") as llm:
|
||||
|
||||
|
14
tests/quantization/utils.py
Normal file
14
tests/quantization/utils.py
Normal file
@ -0,0 +1,14 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
|
||||
def is_quant_method_supported(quant_method: str) -> bool:
|
||||
# Currently, all quantization methods require Nvidia or AMD GPUs
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
return (capability <
|
||||
QUANTIZATION_METHODS[quant_method].get_min_capability())
|
Reference in New Issue
Block a user