[torch.compile] rework test plans (#9866)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-31 22:20:17 -07:00
committed by GitHub
parent 37a4947dcd
commit 566cd27797
4 changed files with 226 additions and 31 deletions

View File

@ -1,3 +1,4 @@
import dataclasses
from typing import Dict, List, Optional
import pytest
@ -8,33 +9,109 @@ from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings
@dataclasses.dataclass
class TestSetting:
model: str
model_args: List[str]
pp_size: int
tp_size: int
attn_backend: str
method: str
fullgraph: bool
# representative settings for testing
test_settings = [
# basic llama model
TestSetting(
model="meta-llama/Llama-3.2-1B",
model_args=[],
pp_size=2,
tp_size=2,
attn_backend="FLASHINFER",
method="generate",
fullgraph=True,
),
# llama model with quantization
TestSetting(
model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
model_args=["--quantization", "gptq"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# MoE model
TestSetting(
model="ibm/PowerMoE-3b",
model_args=[],
pp_size=1,
tp_size=2,
attn_backend="FLASH_ATTN",
method="generate",
fullgraph=True,
),
# embedding model
TestSetting(
model="BAAI/bge-multilingual-gemma2",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="FLASHINFER",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
model_args=["--trust-remote-code", "--max-model-len", "2048"],
pp_size=2,
tp_size=1,
attn_backend="FLASH_ATTN",
method="generate_with_image",
fullgraph=False,
),
]
# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
@pytest.mark.parametrize("test_setting", test_settings)
def test_compile_correctness(test_setting: TestSetting):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
model = test_setting.model
model_args = test_setting.model_args
pp_size = test_setting.pp_size
tp_size = test_setting.tp_size
attn_backend = test_setting.attn_backend
method = test_setting.method
fullgraph = test_setting.fullgraph
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] +
["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]
all_envs: List[Optional[Dict[str, str]]] = []
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
compare_all_settings(model, all_args, all_envs, method=method)
compare_all_settings(model, [final_args] * 3, all_envs, method=method)

View File

@ -1,4 +1,5 @@
import asyncio
import copy
import functools
import os
import signal
@ -8,13 +9,14 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union
import openai
import pytest
import requests
import torch
from openai.types.completion import Completion
from typing_extensions import ParamSpec, assert_never
from typing_extensions import ParamSpec
import vllm.envs as envs
from tests.models.utils import TextTextLogprobs
@ -272,6 +274,31 @@ def _test_completion(
return results
def _test_completion_close(
client: openai.OpenAI,
model: str,
prompt: str,
):
results = []
# test with text prompt
completion = client.completions.create(model=model,
prompt=prompt,
max_tokens=1,
logprobs=5,
temperature=0.0)
logporbs = completion.choices[0].logprobs.top_logprobs[0]
logporbs = {k: round(v, 2) for k, v in logporbs.items()}
results.append({
"test": "completion_close",
"logprobs": logporbs,
})
return results
def _test_embeddings(
client: openai.OpenAI,
model: str,
@ -295,13 +322,81 @@ def _test_embeddings(
return results
def _test_image_text(
client: openai.OpenAI,
model_name: str,
image_url: str,
):
results = []
# test pure text input
messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": "How do you feel today?"
},
],
}]
chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
for x in top_logprobs:
x.logprob = round(x.logprob, 2)
results.append({
"test": "pure_text",
"logprobs": top_logprobs,
})
messages = [{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": "What's in this image?"
},
],
}]
chat_completion = client.chat.completions.create(model=model_name,
messages=messages,
temperature=0.0,
max_tokens=1,
logprobs=True,
top_logprobs=5)
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
results.append({
"test": "text_image",
"logprobs": top_logprobs,
})
return results
def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None,
*,
method: Literal["generate", "encode"] = "generate",
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with two different sets of arguments/environments
@ -328,7 +423,7 @@ def compare_all_settings(model: str,
all_args: List[List[str]],
all_envs: List[Optional[Dict[str, str]]],
*,
method: Literal["generate", "encode"] = "generate",
method: str = "generate",
max_wait_seconds: Optional[float] = None) -> None:
"""
Launch API server with several different sets of arguments/environments
@ -397,10 +492,17 @@ def compare_all_settings(model: str,
if method == "generate":
results += _test_completion(client, model, prompt, token_ids)
elif method == "generate_close":
results += _test_completion_close(client, model, prompt)
elif method == "generate_with_image":
results += _test_image_text(
client, model,
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
)
elif method == "encode":
results += _test_embeddings(client, model, prompt)
else:
assert_never(method)
raise ValueError(f"Unknown method: {method}")
if i > 0:
# if any setting fails, raise an error early
@ -410,6 +512,18 @@ def compare_all_settings(model: str,
compare_envs = all_envs[i]
for ref_result, compare_result in zip(ref_results,
compare_results):
ref_result = copy.deepcopy(ref_result)
compare_result = copy.deepcopy(compare_result)
if "embedding" in ref_result and method == "encode":
ref_embedding = torch.tensor(ref_result["embedding"])
compare_embedding = torch.tensor(
compare_result["embedding"])
mse = ((ref_embedding - compare_embedding)**2).mean()
assert mse < 1e-6, (
f"Embedding for {model=} are not the same.\n"
f"mse={mse}\n")
del ref_result["embedding"]
del compare_result["embedding"]
assert ref_result == compare_result, (
f"Results for {model=} are not the same.\n"
f"{ref_args=} {ref_envs=}\n"

View File

@ -493,13 +493,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
:class:`LlavaImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
@ -511,7 +507,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
else:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -679,7 +679,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
@ -690,9 +689,14 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,