mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[ROCm] Skip tests for quantizations incompatible with ROCm (#17905)
Signed-off-by: Hissu Hyvarinen <hissu.hyvarinen@amd.com>
This commit is contained in:
@ -2,6 +2,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
# These ground truth generations were generated using `transformers==4.38.1
|
# These ground truth generations were generated using `transformers==4.38.1
|
||||||
# aqlm==1.1.0 torch==2.2.0`
|
# aqlm==1.1.0 torch==2.2.0`
|
||||||
@ -34,7 +35,9 @@ ground_truth_generations = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("aqlm"),
|
@pytest.mark.skipif(not is_quant_method_supported("aqlm")
|
||||||
|
or current_platform.is_rocm()
|
||||||
|
or not current_platform.is_cuda(),
|
||||||
reason="AQLM is not supported on this GPU type.")
|
reason="AQLM is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
|
@pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"])
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
|||||||
@ -55,6 +55,14 @@ def test_models(
|
|||||||
Only checks log probs match to cover the discrepancy in
|
Only checks log probs match to cover the discrepancy in
|
||||||
numerical sensitive kernels.
|
numerical sensitive kernels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if backend == "FLASHINFER" and current_platform.is_rocm():
|
||||||
|
pytest.skip("Flashinfer does not support ROCm/HIP.")
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
|
||||||
|
pytest.skip(
|
||||||
|
f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
|
||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("TOKENIZERS_PARALLELISM", 'true')
|
m.setenv("TOKENIZERS_PARALLELISM", 'true')
|
||||||
m.setenv(STR_BACKEND_ENV_VAR, backend)
|
m.setenv(STR_BACKEND_ENV_VAR, backend)
|
||||||
|
|||||||
@ -14,6 +14,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
|
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import check_logprobs_close
|
from ..utils import check_logprobs_close
|
||||||
|
|
||||||
@ -34,7 +35,9 @@ MODELS = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=3)
|
@pytest.mark.flaky(reruns=3)
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin")
|
||||||
|
or current_platform.is_rocm()
|
||||||
|
or not current_platform.is_cuda(),
|
||||||
reason="gptq_marlin is not supported on this GPU type.")
|
reason="gptq_marlin is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
@pytest.mark.parametrize("dtype", ["half", "bfloat16"])
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from dataclasses import dataclass
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import check_logprobs_close
|
from ..utils import check_logprobs_close
|
||||||
|
|
||||||
@ -38,7 +39,9 @@ model_pairs = [
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.flaky(reruns=2)
|
@pytest.mark.flaky(reruns=2)
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24")
|
||||||
|
or current_platform.is_rocm()
|
||||||
|
or not current_platform.is_cuda(),
|
||||||
reason="Marlin24 is not supported on this GPU type.")
|
reason="Marlin24 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
|||||||
Reference in New Issue
Block a user