diff --git a/tests/models/quantization/test_aqlm.py b/tests/models/quantization/test_aqlm.py index 548053b7ae..1272a62974 100644 --- a/tests/models/quantization/test_aqlm.py +++ b/tests/models/quantization/test_aqlm.py @@ -2,6 +2,7 @@ import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform # These ground truth generations were generated using `transformers==4.38.1 # aqlm==1.1.0 torch==2.2.0` @@ -34,7 +35,9 @@ ground_truth_generations = [ ] -@pytest.mark.skipif(not is_quant_method_supported("aqlm"), +@pytest.mark.skipif(not is_quant_method_supported("aqlm") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="AQLM is not supported on this GPU type.") @pytest.mark.parametrize("model", ["ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf"]) @pytest.mark.parametrize("dtype", ["half"]) diff --git a/tests/models/quantization/test_fp8.py b/tests/models/quantization/test_fp8.py index 4d15675a3a..e01ee20263 100644 --- a/tests/models/quantization/test_fp8.py +++ b/tests/models/quantization/test_fp8.py @@ -55,6 +55,14 @@ def test_models( Only checks log probs match to cover the discrepancy in numerical sensitive kernels. """ + + if backend == "FLASHINFER" and current_platform.is_rocm(): + pytest.skip("Flashinfer does not support ROCm/HIP.") + + if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm(): + pytest.skip( + f"{kv_cache_dtype} is currently not supported on ROCm/HIP.") + with monkeypatch.context() as m: m.setenv("TOKENIZERS_PARALLELISM", 'true') m.setenv(STR_BACKEND_ENV_VAR, backend) diff --git a/tests/models/quantization/test_gptq_marlin.py b/tests/models/quantization/test_gptq_marlin.py index 680134c6ea..397bdb9812 100644 --- a/tests/models/quantization/test_gptq_marlin.py +++ b/tests/models/quantization/test_gptq_marlin.py @@ -14,6 +14,7 @@ import pytest from tests.quantization.utils import is_quant_method_supported from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT +from vllm.platforms import current_platform from ..utils import check_logprobs_close @@ -34,7 +35,9 @@ MODELS = [ @pytest.mark.flaky(reruns=3) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="gptq_marlin is not supported on this GPU type.") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half", "bfloat16"]) diff --git a/tests/models/quantization/test_gptq_marlin_24.py b/tests/models/quantization/test_gptq_marlin_24.py index ce28f964d5..6fb24b1f43 100644 --- a/tests/models/quantization/test_gptq_marlin_24.py +++ b/tests/models/quantization/test_gptq_marlin_24.py @@ -10,6 +10,7 @@ from dataclasses import dataclass import pytest from tests.quantization.utils import is_quant_method_supported +from vllm.platforms import current_platform from ..utils import check_logprobs_close @@ -38,7 +39,9 @@ model_pairs = [ @pytest.mark.flaky(reruns=2) -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24"), +@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin_24") + or current_platform.is_rocm() + or not current_platform.is_cuda(), reason="Marlin24 is not supported on this GPU type.") @pytest.mark.parametrize("model_pair", model_pairs) @pytest.mark.parametrize("dtype", ["half"])