[Quickfix] update CachedRequestState as NewRequestData changed (#2367)

### What this PR does / why we need it?
1. update `CachedRequestState` as `NewRequestData` changed in
https://github.com/vllm-project/vllm/pull/22570
2. drop maintenance of vllm v0.10.0 in the branch main

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.10.0
- vLLM main:
92ff41abea

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-08-15 07:35:27 +08:00
committed by GitHub
parent 2ad7e1251e
commit 61866b8ac6
18 changed files with 77 additions and 285 deletions

View File

@ -81,7 +81,7 @@ jobs:
VLLM_USE_MODELSCOPE: True
strategy:
matrix:
vllm_version: [main, v0.10.0]
vllm_version: [main]
steps:
- name: Install packages
run: |
@ -137,7 +137,7 @@ jobs:
max-parallel: 2
matrix:
os: [linux-aarch64-a2-1]
vllm_version: [main, v0.10.0]
vllm_version: [main]
name: singlecard e2e test
runs-on: ${{ matrix.os }}
container:
@ -185,9 +185,6 @@ jobs:
run: |
pip install -r requirements-dev.txt
pip install -v -e .
if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then
pip install "transformers<4.54.0"
fi
- name: Run e2e test
env:
@ -222,7 +219,7 @@ jobs:
max-parallel: 2
matrix:
os: [linux-aarch64-a2-2]
vllm_version: [main, v0.10.0]
vllm_version: [main]
name: multicard e2e test
runs-on: ${{ matrix.os }}
container:
@ -270,9 +267,6 @@ jobs:
run: |
pip install -r requirements-dev.txt
pip install -v -e .
if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then
pip install "transformers<4.54.0"
fi
- name: Run vllm-project/vllm-ascend test
env:

View File

@ -53,7 +53,7 @@ jobs:
max-parallel: 2
matrix:
os: [linux-aarch64-310p-1, linux-aarch64-310p-4]
vllm_version: [main, v0.10.0]
vllm_version: [main]
name: 310p e2e test
runs-on: ${{ matrix.os }}
container:

View File

@ -50,7 +50,7 @@ def create_requests(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_kwargs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,

View File

@ -25,7 +25,6 @@ from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.utils import vllm_version_is
def test_basic_lifecycle():
@ -103,9 +102,6 @@ def test_basic_lifecycle():
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
if vllm_version_is("0.10.0"):
model_runner_output.finished_sending = [request_id]
else:
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
@ -164,9 +160,6 @@ def test_prefix_cache_lifecycle():
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
if vllm_version_is("0.10.0"):
model_runner_output.finished_sending = [request_remote.request_id]
else:
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(

View File

@ -25,7 +25,6 @@ from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.utils import vllm_version_is
def test_basic_lifecycle():
@ -91,9 +90,6 @@ def test_basic_lifecycle():
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
if vllm_version_is("0.10.0"):
model_runner_output.finished_recving = [request_id]
else:
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
@ -211,9 +207,6 @@ def test_full_block_prompt():
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
if vllm_version_is("0.10.0"):
model_runner_output.finished_recving = [request_id]
else:
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(

View File

@ -157,7 +157,7 @@ def create_request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_inputs=None,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
**({
@ -187,19 +187,11 @@ def create_model_runner_output(
# Make output data structure.
extra_args = {}
if not vllm_version_is("0.10.0"):
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
else:
extra_args = {
"finished_sending": finished_sending,
"finished_recving": finished_recving,
}
return ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,

View File

@ -12,7 +12,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
return CachedRequestState(
req_id=req_id,
prompt_token_ids=prompt,
mm_inputs=[],
mm_kwargs=[],
mm_positions=[],
sampling_params=SamplingParams(),
pooling_params=None,

View File

@ -30,8 +30,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
get_act_and_mul_fn)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@ -43,8 +42,6 @@ from vllm.model_executor.models.qwen2_5_vl import (
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.utils import vllm_version_is
MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight
@ -202,8 +199,6 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
)
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
if vllm_version_is("0.10.0"):
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
self.blocks = nn.ModuleList([
AscendQwen2_5_VisionBlock(
dim=self.hidden_size,
@ -303,12 +298,9 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if not vllm_version_is("0.10.0"):
stacked_params_mapping.extend([
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
])
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:

View File

@ -30,8 +30,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
get_act_and_mul_fn)
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2_5_vl import (
@ -43,7 +42,6 @@ from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
from vllm_ascend.utils import vllm_version_is
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
@ -175,8 +173,6 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
)
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
if vllm_version_is("0.10.0"):
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
self.blocks = nn.ModuleList([
AscendQwen2_5_VisionBlock_Without_Padding(
dim=self.hidden_size,

View File

@ -105,7 +105,7 @@ def model_input_split_v1_mla_attn(
[block_table_pre,
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
seq_index)
assert attn_metadata.attn_mask is not None
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
# the attn_mla kernel in torch npu only accept 128*128 attn mask
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask

View File

@ -14,12 +14,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm_ascend.utils import vllm_version_is
# Import specific patches for different versions
if vllm_version_is("0.10.0"):
from vllm_ascend.patch.platform import patch_0_10_0 # noqa: F401
from vllm_ascend.patch.platform import patch_common # noqa: F401
else:
from vllm_ascend.patch.platform import patch_common # noqa: F401
from vllm_ascend.patch.platform import patch_main # noqa: F401
from vllm_ascend.patch.platform import patch_common # noqa: F401
from vllm_ascend.patch.platform import patch_main # noqa: F401

View File

@ -1,16 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

View File

@ -15,12 +15,5 @@
# limitations under the License.
#
from vllm_ascend.utils import vllm_version_is
# Import specific patches for different versions
if vllm_version_is("0.10.0"):
from vllm_ascend.patch.worker import patch_0_10_0 # noqa: F401
from vllm_ascend.patch.worker import patch_common # noqa: F401
else:
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.patch.worker import patch_main # noqa: F401
from vllm_ascend.patch.worker import patch_common # noqa: F401
from vllm_ascend.patch.worker import patch_main # noqa: F401

View File

@ -1,18 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import vllm_ascend.patch.worker.patch_0_10_0.patch_sampler_gather_logprobs # noqa

View File

@ -1,87 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
from vllm.platforms import current_platform
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.sample.sampler import Sampler
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def batched_count_greater_than(x: torch.Tensor,
values: torch.Tensor) -> torch.Tensor:
"""
Counts elements in each row of x that are greater than the corresponding
value in values. Use torch.compile to generate an optimized kernel for
this function. otherwise, it will create additional copies of the input
tensors and cause memory issues.
Args:
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
Returns:
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
"""
return (x >= values).sum(-1)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logprobs: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Must be int64.
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
assert token_ids.dtype == torch.int64
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
Sampler.gather_logprobs = gather_logprobs

View File

@ -51,11 +51,12 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_transcription
from vllm.model_executor.models.interfaces_base import (
VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, SupportedTask
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LazyLoader, cdiv)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -66,6 +67,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
@ -86,17 +88,11 @@ from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
ProfileExecuteDuration, is_310p,
maybe_converting_weight_acl_format,
vllm_version_is)
maybe_converting_weight_acl_format)
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if not vllm_version_is("0.10.0"):
from vllm.tasks import GenerationTask, SupportedTask
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput
if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import SchedulerOutput
@ -479,7 +475,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=new_req_data.prompt_token_ids,
mm_inputs=new_req_data.mm_inputs,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
sampling_params=sampling_params,
pooling_params=new_req_data.pooling_params,
@ -497,18 +493,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
second_per_grid_ts = []
audio_feature_lengths = []
use_audio_in_video = False
for mm_input in self.requests[req_id].mm_inputs:
for item in self.requests[req_id].mm_kwargs:
mm_input = item.require_data()
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
image_grid_thw.append(
mm_input["image_grid_thw"].tolist())
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
video_grid_thw.append(
mm_input["video_grid_thw"].tolist())
if mm_input.get("second_per_grid_ts") is not None:
second_per_grid_ts.extend(
second_per_grid_ts.append(
mm_input["second_per_grid_ts"])
if mm_input.get("audio_feature_lengths") is not None:
audio_feature_lengths.extend(
audio_feature_lengths.append(
mm_input["audio_feature_lengths"])
if mm_input.get("use_audio_in_video") is True:
use_audio_in_video = True
@ -912,13 +910,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return
# Batch the multi-modal inputs.
mm_inputs = list[MultiModalKwargs]()
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[mm_input_id])
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
@ -929,14 +927,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the
# encoder outputs.
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
encoder_outputs = []
for grouped_mm_inputs in grouped_mm_inputs_list:
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs,
device=self.device,
pin_memory=True,
):
# Run the encoder.
# `curr_group_outputs` is either of the following:
# 1. A tensor of shape (num_items, feature_size, hidden_size)
@ -945,11 +941,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
**mm_kwargs_group)
sanity_check_mm_encoder_outputs(
curr_group_outputs,
expected_num_items=len(grouped_mm_inputs),
expected_num_items=num_items,
)
for output in curr_group_outputs:
@ -1604,12 +1600,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
pooler_output.append(raw_output.data.cpu())
else:
pooler_output.append(None)
extra_args = ({
"finished_sending": finished_sending,
"finished_recving": finished_recving
} if vllm_version_is("0.10.0") else {
"kv_connector_output": kv_connector_output
})
extra_args = ({"kv_connector_output": kv_connector_output})
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@ -1645,7 +1636,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
finished_recving) = (self._process_reqs(scheduler_output,
intermediate_tensors))
kv_connector_output = None
if not vllm_version_is("0.10.0"):
if finished_sending is not None and finished_recving is not None:
kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending,
@ -1665,12 +1655,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
if not broadcast_pp_output:
if kv_connector_output is not None:
hidden_states.kv_connector_output = kv_connector_output
else:
#TODO: Remove this after we drop vllm v0.10.0
hidden_states.finished_sending = finished_sending
hidden_states.finished_recving = finished_recving
return hidden_states
assert isinstance(hidden_states, IntermediateTensors)
get_pp_group().send_tensor_dict(
@ -1815,12 +1800,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().clear_connector_metadata()
extra_args = ({
"finished_sending": finished_sending,
"finished_recving": finished_recving
} if vllm_version_is("0.10.0") else {
"kv_connector_output": kv_connector_output
})
extra_args = ({"kv_connector_output": kv_connector_output})
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids,

View File

@ -43,7 +43,7 @@ class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_kwargs: list[MultiModalKwargs]
mm_positions: list[PlaceholderRange]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]

View File

@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import logger
from vllm.lora.request import LoRARequest
from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
@ -45,12 +46,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
try_register_lib, vllm_version_is)
try_register_lib)
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if not vllm_version_is("0.10.0"):
from vllm.tasks import SupportedTask
class NPUWorker(WorkerBase):
@ -209,12 +207,6 @@ class NPUWorker(WorkerBase):
if not has_kv_transfer_group():
return None
is_legacy = vllm_version_is("0.10.0")
if is_legacy:
finished_sending = output.finished_sending
finished_recving = output.finished_recving
else:
kv_connector_output = output.kv_connector_output
finished_sending = kv_connector_output.finished_sending
finished_recving = kv_connector_output.finished_recving
@ -223,11 +215,6 @@ class NPUWorker(WorkerBase):
return EMPTY_MODEL_RUNNER_OUTPUT
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
if is_legacy:
new_output.finished_sending = finished_sending
new_output.finished_recving = finished_recving
else:
new_output.kv_connector_output = kv_connector_output
return new_output