[Test] Remove VLLM_USE_V1 in example and tests (#1733)

V1 is enabled by default, no need to set it by hand now. This PR remove
the useless setting in example and tests

- vLLM version: v0.9.2
- vLLM main:
9ad0a4588b

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-15 12:49:57 +08:00
committed by GitHub
parent eb921d2b6f
commit 787010a637
29 changed files with 186 additions and 291 deletions

View File

@ -41,16 +41,10 @@ concurrency:
jobs: jobs:
lint: lint:
# Only trigger lint on pull request
if: ${{ github.event_name == 'pull_request' }}
uses: ./.github/workflows/pre-commit.yml uses: ./.github/workflows/pre-commit.yml
changes: changes:
# Only trigger changes on pull request
if: ${{ github.event_name == 'pull_request' }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
pull-requests: read
outputs: outputs:
e2e_tracker: ${{ steps.filter.outputs.e2e_tracker }} e2e_tracker: ${{ steps.filter.outputs.e2e_tracker }}
ut_tracker: ${{ steps.filter.outputs.ut_tracker }} ut_tracker: ${{ steps.filter.outputs.ut_tracker }}
@ -60,20 +54,24 @@ jobs:
with: with:
filters: | filters: |
e2e_tracker: e2e_tracker:
- '.github/workflows/vllm_ascend_test.yaml'
- 'vllm_ascend/**' - 'vllm_ascend/**'
- 'csrc/**' - 'csrc/**'
- 'cmake/**' - 'cmake/**'
- 'tests/e2e/**' - 'tests/e2e/**'
- 'tests/conftest.py' - 'CMakeLists.txt'
- 'tests/model_utils.py' - 'setup.py'
- 'tests/utils.py' - 'requirements.txt'
- 'requirements-dev.txt'
- 'requirements-lint.txt'
- 'packages.txt'
ut_tracker: ut_tracker:
- 'tests/ut/**' - 'tests/ut/**'
ut: ut:
needs: [lint, changes] needs: [lint, changes]
name: unit test name: unit test
# only trigger unit test after lint passed and the change is e2e and ut related. Or the PR is merged. # only trigger unit test after lint passed and the change is e2e and ut related.
if: ${{ github.event_name == 'push' || (needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true')) }} if: ${{ needs.lint.result == 'success' && (needs.changes.outputs.e2e_tracker == 'true' || needs.changes.outputs.ut_tracker == 'true') }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10 image: quay.io/ascend/cann:8.1.rc1-910b-ubuntu22.04-py3.10
@ -112,9 +110,8 @@ jobs:
python3 -m pip install -r requirements-dev.txt --extra-index https://download.pytorch.org/whl/cpu/ python3 -m pip install -r requirements-dev.txt --extra-index https://download.pytorch.org/whl/cpu/
python3 -m pip install -v . --extra-index https://download.pytorch.org/whl/cpu/ python3 -m pip install -v . --extra-index https://download.pytorch.org/whl/cpu/
- name: Run unit test for V1 Engine - name: Run unit test
env: env:
VLLM_USE_V1: 1
VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_WORKER_MULTIPROC_METHOD: spawn
TORCH_DEVICE_BACKEND_AUTOLOAD: 0 TORCH_DEVICE_BACKEND_AUTOLOAD: 0
run: | run: |
@ -133,8 +130,8 @@ jobs:
e2e: e2e:
needs: [lint, changes] needs: [lint, changes]
# only trigger e2e test after lint passed and the change is e2e related. # only trigger e2e test after lint passed and the change is e2e related with pull request.
if: ${{ needs.lint.result == 'success' && needs.changes.outputs.e2e_tracker == 'true' }} if: ${{ github.event_name == 'pull_request' && needs.lint.result == 'success' && needs.changes.outputs.e2e_tracker == 'true' }}
strategy: strategy:
max-parallel: 2 max-parallel: 2
matrix: matrix:
@ -189,9 +186,8 @@ jobs:
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
pip install -v -e . pip install -v -e .
- name: Run e2e test for V1 Engine - name: Run e2e test
env: env:
VLLM_USE_V1: 1
VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True VLLM_USE_MODELSCOPE: True
run: | run: |
@ -213,26 +209,6 @@ jobs:
# TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed # TODO: revert me when test_v1_spec_decode.py::test_ngram_correctness is fixed
VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/singlecard/spec_decode_v1/test_v1_spec_decode.py
- name: Run e2e test on V0 engine
if: ${{ github.event_name == 'schedule' }}
env:
VLLM_USE_V1: 0
VLLM_USE_MODELSCOPE: True
run: |
pytest -sv tests/e2e/singlecard/test_offline_inference.py
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
pytest -sv tests/e2e/singlecard/test_camem.py
pytest -sv tests/e2e/singlecard/test_prompt_embedding.py
pytest -sv tests/e2e/singlecard/test_embedding.py
pytest -sv tests/e2e/singlecard/ \
--ignore=tests/e2e/singlecard/test_offline_inference.py \
--ignore=tests/e2e/singlecard/test_ilama_lora.py \
--ignore=tests/e2e/singlecard/test_guided_decoding.py \
--ignore=tests/e2e/singlecard/test_camem.py \
--ignore=tests/e2e/singlecard/test_prompt_embedding.py \
--ignore=tests/e2e/singlecard/test_embedding.py
e2e-4-cards: e2e-4-cards:
needs: [e2e] needs: [e2e]
if: ${{ needs.e2e.result == 'success' }} if: ${{ needs.e2e.result == 'success' }}
@ -290,9 +266,8 @@ jobs:
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
pip install -v -e . pip install -v -e .
- name: Run vllm-project/vllm-ascend test for V1 Engine - name: Run vllm-project/vllm-ascend test
env: env:
VLLM_USE_V1: 1
VLLM_WORKER_MULTIPROC_METHOD: spawn VLLM_WORKER_MULTIPROC_METHOD: spawn
VLLM_USE_MODELSCOPE: True VLLM_USE_MODELSCOPE: True
run: | run: |
@ -308,19 +283,3 @@ jobs:
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \ pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \ --ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
--ignore=tests/e2e/multicard/test_data_parallel.py --ignore=tests/e2e/multicard/test_data_parallel.py
- name: Run vllm-project/vllm-ascend test on V0 engine
if: ${{ github.event_name == 'schedule' }}
env:
VLLM_USE_V1: 0
VLLM_USE_MODELSCOPE: True
run: |
pytest -sv tests/e2e/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py will raise error.
# To avoid oom, we need to run the test in a single process.
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
pytest -sv tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/ --ignore=tests/e2e/multicard/test_ilama_lora_tp2.py \
--ignore=tests/e2e/multicard/test_offline_inference_distributed.py \
--ignore=tests/e2e/multicard/test_data_parallel.py

View File

@ -120,7 +120,6 @@ def main(
trust_remote_code, trust_remote_code,
): ):
# DP only support on V1 engine # DP only support on V1 engine
os.environ["VLLM_USE_V1"] = "1"
os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size) os.environ["VLLM_DP_SIZE"] = str(dp_size)

View File

@ -5,7 +5,6 @@ from vllm import LLM, SamplingParams
# enable dual-batch overlap for vllm ascend # enable dual-batch overlap for vllm ascend
os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1" os.environ["VLLM_ASCEND_ENABLE_DBO"] = "1"
os.environ["VLLM_USE_V1"] = "1"
# Sample prompts. # Sample prompts.
prompts = ["The president of the United States is"] * 41 prompts = ["The president of the United States is"] * 41

View File

@ -22,7 +22,6 @@ import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
os.environ["VLLM_USE_V1"] = "1"
os.environ["VLLM_USE_MODELSCOPE"] = "True" os.environ["VLLM_USE_MODELSCOPE"] = "True"
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

View File

@ -1,4 +1,3 @@
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1 export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh source /usr/local/Ascend/nnal/atb/set_env.sh

View File

@ -12,4 +12,5 @@ xgrammar
zmq zmq
types-psutil types-psutil
pytest-cov pytest-cov
regex
sentence_transformers sentence_transformers

View File

@ -4,5 +4,6 @@ pre-commit==4.0.1
# type checking # type checking
mypy==1.11.1 mypy==1.11.1
types-PyYAML types-PyYAML
types-regex
types-requests types-requests
types-setuptools types-setuptools

View File

@ -39,8 +39,8 @@ from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of from vllm.utils import is_list_of
from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs, from tests.e2e.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs) TokensTextLogprobsPromptLogprobs)
# TODO: remove this part after the patch merged into vllm, if # TODO: remove this part after the patch merged into vllm, if
# we not explicitly patch here, some of them might be effectiveless # we not explicitly patch here, some of them might be effectiveless
# in pytest scenario # in pytest scenario
@ -62,7 +62,7 @@ PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray] PromptVideoInput = _PromptMultiModalInput[np.ndarray]
_TEST_DIR = os.path.dirname(__file__) _TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "e2e", "prompts", "example.txt")] _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
def cleanup_dist_env_and_memory(shutdown_ray: bool = False): def cleanup_dist_env_and_memory(shutdown_ray: bool = False):

View File

@ -26,12 +26,11 @@ from unittest.mock import patch
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams from vllm import SamplingParams
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
@patch.dict( @patch.dict(
os.environ, { os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn", "VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1", "TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1" "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
@ -56,12 +55,10 @@ def test_generate_with_allgather():
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)
@patch.dict( @patch.dict(os.environ, {
os.environ, { "VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"VLLM_USE_V1": "1", "TASK_QUEUE_ENABLE": "1"
"VLLM_WORKER_MULTIPROC_METHOD": "spawn", })
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall(): def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"] example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0) sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
@ -79,4 +76,4 @@ def test_generate_with_alltoall():
}, },
"expert_tensor_parallel_size": 1 "expert_tensor_parallel_size": 1
}) as vllm_model: }) as vllm_model:
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)

View File

@ -1,7 +1,7 @@
import pytest import pytest
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT, from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT,
MODEL_PATH, do_sample) MODEL_PATH, do_sample)

View File

@ -27,7 +27,7 @@ from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams from vllm import SamplingParams
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"

View File

@ -16,7 +16,7 @@
# #
import pytest import pytest
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
MODELS = [ MODELS = [
"Qwen/Qwen3-0.6B", "Qwen/Qwen3-0.6B",

View File

@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching on V1 scheduler or AscendScheduler.""" """Compare the with and without prefix caching on V1 scheduler or AscendScheduler."""
import os
import pytest import pytest
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
from tests.model_utils import check_outputs_equal from tests.e2e.model_utils import check_outputs_equal
MODELS = [ MODELS = [
# for MHA # for MHA
@ -60,8 +58,6 @@ INPUT_PROMPTS = [
] ]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="mtp is not supported on v1")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50]) @pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None: def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None:
@ -89,8 +85,6 @@ def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None:
) )
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="mtp is not supported on v1")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50]) @pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_ascend_scheduler(model: str, def test_prefix_cache_with_ascend_scheduler(model: str,

View File

@ -22,9 +22,7 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`.
import os import os
from typing import Dict from typing import Dict
import pytest from tests.e2e.conftest import VllmRunner
from tests.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@ -78,8 +76,6 @@ def _deepseek_torchair_test_fixture(
print(f"Generated text: {vllm_output[i][1]!r}") print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair(): def test_e2e_deepseekv3_with_torchair():
additional_config = { additional_config = {
"torchair_graph_config": { "torchair_graph_config": {
@ -89,8 +85,6 @@ def test_e2e_deepseekv3_with_torchair():
_deepseek_torchair_test_fixture(additional_config) _deepseek_torchair_test_fixture(additional_config)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair_ms_mla(): def test_e2e_deepseekv3_with_torchair_ms_mla():
additional_config = { additional_config = {
"torchair_graph_config": { "torchair_graph_config": {
@ -150,8 +144,6 @@ def _pangu_torchair_test_fixture(
print(f"Generated text: {vllm_output[i][1]!r}") print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_pangu_with_torchair(): def test_e2e_pangu_with_torchair():
additional_config = { additional_config = {
"torchair_graph_config": { "torchair_graph_config": {

View File

@ -1,15 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc import gc
import os
import pytest import pytest
import torch import torch
from vllm import LLM from vllm import LLM
if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True)
MODEL = "Qwen/Qwen2.5-0.5B-Instruct" MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
PROMPT = "Hello my name is Robert and I" PROMPT = "Hello my name is Robert and I"

View File

@ -9,8 +9,8 @@ Run `pytest tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py`.
""" """
import pytest import pytest
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
from tests.model_utils import check_outputs_equal from tests.e2e.model_utils import check_outputs_equal
MODELS = [ MODELS = [
"Qwen/Qwen3-0.6B-Base", "Qwen/Qwen3-0.6B-Base",

View File

@ -53,7 +53,6 @@ def model_name():
@pytest.mark.skipif( @pytest.mark.skipif(
True, reason="TODO: Enable me after test_mtp_correctness is fixed") True, reason="TODO: Enable me after test_mtp_correctness is fixed")
def test_mtp_correctness( def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
@ -62,33 +61,30 @@ def test_mtp_correctness(
Compare the outputs of a original LLM and a speculative LLM Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding. should be the same when using mtp speculative decoding.
''' '''
with monkeypatch.context() as m: ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
m.setenv("VLLM_USE_V1", "1") ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True) spec_llm = LLM(model=model_name,
ref_outputs = ref_llm.chat(test_prompts, sampling_config) trust_remote_code=True,
del ref_llm speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_llm = LLM(model=model_name, # Heuristic: expect at least 66% of the prompts to match exactly
trust_remote_code=True, # Upon failure, inspect the outputs to check for inaccuracy.
speculative_config={ assert matches > int(0.66 * len(ref_outputs))
"method": "deepseek_mtp", del spec_llm
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm

View File

@ -60,7 +60,6 @@ def eagle3_model_name():
def test_ngram_correctness( def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
@ -70,44 +69,40 @@ def test_ngram_correctness(
should be the same when using ngram speculative decoding. should be the same when using ngram speculative decoding.
''' '''
pytest.skip("Not current support for the test.") pytest.skip("Not current support for the test.")
with monkeypatch.context() as m: ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
m.setenv("VLLM_USE_V1", "1") ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True) spec_llm = LLM(
ref_outputs = ref_llm.chat(test_prompts, sampling_config) model=model_name,
del ref_llm speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_llm = LLM( # Heuristic: expect at least 70% of the prompts to match exactly
model=model_name, # Upon failure, inspect the outputs to check for inaccuracy.
speculative_config={ assert matches > int(0.7 * len(ref_outputs))
"method": "ngram", del spec_llm
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"]) @pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness( def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]], test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams, sampling_config: SamplingParams,
model_name: str, model_name: str,
@ -119,43 +114,40 @@ def test_eagle_correctness(
''' '''
if not use_eagle3: if not use_eagle3:
pytest.skip("Not current support for the test.") pytest.skip("Not current support for the test.")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True) ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config) ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm del ref_llm
spec_model_name = eagle3_model_name( spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
) if use_eagle3 else eagle_model_name() spec_llm = LLM(
spec_llm = LLM( model=model_name,
model=model_name, trust_remote_code=True,
trust_remote_code=True, enable_chunked_prefill=True,
enable_chunked_prefill=True, max_num_seqs=1,
max_num_seqs=1, max_num_batched_tokens=2048,
max_num_batched_tokens=2048, gpu_memory_utilization=0.6,
gpu_memory_utilization=0.6, speculative_config={
speculative_config={ "method": "eagle3" if use_eagle3 else "eagle",
"method": "eagle3" if use_eagle3 else "eagle", "model": spec_model_name,
"model": spec_model_name, "num_speculative_tokens": 2,
"num_speculative_tokens": 2, "max_model_len": 128,
"max_model_len": 128, },
}, max_model_len=128,
max_model_len=128, enforce_eager=True,
enforce_eager=True, )
) spec_outputs = spec_llm.chat(test_prompts, sampling_config)
spec_outputs = spec_llm.chat(test_prompts, sampling_config) matches = 0
matches = 0 misses = 0
misses = 0 for ref_output, spec_output in zip(ref_outputs, spec_outputs):
for ref_output, spec_output in zip(ref_outputs, spec_outputs): if ref_output.outputs[0].text == spec_output.outputs[0].text:
if ref_output.outputs[0].text == spec_output.outputs[0].text: matches += 1
matches += 1 else:
else: misses += 1
misses += 1 print(f"ref_output: {ref_output.outputs[0].text}")
print(f"ref_output: {ref_output.outputs[0].text}") print(f"spec_output: {spec_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly # Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy. # Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs)) assert matches > int(0.66 * len(ref_outputs))
del spec_llm del spec_llm

View File

@ -20,14 +20,12 @@ Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`. Run `pytest tests/compile/test_aclgraph.py`.
""" """
import os
import pytest import pytest
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
from tests.model_utils import check_outputs_equal from tests.e2e.model_utils import check_outputs_equal
MODELS = [ MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",
@ -36,37 +34,29 @@ MODELS = [
] ]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("max_tokens", [32])
def test_models( def test_models(
model: str, model: str,
max_tokens: int, max_tokens: int,
monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
with monkeypatch.context() as m: prompts = [
prompts = [ "Hello, my name is", "The president of the United States is",
"Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is"
"The capital of France is", "The future of AI is" ]
]
# aclgraph only support on v1 sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
m.setenv("VLLM_USE_V1", "1") # TODO: change to use vllmrunner when the registry of custom op is solved
# while running pytest
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
sampling_params = SamplingParams(max_tokens=max_tokens, vllm_model = LLM(model, enforce_eager=True)
temperature=0.0) vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
# TODO: change to use vllmrunner when the registry of custom op is solved del vllm_model
# while running pytest torch.npu.empty_cache()
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model, enforce_eager=True)
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
vllm_aclgraph_outputs_list = [] vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs: for output in vllm_aclgraph_outputs:
@ -86,12 +76,9 @@ def test_models(
) )
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None: def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True") m.setenv("VLLM_USE_MODELSCOPE", "True")
m.setenv("VLLM_USE_V1", "1")
with pytest.raises(NotImplementedError) as excinfo: with pytest.raises(NotImplementedError) as excinfo:
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat", VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=1024, max_model_len=1024,

View File

@ -21,7 +21,7 @@ import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes from vllm.utils import GiB_bytes
from tests.utils import fork_new_process_for_each_test from tests.e2e.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.device_allocator.camem import CaMemAllocator

View File

@ -20,8 +20,6 @@ Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`. Run `pytest tests/compile/test_aclgraph.py`.
""" """
import os
import pytest import pytest
import torch import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -29,8 +27,6 @@ from vllm import LLM, SamplingParams
MODELS = ["deepseek-ai/DeepSeek-V2-Lite"] MODELS = ["deepseek-ai/DeepSeek-V2-Lite"]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="new chunked only support on v1")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [1]) @pytest.mark.parametrize("max_tokens", [1])
def test_models( def test_models(
@ -39,36 +35,33 @@ def test_models(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
) -> None: ) -> None:
return return
with monkeypatch.context() as m:
prompts = "The president of the United States is"
m.setenv("VLLM_USE_V1", "1") prompts = "The president of the United States is"
sampling_params = SamplingParams( sampling_params = SamplingParams(
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=0.0, temperature=0.0,
) )
vllm_model = LLM(model, vllm_model = LLM(model, long_prefill_token_threshold=4, enforce_eager=True)
long_prefill_token_threshold=4, output_chunked = vllm_model.generate(prompts, sampling_params)
enforce_eager=True) logprobs_chunked = output_chunked.outputs[0].logprobs
output_chunked = vllm_model.generate(prompts, sampling_params) del vllm_model
logprobs_chunked = output_chunked.outputs[0].logprobs torch.npu.empty_cache()
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model, vllm_model = LLM(model,
enforce_eager=True, enforce_eager=True,
additional_config={ additional_config={
'ascend_scheduler_config': { 'ascend_scheduler_config': {
'enabled': True 'enabled': True
}, },
}) })
output = vllm_model.generate(prompts, sampling_params) output = vllm_model.generate(prompts, sampling_params)
logprobs = output.outputs[0].logprobs logprobs = output.outputs[0].logprobs
del vllm_model del vllm_model
torch.npu.empty_cache() torch.npu.empty_cache()
logprobs_similarity = torch.cosine_similarity( logprobs_similarity = torch.cosine_similarity(logprobs_chunked.flatten(),
logprobs_chunked.flatten(), logprobs.flatten(), dim=0) logprobs.flatten(),
assert logprobs_similarity > 0.95 dim=0)
assert logprobs_similarity > 0.95

View File

@ -21,8 +21,8 @@ from typing import Optional
from modelscope import snapshot_download # type: ignore[import-untyped] from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.conftest import HfRunner from tests.e2e.conftest import HfRunner
from tests.utils import check_embeddings_close, matryoshka_fy from tests.e2e.utils import check_embeddings_close, matryoshka_fy
def run_embedding_correctness_test( def run_embedding_correctness_test(

View File

@ -18,14 +18,14 @@
# #
import json import json
import os import os
import re
import jsonschema import jsonschema
import pytest import pytest
import regex as re
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256" os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
@ -85,11 +85,7 @@ def sample_json_schema():
def check_backend(guided_decoding_backend: str): def check_backend(guided_decoding_backend: str):
if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv( if guided_decoding_backend not in GuidedDecodingBackendV1:
"VLLM_USE_V1") == "0":
pytest.skip(f"{guided_decoding_backend} does not support v0, skip it.")
if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv(
"VLLM_USE_V1") == "1":
pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.") pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.")

View File

@ -3,7 +3,7 @@ import vllm
from modelscope import snapshot_download # type: ignore from modelscope import snapshot_download # type: ignore
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
MODEL_PATH = "vllm-ascend/ilama-3.2-1B" MODEL_PATH = "vllm-ascend/ilama-3.2-1B"

View File

@ -30,7 +30,7 @@ from vllm import SamplingParams
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
import vllm_ascend # noqa: F401 import vllm_ascend # noqa: F401
from tests.conftest import VllmRunner from tests.e2e.conftest import VllmRunner
MODELS = [ MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",

View File

@ -14,7 +14,6 @@
# #
import os import os
from unittest import mock
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
@ -170,25 +169,23 @@ class TestAscendConfig(TestBase):
init_ascend_config(test_vllm_config) init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False) check_ascend_config(test_vllm_config, False)
# For V1 engine test_vllm_config.additional_config = {
with mock.patch.dict(os.environ, {"VLLM_USE_V1": "1"}): "torchair_graph_config": {
test_vllm_config.additional_config = { "enabled": True,
"torchair_graph_config": { },
"enabled": True, "refresh": True
}, }
"refresh": True init_ascend_config(test_vllm_config)
} check_ascend_config(test_vllm_config, False)
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = { test_vllm_config.additional_config = {
"torchair_graph_config": { "torchair_graph_config": {
"enabled": False, "enabled": False,
}, },
"refresh": True "refresh": True
} }
init_ascend_config(test_vllm_config) init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False) check_ascend_config(test_vllm_config, False)
@_clean_up_ascend_config @_clean_up_ascend_config
def test_check_ascend_config_wrong_case(self): def test_check_ascend_config_wrong_case(self):

View File

@ -373,7 +373,6 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.utils.is_310p", return_value=False) @patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm.envs.VLLM_USE_V1", True)
def test_check_and_update_config_v1_worker_class_selection( def test_check_and_update_config_v1_worker_class_selection(
self, mock_init_ascend, mock_check_ascend, mock_is_310p): self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = self.mock_ascend_config mock_init_ascend.return_value = self.mock_ascend_config
@ -392,7 +391,6 @@ class TestNPUPlatform(TestBase):
@patch("vllm_ascend.ascend_config.check_ascend_config") @patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config") @patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.is_310p", return_value=True) @patch("vllm_ascend.utils.is_310p", return_value=True)
@patch("vllm.envs.VLLM_USE_V1", True)
def test_check_and_update_config_310p_no_custom_ops( def test_check_and_update_config_310p_no_custom_ops(
self, mock_is_310p, mock_init_ascend, mock_check_ascend): self, mock_is_310p, mock_init_ascend, mock_check_ascend):
mock_init_ascend.return_value = self.mock_ascend_config mock_init_ascend.return_value = self.mock_ascend_config