Files
vllm-ascend/tests/conftest.py
Yikun Jiang d5e7756028 [Core] Init vllm-ascend (#3)
### What this PR does / why we need it?
vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on
the Ascend NPU.

This plugin is the recommended approach for supporting the Ascend
backend within the vLLM community. It adheres to the principles outlined
in the [RFC]: Hardware pluggable, providing a hardware-pluggable
interface that decouples the integration of the Ascend NPU with vLLM.

This patch also include changes to make CI work and use cache speed up
e2e test, including:
1. Change push (post merge ci) and pull_request (pr ci) trigger branch
to main
   2. Make mypy work by ignore base_communicator and clear unused deps
   3. Several improvements for vllm_ascend_test:
     - use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins)
- switch `git clone` command to `action/checkout` to speedup checkout
and
     - Enable sv for pytest for better info dump
- Remove network host to resole `docker: conflicting ontions: cannot
attach both user-defined and non-user-definednetwork-modes`, which is a
problem on docker 1.45 but not on 1.39.
4. Adapt MLA decode optimizations:
cabaf4eff3

### Does this PR introduce _any_ user-facing change?
Yes, init the PR.

### How was this patch tested?
- This is the first PR to make ascend NPU work on vLLM. All code is
tested on ascend with vLLM V0 Engine.
- CI passed

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: wangshuai09 <391746016@qq.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
2025-02-05 10:53:12 +08:00

332 lines
12 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
from PIL import Image
from vllm import LLM, SamplingParams
from vllm.config import TaskOption
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import is_list_of
from tests.model_utils import (TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
logger = init_logger(__name__)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
class VllmRunner:
def __init__(
self,
model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = False,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
**kwargs,
)
def get_inputs(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids,
stop=stop)
return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)
def generate_beam_search(
self,
prompts: Union[List[str], List[List[int]]],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]
texts = [x.text for x in output.sequences]
returned_outputs.append((token_ids, texts))
return returned_outputs
def classify(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]
def encode(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[List[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def score(
self,
text_1: Union[str, List[str]],
text_2: Union[str, List[str]],
) -> List[float]:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.score for req_output in req_outputs]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner