[Misc] add w8a8 asym models (#11075)

This commit is contained in:
Dipika Sikka
2024-12-23 13:33:20 -05:00
committed by GitHub
parent b866cdbd05
commit 8cef6e02dc

View File

@ -79,12 +79,12 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
assert output
@pytest.mark.parametrize(
"model_path",
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
# TODO static & asymmetric
])
@pytest.mark.parametrize("model_path", [
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
@ -92,6 +92,10 @@ def test_compressed_tensors_w8a8_logprobs(hf_runner, vllm_runner,
max_tokens, num_logprobs):
dtype = "bfloat16"
# skip language translation prompt for the static per tensor asym model
if model_path == "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym": # noqa: E501
example_prompts = example_prompts[0:-1]
with hf_runner(model_path, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)