mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Quickfix] update CachedRequestState as NewRequestData changed (#2367)
### What this PR does / why we need it?
1. update `CachedRequestState` as `NewRequestData` changed in
https://github.com/vllm-project/vllm/pull/22570
2. drop maintenance of vllm v0.10.0 in the branch main
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
92ff41abea
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
12
.github/workflows/vllm_ascend_test.yaml
vendored
12
.github/workflows/vllm_ascend_test.yaml
vendored
@ -81,7 +81,7 @@ jobs:
|
|||||||
VLLM_USE_MODELSCOPE: True
|
VLLM_USE_MODELSCOPE: True
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
vllm_version: [main, v0.10.0]
|
vllm_version: [main]
|
||||||
steps:
|
steps:
|
||||||
- name: Install packages
|
- name: Install packages
|
||||||
run: |
|
run: |
|
||||||
@ -137,7 +137,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
os: [linux-aarch64-a2-1]
|
os: [linux-aarch64-a2-1]
|
||||||
vllm_version: [main, v0.10.0]
|
vllm_version: [main]
|
||||||
name: singlecard e2e test
|
name: singlecard e2e test
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
container:
|
container:
|
||||||
@ -185,9 +185,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
pip install -v -e .
|
pip install -v -e .
|
||||||
if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then
|
|
||||||
pip install "transformers<4.54.0"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Run e2e test
|
- name: Run e2e test
|
||||||
env:
|
env:
|
||||||
@ -222,7 +219,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
os: [linux-aarch64-a2-2]
|
os: [linux-aarch64-a2-2]
|
||||||
vllm_version: [main, v0.10.0]
|
vllm_version: [main]
|
||||||
name: multicard e2e test
|
name: multicard e2e test
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
container:
|
container:
|
||||||
@ -270,9 +267,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip install -r requirements-dev.txt
|
pip install -r requirements-dev.txt
|
||||||
pip install -v -e .
|
pip install -v -e .
|
||||||
if [[ "${{ matrix.vllm_version }}" == "v0.10.0" ]]; then
|
|
||||||
pip install "transformers<4.54.0"
|
|
||||||
fi
|
|
||||||
|
|
||||||
- name: Run vllm-project/vllm-ascend test
|
- name: Run vllm-project/vllm-ascend test
|
||||||
env:
|
env:
|
||||||
|
2
.github/workflows/vllm_ascend_test_310p.yaml
vendored
2
.github/workflows/vllm_ascend_test_310p.yaml
vendored
@ -53,7 +53,7 @@ jobs:
|
|||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
matrix:
|
matrix:
|
||||||
os: [linux-aarch64-310p-1, linux-aarch64-310p-4]
|
os: [linux-aarch64-310p-1, linux-aarch64-310p-4]
|
||||||
vllm_version: [main, v0.10.0]
|
vllm_version: [main]
|
||||||
name: 310p e2e test
|
name: 310p e2e test
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
container:
|
container:
|
||||||
|
@ -50,7 +50,7 @@ def create_requests(
|
|||||||
request_id=f"{i}",
|
request_id=f"{i}",
|
||||||
prompt_token_ids=[i] * num_tokens,
|
prompt_token_ids=[i] * num_tokens,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
multi_modal_inputs=mm_inputs,
|
multi_modal_kwargs=mm_inputs,
|
||||||
multi_modal_placeholders=mm_position,
|
multi_modal_placeholders=mm_position,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
eos_token_id=EOS_TOKEN_ID,
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
@ -25,7 +25,6 @@ from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
|||||||
create_model_runner_output,
|
create_model_runner_output,
|
||||||
create_request, create_scheduler,
|
create_request, create_scheduler,
|
||||||
create_vllm_config)
|
create_vllm_config)
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_lifecycle():
|
def test_basic_lifecycle():
|
||||||
@ -103,13 +102,10 @@ def test_basic_lifecycle():
|
|||||||
|
|
||||||
# (3b): execute_model()
|
# (3b): execute_model()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
if vllm_version_is("0.10.0"):
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||||
model_runner_output.finished_sending = [request_id]
|
KVConnectorOutput # type: ignore # noqa
|
||||||
else:
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
finished_sending=[request_id])
|
||||||
KVConnectorOutput # type: ignore # noqa
|
|
||||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
||||||
finished_sending=[request_id])
|
|
||||||
|
|
||||||
# (3c): update_from_output()
|
# (3c): update_from_output()
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
@ -164,13 +160,10 @@ def test_prefix_cache_lifecycle():
|
|||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
scheduler.schedule()
|
scheduler.schedule()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
if vllm_version_is("0.10.0"):
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||||
model_runner_output.finished_sending = [request_remote.request_id]
|
KVConnectorOutput # noqa
|
||||||
else:
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
finished_sending=[request_remote.request_id])
|
||||||
KVConnectorOutput # noqa
|
|
||||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
||||||
finished_sending=[request_remote.request_id])
|
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
_ = scheduler.schedule()
|
_ = scheduler.schedule()
|
||||||
assert_scheduler_empty(scheduler)
|
assert_scheduler_empty(scheduler)
|
||||||
|
@ -25,7 +25,6 @@ from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
|||||||
create_model_runner_output,
|
create_model_runner_output,
|
||||||
create_request, create_scheduler,
|
create_request, create_scheduler,
|
||||||
create_vllm_config)
|
create_vllm_config)
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
|
|
||||||
def test_basic_lifecycle():
|
def test_basic_lifecycle():
|
||||||
@ -91,13 +90,10 @@ def test_basic_lifecycle():
|
|||||||
|
|
||||||
# (2b): forward(): request finishes recv.
|
# (2b): forward(): request finishes recv.
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
if vllm_version_is("0.10.0"):
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||||
model_runner_output.finished_recving = [request_id]
|
KVConnectorOutput # type: ignore # noqa
|
||||||
else:
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
finished_recving=[request_id])
|
||||||
KVConnectorOutput # type: ignore # noqa
|
|
||||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
||||||
finished_recving=[request_id])
|
|
||||||
|
|
||||||
# (2c): update_from_output():
|
# (2c): update_from_output():
|
||||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||||
@ -211,13 +207,10 @@ def test_full_block_prompt():
|
|||||||
# # STEP (2): Recv.
|
# # STEP (2): Recv.
|
||||||
scheduler_output = scheduler.schedule()
|
scheduler_output = scheduler.schedule()
|
||||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
if vllm_version_is("0.10.0"):
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||||
model_runner_output.finished_recving = [request_id]
|
KVConnectorOutput # type: ignore # noqa
|
||||||
else:
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
finished_recving=[request_id])
|
||||||
KVConnectorOutput # type: ignore # noqa
|
|
||||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
||||||
finished_recving=[request_id])
|
|
||||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
assert len(scheduler.waiting) == 1
|
assert len(scheduler.waiting) == 1
|
||||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||||
|
@ -157,7 +157,7 @@ def create_request(
|
|||||||
request_id=f"id-{request_id}",
|
request_id=f"id-{request_id}",
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
multi_modal_inputs=None,
|
multi_modal_kwargs=None,
|
||||||
multi_modal_placeholders=None,
|
multi_modal_placeholders=None,
|
||||||
multi_modal_hashes=None,
|
multi_modal_hashes=None,
|
||||||
**({
|
**({
|
||||||
@ -187,19 +187,11 @@ def create_model_runner_output(
|
|||||||
|
|
||||||
# Make output data structure.
|
# Make output data structure.
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
if not vllm_version_is("0.10.0"):
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
KVConnectorOutput # type: ignore # noqa
|
||||||
KVConnectorOutput # type: ignore # noqa
|
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||||
kv_connector_output = KVConnectorOutput(
|
finished_recving=finished_recving)
|
||||||
finished_sending=finished_sending,
|
extra_args = {"kv_connector_output": kv_connector_output}
|
||||||
finished_recving=finished_recving)
|
|
||||||
extra_args = {"kv_connector_output": kv_connector_output}
|
|
||||||
else:
|
|
||||||
extra_args = {
|
|
||||||
"finished_sending": finished_sending,
|
|
||||||
"finished_recving": finished_recving,
|
|
||||||
}
|
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_id_to_index,
|
req_id_to_index=req_id_to_index,
|
||||||
|
@ -12,7 +12,7 @@ def mock_cached_request_state(req_id="1", prompt=[1, 2, 3], output=[4, 5, 6]):
|
|||||||
return CachedRequestState(
|
return CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=prompt,
|
prompt_token_ids=prompt,
|
||||||
mm_inputs=[],
|
mm_kwargs=[],
|
||||||
mm_positions=[],
|
mm_positions=[],
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
pooling_params=None,
|
pooling_params=None,
|
||||||
|
@ -30,8 +30,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
get_act_and_mul_fn)
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
@ -43,8 +42,6 @@ from vllm.model_executor.models.qwen2_5_vl import (
|
|||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
MIN_PAD_SIZE = 64 # min_size to pad weight
|
MIN_PAD_SIZE = 64 # min_size to pad weight
|
||||||
MAX_PAD_SIZE = 128 # max_size to pad weight
|
MAX_PAD_SIZE = 128 # max_size to pad weight
|
||||||
|
|
||||||
@ -202,8 +199,6 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||||
if vllm_version_is("0.10.0"):
|
|
||||||
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
AscendQwen2_5_VisionBlock(
|
AscendQwen2_5_VisionBlock(
|
||||||
dim=self.hidden_size,
|
dim=self.hidden_size,
|
||||||
@ -303,12 +298,9 @@ class AscendQwen2_5_VisionTransformer(Qwen2_5_VisionTransformer):
|
|||||||
("qkv_proj", "q_proj", "q"),
|
("qkv_proj", "q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
("qkv_proj", "k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
||||||
|
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
||||||
]
|
]
|
||||||
if not vllm_version_is("0.10.0"):
|
|
||||||
stacked_params_mapping.extend([
|
|
||||||
("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
|
|
||||||
("mlp.gate_up_proj.", "mlp.up_proj.", 1),
|
|
||||||
])
|
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
@ -30,8 +30,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import parallel_state
|
from vllm.distributed import parallel_state
|
||||||
from vllm.distributed import utils as dist_utils
|
from vllm.distributed import utils as dist_utils
|
||||||
from vllm.model_executor.layers.activation import (_ACTIVATION_REGISTRY,
|
from vllm.model_executor.layers.activation import get_act_and_mul_fn
|
||||||
get_act_and_mul_fn)
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.models.qwen2_5_vl import (
|
from vllm.model_executor.models.qwen2_5_vl import (
|
||||||
@ -43,7 +42,6 @@ from vllm.model_executor.models.utils import maybe_prefix
|
|||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
|
from vllm_ascend.models.qwen2_5_vl import AscendQwen2_5_VisionRotaryEmbedding
|
||||||
from vllm_ascend.utils import vllm_version_is
|
|
||||||
|
|
||||||
|
|
||||||
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
class AscendQwen2_5_VisionAttention_Without_Padding(Qwen2_5_VisionAttention):
|
||||||
@ -175,8 +173,6 @@ class AscendQwen2_5_VisionTransformer_Without_Padding(Qwen2_5_VisionTransformer
|
|||||||
)
|
)
|
||||||
|
|
||||||
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
act_fn = get_act_and_mul_fn(vision_config.hidden_act)
|
||||||
if vllm_version_is("0.10.0"):
|
|
||||||
act_fn = _ACTIVATION_REGISTRY[vision_config.hidden_act]
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
AscendQwen2_5_VisionBlock_Without_Padding(
|
AscendQwen2_5_VisionBlock_Without_Padding(
|
||||||
dim=self.hidden_size,
|
dim=self.hidden_size,
|
||||||
|
@ -105,7 +105,7 @@ def model_input_split_v1_mla_attn(
|
|||||||
[block_table_pre,
|
[block_table_pre,
|
||||||
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
block_table_post] = split_attn_tensor_type(attn_metadata.block_tables,
|
||||||
seq_index)
|
seq_index)
|
||||||
|
assert attn_metadata.attn_mask is not None
|
||||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache or attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit:
|
||||||
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
# the attn_mla kernel in torch npu only accept 128*128 attn mask
|
||||||
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
attn_mask_pre = attn_mask_post = attn_metadata.attn_mask
|
||||||
|
@ -14,12 +14,5 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
||||||
|
from vllm_ascend.patch.platform import patch_main # noqa: F401
|
||||||
# Import specific patches for different versions
|
|
||||||
if vllm_version_is("0.10.0"):
|
|
||||||
from vllm_ascend.patch.platform import patch_0_10_0 # noqa: F401
|
|
||||||
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
|
||||||
else:
|
|
||||||
from vllm_ascend.patch.platform import patch_common # noqa: F401
|
|
||||||
from vllm_ascend.patch.platform import patch_main # noqa: F401
|
|
||||||
|
@ -1,16 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
@ -15,12 +15,5 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
from vllm_ascend.utils import vllm_version_is
|
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
||||||
|
from vllm_ascend.patch.worker import patch_main # noqa: F401
|
||||||
# Import specific patches for different versions
|
|
||||||
if vllm_version_is("0.10.0"):
|
|
||||||
from vllm_ascend.patch.worker import patch_0_10_0 # noqa: F401
|
|
||||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
|
||||||
else:
|
|
||||||
from vllm_ascend.patch.worker import patch_common # noqa: F401
|
|
||||||
from vllm_ascend.patch.worker import patch_main # noqa: F401
|
|
||||||
|
@ -1,18 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import vllm_ascend.patch.worker.patch_0_10_0.patch_sampler_gather_logprobs # noqa
|
|
@ -1,87 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.v1.outputs import LogprobsTensors
|
|
||||||
from vllm.v1.sample.sampler import Sampler
|
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
|
||||||
def batched_count_greater_than(x: torch.Tensor,
|
|
||||||
values: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Counts elements in each row of x that are greater than the corresponding
|
|
||||||
value in values. Use torch.compile to generate an optimized kernel for
|
|
||||||
this function. otherwise, it will create additional copies of the input
|
|
||||||
tensors and cause memory issues.
|
|
||||||
Args:
|
|
||||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
|
||||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
|
||||||
"""
|
|
||||||
return (x >= values).sum(-1)
|
|
||||||
|
|
||||||
|
|
||||||
def gather_logprobs(
|
|
||||||
self,
|
|
||||||
logprobs: torch.Tensor,
|
|
||||||
num_logprobs: int,
|
|
||||||
token_ids: torch.Tensor,
|
|
||||||
) -> LogprobsTensors:
|
|
||||||
"""
|
|
||||||
Gather logprobs for topk and sampled/prompt token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logprobs: (num tokens) x (vocab) tensor
|
|
||||||
num_logprobs: minimum number of logprobs to
|
|
||||||
retain per token
|
|
||||||
token_ids: prompt tokens (if prompt logprobs)
|
|
||||||
or sampled tokens (if sampled
|
|
||||||
logprobs); 1D token ID tensor
|
|
||||||
with (num tokens) elements
|
|
||||||
Must be int64.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
|
||||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
|
||||||
Sampled token rank tensor, (num tokens)
|
|
||||||
"""
|
|
||||||
assert token_ids.dtype == torch.int64
|
|
||||||
# Find the topK values.
|
|
||||||
topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)
|
|
||||||
|
|
||||||
# Get with the logprob of the prompt or sampled token.
|
|
||||||
token_ids = token_ids.unsqueeze(-1)
|
|
||||||
token_logprobs = logprobs.gather(-1, token_ids)
|
|
||||||
|
|
||||||
# Compute the ranks of the actual token.
|
|
||||||
token_ranks = batched_count_greater_than(logprobs, token_logprobs)
|
|
||||||
|
|
||||||
# Concatenate together with the topk.
|
|
||||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
|
||||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
|
||||||
|
|
||||||
# Use int32 to reduce the tensor size.
|
|
||||||
indices = indices.to(torch.int32)
|
|
||||||
|
|
||||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
|
||||||
|
|
||||||
|
|
||||||
Sampler.gather_logprobs = gather_logprobs
|
|
@ -51,11 +51,12 @@ from vllm.model_executor.model_loader import get_model
|
|||||||
from vllm.model_executor.models.interfaces import supports_transcription
|
from vllm.model_executor.models.interfaces import supports_transcription
|
||||||
from vllm.model_executor.models.interfaces_base import (
|
from vllm.model_executor.models.interfaces_base import (
|
||||||
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
VllmModelForPooling, is_pooling_model, is_text_generation_model)
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.tasks import GenerationTask, SupportedTask
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LazyLoader, cdiv)
|
LazyLoader, cdiv)
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@ -66,6 +67,7 @@ from vllm.v1.pool.metadata import PoolingMetadata
|
|||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
|
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
|
||||||
sanity_check_mm_encoder_outputs,
|
sanity_check_mm_encoder_outputs,
|
||||||
@ -86,17 +88,11 @@ from vllm_ascend.platform import NPUPlatform
|
|||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
ProfileExecuteDuration, is_310p,
|
ProfileExecuteDuration, is_310p,
|
||||||
maybe_converting_weight_acl_format,
|
maybe_converting_weight_acl_format)
|
||||||
vllm_version_is)
|
|
||||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
if not vllm_version_is("0.10.0"):
|
|
||||||
from vllm.tasks import GenerationTask, SupportedTask
|
|
||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
||||||
KVConnectorOutput
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr # type: ignore[import-untyped]
|
import xgrammar as xgr # type: ignore[import-untyped]
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@ -479,7 +475,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.requests[req_id] = CachedRequestState(
|
self.requests[req_id] = CachedRequestState(
|
||||||
req_id=req_id,
|
req_id=req_id,
|
||||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||||
mm_inputs=new_req_data.mm_inputs,
|
mm_kwargs=new_req_data.mm_kwargs,
|
||||||
mm_positions=new_req_data.mm_positions,
|
mm_positions=new_req_data.mm_positions,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
pooling_params=new_req_data.pooling_params,
|
pooling_params=new_req_data.pooling_params,
|
||||||
@ -497,18 +493,20 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
second_per_grid_ts = []
|
second_per_grid_ts = []
|
||||||
audio_feature_lengths = []
|
audio_feature_lengths = []
|
||||||
use_audio_in_video = False
|
use_audio_in_video = False
|
||||||
for mm_input in self.requests[req_id].mm_inputs:
|
|
||||||
|
for item in self.requests[req_id].mm_kwargs:
|
||||||
|
mm_input = item.require_data()
|
||||||
if mm_input.get("image_grid_thw") is not None:
|
if mm_input.get("image_grid_thw") is not None:
|
||||||
image_grid_thw.extend(
|
image_grid_thw.append(
|
||||||
mm_input["image_grid_thw"].tolist())
|
mm_input["image_grid_thw"].tolist())
|
||||||
if mm_input.get("video_grid_thw") is not None:
|
if mm_input.get("video_grid_thw") is not None:
|
||||||
video_grid_thw.extend(
|
video_grid_thw.append(
|
||||||
mm_input["video_grid_thw"].tolist())
|
mm_input["video_grid_thw"].tolist())
|
||||||
if mm_input.get("second_per_grid_ts") is not None:
|
if mm_input.get("second_per_grid_ts") is not None:
|
||||||
second_per_grid_ts.extend(
|
second_per_grid_ts.append(
|
||||||
mm_input["second_per_grid_ts"])
|
mm_input["second_per_grid_ts"])
|
||||||
if mm_input.get("audio_feature_lengths") is not None:
|
if mm_input.get("audio_feature_lengths") is not None:
|
||||||
audio_feature_lengths.extend(
|
audio_feature_lengths.append(
|
||||||
mm_input["audio_feature_lengths"])
|
mm_input["audio_feature_lengths"])
|
||||||
if mm_input.get("use_audio_in_video") is True:
|
if mm_input.get("use_audio_in_video") is True:
|
||||||
use_audio_in_video = True
|
use_audio_in_video = True
|
||||||
@ -912,13 +910,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Batch the multi-modal inputs.
|
# Batch the multi-modal inputs.
|
||||||
mm_inputs = list[MultiModalKwargs]()
|
mm_kwargs = list[MultiModalKwargsItem]()
|
||||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
|
|
||||||
for mm_input_id in encoder_input_ids:
|
for mm_input_id in encoder_input_ids:
|
||||||
mm_inputs.append(req_state.mm_inputs[mm_input_id])
|
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||||
req_ids_pos.append(
|
req_ids_pos.append(
|
||||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||||
|
|
||||||
@ -929,14 +927,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# in the same batch while still being able to benefit from batching
|
# in the same batch while still being able to benefit from batching
|
||||||
# multimodal inputs. The proper solution should be reordering the
|
# multimodal inputs. The proper solution should be reordering the
|
||||||
# encoder outputs.
|
# encoder outputs.
|
||||||
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
|
|
||||||
|
|
||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
mm_kwargs,
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
device=self.device,
|
||||||
device=self.device)
|
pin_memory=True,
|
||||||
|
):
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `curr_group_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
# 1. A tensor of shape (num_items, feature_size, hidden_size)
|
||||||
@ -945,11 +941,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||||
# depending on the input multimodal items.
|
# depending on the input multimodal items.
|
||||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||||
**batched_mm_inputs)
|
**mm_kwargs_group)
|
||||||
|
|
||||||
sanity_check_mm_encoder_outputs(
|
sanity_check_mm_encoder_outputs(
|
||||||
curr_group_outputs,
|
curr_group_outputs,
|
||||||
expected_num_items=len(grouped_mm_inputs),
|
expected_num_items=num_items,
|
||||||
)
|
)
|
||||||
|
|
||||||
for output in curr_group_outputs:
|
for output in curr_group_outputs:
|
||||||
@ -1604,12 +1600,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
pooler_output.append(raw_output.data.cpu())
|
pooler_output.append(raw_output.data.cpu())
|
||||||
else:
|
else:
|
||||||
pooler_output.append(None)
|
pooler_output.append(None)
|
||||||
extra_args = ({
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||||
"finished_sending": finished_sending,
|
|
||||||
"finished_recving": finished_recving
|
|
||||||
} if vllm_version_is("0.10.0") else {
|
|
||||||
"kv_connector_output": kv_connector_output
|
|
||||||
})
|
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
@ -1645,15 +1636,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
finished_recving) = (self._process_reqs(scheduler_output,
|
finished_recving) = (self._process_reqs(scheduler_output,
|
||||||
intermediate_tensors))
|
intermediate_tensors))
|
||||||
kv_connector_output = None
|
kv_connector_output = None
|
||||||
if not vllm_version_is("0.10.0"):
|
if finished_sending is not None and finished_recving is not None:
|
||||||
if finished_sending is not None and finished_recving is not None:
|
kv_connector_output = KVConnectorOutput(
|
||||||
kv_connector_output = KVConnectorOutput(
|
finished_sending=finished_sending,
|
||||||
finished_sending=finished_sending,
|
finished_recving=finished_recving)
|
||||||
finished_recving=finished_recving)
|
else:
|
||||||
else:
|
kv_connector_output = None
|
||||||
kv_connector_output = None
|
finished_sending = None
|
||||||
finished_sending = None
|
finished_recving = None
|
||||||
finished_recving = None
|
|
||||||
with ProfileExecuteDuration().capture_async("post process"):
|
with ProfileExecuteDuration().capture_async("post process"):
|
||||||
# Broadcast PP output for external_launcher (torchrun)
|
# Broadcast PP output for external_launcher (torchrun)
|
||||||
# to make sure we are synced across pp ranks
|
# to make sure we are synced across pp ranks
|
||||||
@ -1665,12 +1655,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# For mid-pipeline stages, return the hidden states.
|
# For mid-pipeline stages, return the hidden states.
|
||||||
if not broadcast_pp_output:
|
if not broadcast_pp_output:
|
||||||
if kv_connector_output is not None:
|
hidden_states.kv_connector_output = kv_connector_output
|
||||||
hidden_states.kv_connector_output = kv_connector_output
|
|
||||||
else:
|
|
||||||
#TODO: Remove this after we drop vllm v0.10.0
|
|
||||||
hidden_states.finished_sending = finished_sending
|
|
||||||
hidden_states.finished_recving = finished_recving
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
assert isinstance(hidden_states, IntermediateTensors)
|
assert isinstance(hidden_states, IntermediateTensors)
|
||||||
get_pp_group().send_tensor_dict(
|
get_pp_group().send_tensor_dict(
|
||||||
@ -1815,12 +1800,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().clear_connector_metadata()
|
get_kv_transfer_group().clear_connector_metadata()
|
||||||
|
|
||||||
extra_args = ({
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||||
"finished_sending": finished_sending,
|
|
||||||
"finished_recving": finished_recving
|
|
||||||
} if vllm_version_is("0.10.0") else {
|
|
||||||
"kv_connector_output": kv_connector_output
|
|
||||||
})
|
|
||||||
|
|
||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
|
@ -43,7 +43,7 @@ class CachedRequestState:
|
|||||||
|
|
||||||
req_id: str
|
req_id: str
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_inputs: list[MultiModalKwargs]
|
mm_kwargs: list[MultiModalKwargs]
|
||||||
mm_positions: list[PlaceholderRange]
|
mm_positions: list[PlaceholderRange]
|
||||||
sampling_params: Optional[SamplingParams]
|
sampling_params: Optional[SamplingParams]
|
||||||
pooling_params: Optional[PoolingParams]
|
pooling_params: Optional[PoolingParams]
|
||||||
|
@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, GiB_bytes
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
@ -45,12 +46,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator
|
|||||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
|
from vllm_ascend.utils import (init_ascend_soc_version, sleep_mode_enabled,
|
||||||
try_register_lib, vllm_version_is)
|
try_register_lib)
|
||||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||||
|
|
||||||
if not vllm_version_is("0.10.0"):
|
|
||||||
from vllm.tasks import SupportedTask
|
|
||||||
|
|
||||||
|
|
||||||
class NPUWorker(WorkerBase):
|
class NPUWorker(WorkerBase):
|
||||||
|
|
||||||
@ -209,26 +207,15 @@ class NPUWorker(WorkerBase):
|
|||||||
if not has_kv_transfer_group():
|
if not has_kv_transfer_group():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
is_legacy = vllm_version_is("0.10.0")
|
kv_connector_output = output.kv_connector_output
|
||||||
|
finished_sending = kv_connector_output.finished_sending
|
||||||
if is_legacy:
|
finished_recving = kv_connector_output.finished_recving
|
||||||
finished_sending = output.finished_sending
|
|
||||||
finished_recving = output.finished_recving
|
|
||||||
else:
|
|
||||||
kv_connector_output = output.kv_connector_output
|
|
||||||
finished_sending = kv_connector_output.finished_sending
|
|
||||||
finished_recving = kv_connector_output.finished_recving
|
|
||||||
|
|
||||||
if not finished_sending and not finished_recving:
|
if not finished_sending and not finished_recving:
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
|
|
||||||
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
new_output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||||
|
new_output.kv_connector_output = kv_connector_output
|
||||||
if is_legacy:
|
|
||||||
new_output.finished_sending = finished_sending
|
|
||||||
new_output.finished_recving = finished_recving
|
|
||||||
else:
|
|
||||||
new_output.kv_connector_output = kv_connector_output
|
|
||||||
return new_output
|
return new_output
|
||||||
|
|
||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
|
Reference in New Issue
Block a user