Files
vllm-ascend/vllm_ascend/platform.py
Shirley125 b4233a2ec3 [Bugfix] Route requests requiring KVC recomputation from the decode instance to the P instance (#3448)
### What this PR does / why we need it?
This PR is aimed to fix the recomputing out of memory bug in decode
instance. When recomputing happens in decode, kv cache usage may exceed
the pre-allocated memory, and it will cause OOM.

So we propose a new scheduling strategy, when decode instance cannot
allocate new block for running requests, we will stop the request that
will be preempted. These stopped request will be recognied by proxy, and
they will be send to prefill instance again to calculate kvc and then
direct to decode instance.

This is a temporary plan to fix the bug. The long-term stratege is to
use CPU offload in decode instance.

### Does this PR introduce _any_ user-facing change?
An extra ascend configuration option **-- recompute_scheduler_enable =
True** is added to enable this strategy. The default value is False
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
2025-10-18 15:56:44 +08:00

426 lines
18 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import gc
import os
from datetime import timedelta
from typing import TYPE_CHECKING, Optional, Tuple
import torch
import vllm.envs as envs_vllm
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.logger import logger
from vllm.platforms import Platform, PlatformEnum
from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
update_aclgraph_sizes)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from vllm.utils import FlexibleArgumentParser
else:
ModelConfig = None
VllmConfig = None
FlexibleArgumentParser = None
class NPUPlatform(Platform):
_enum = PlatformEnum.OOT
device_name: str = "npu"
device_type: str = "npu"
simple_compile_backend: str = "eager" # Disable torch.compile()
ray_device_key: str = "NPU"
device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
dispatch_key: str = "PrivateUse1"
supported_quantization: list[str] = [ASCEND_QUANTIZATION_METHOD]
def is_sleep_mode_available(self) -> bool:
return True
@classmethod
def pre_register_and_update(cls,
parser: Optional[FlexibleArgumentParser] = None
) -> None:
# Adapt the global patch here.
from vllm_ascend.utils import adapt_patch
adapt_patch(is_global_patch=True)
# For online serving, "ascend" quantization method is not a choice natively,
# so we need to add "ascend" quantization method to quantization methods list
# and the user can enable quantization using "vllm serve --quantization ascend".
if parser is not None:
quant_action = parser._option_string_actions.get('--quantization')
if quant_action and hasattr(quant_action,
'choices') and quant_action.choices:
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
from vllm_ascend.quantization.quant_config import \
AscendQuantConfig # noqa: F401
@classmethod
def get_device_capability(cls, device_id: int = 0):
return None
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return torch.npu.get_device_name(device_id)
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def inference_mode(cls):
return torch.inference_mode()
@classmethod
def set_device(cls, device: torch.device):
torch.npu.set_device(device)
@classmethod
def empty_cache(cls):
torch.npu.empty_cache()
@classmethod
def synchronize(cls):
torch.npu.synchronize()
@classmethod
def mem_get_info(cls) -> Tuple[int, int]:
return torch.npu.mem_get_info()
@classmethod
def clear_npu_memory(cls):
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if not envs_vllm.VLLM_USE_V1:
raise ValueError("vLLM Ascend does not support V0 engine.")
# initialize ascend config from vllm additional_config
ascend_config = init_ascend_config(vllm_config)
from vllm.config import CompilationLevel # noqa: E402
compilation_config = vllm_config.compilation_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
ascend_scheduler_config = ascend_config.ascend_scheduler_config
structured_outputs_config = vllm_config.structured_outputs_config
if (model_config is not None and not model_config.use_mla
and not scheduler_config.async_scheduling
and model_config.runner_type != "pooling"):
logger.info(
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
"as the performance of operators supporting this feature "
"functionality is currently suboptimal.")
if not model_config.is_multimodal_model and \
structured_outputs_config.backend == "auto" and \
not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \
not scheduler_config.send_delta_data and \
scheduler_config.policy == "fcfs":
ascend_scheduler_config.enabled = True
chunked_prefill_enabled_in_ascend_scheduler = getattr(
ascend_scheduler_config, "enable_chunked_prefill", False)
if chunked_prefill_enabled_in_ascend_scheduler:
logger.warning(
"Chunked prefill feature is enabled in ascend_scheduler,"
"but note that the operator supporting this feature "
"would lead to performance degradation.")
# In this situation, max_num_batched_tokens would have been rewritten.
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
if (scheduler_config.max_num_batched_tokens
< scheduler_config.max_model_len
and not chunked_prefill_enabled_in_ascend_scheduler):
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
kv_cache_dtype = vllm_config.additional_config.get(
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
vllm_config.cache_config.cache_dtype = kv_cache_dtype
elif model_config and hasattr(model_config.hf_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(
model_config.dtype).replace("torch.", "")
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
enforce_eager = False
else:
enforce_eager = getattr(model_config, "enforce_eager", False)
check_ascend_config(vllm_config, enforce_eager)
from vllm.config.compilation import CUDAGraphMode
if enforce_eager:
logger.info("Compilation disabled, using eager mode by default")
compilation_config.level = CompilationLevel.NO_COMPILATION
compilation_config.cudagraph_num_of_warmups = 1
if compilation_config.level not in [
CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE
]:
logger.warning(
"NPU does not support %s compilation level. Setting CUDAGraphMode to NONE",
compilation_config.level)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# set CUDAGraphMode to None when torchair is enabled, no mather what compilation_config.level is.
if ascend_config.torchair_graph_config.enabled:
logger.info(
"Torchair compilation enabled on NPU. Setting CUDAGraphMode to NONE"
)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# Note: We delete the torchair cache folder here to prevent runtime issues caused by dimension
# mismatches or configuration inconsistencies when users reuse cached computation graphs. Though
# this will increase graph compilation duration, it significantly enhances robustness and decreases
# graph launching time during inference.
if check_torchair_cache_exist(
) and not ascend_config.torchair_graph_config.use_cached_kv_cache_bytes:
logger.warning(
"Torchair cache folder is deleted here to prevent runtime issues caused by dimension "
"mismatches or configuration inconsistencies when users reuse cached computation graphs. "
"In order to decrease torchair graph compilation time, users can enable both use_cached_graph "
"and use_cached_kv_cache_bytes in torchair_graph_config.")
delete_torchair_cache_file()
# set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes()
# TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism
# is supported by vllm-ascend.
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
enable_sp(vllm_config):
original_sizes = compilation_config.cudagraph_capture_sizes
sp_aclgraph_sizes = \
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
assert sp_aclgraph_sizes, (
f"cudagraph_capture_sizes {original_sizes} does not contain"
f"values that are multiples of tp_size "
f"{vllm_config.parallel_config.tensor_parallel_size}")
if len(sp_aclgraph_sizes) != len(original_sizes):
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
vllm_config.compilation_config.init_with_cudagraph_sizes(
sp_aclgraph_sizes)
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE
if compilation_config.cudagraph_mode == CUDAGraphMode.NONE:
compilation_config.level = CompilationLevel.NO_COMPILATION
elif compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE:
logger.info(
"PIECEWISE compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
assert compilation_config.level == CompilationLevel.PIECEWISE, \
"When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE"
compilation_config.set_splitting_ops_for_v1()
compilation_config.use_inductor = False
compilation_config.splitting_ops.extend([
"vllm.unified_ascend_attention_with_output", "vllm.mla_forward"
])
update_aclgraph_sizes(vllm_config)
elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY:
logger.info(
"FULL_DECODE_ONLY compilation enabled on NPU. use_inductor not supported - "
"using only ACL Graph mode")
compilation_config.use_inductor = False
warning_message = """\033[91m
**********************************************************************************
* WARNING: You have enabled the *full graph* feature.
* This is an early experimental stage and may involve various unknown issues.
* A known problem is that capturing too many batch sizes can lead to OOM
* (Out of Memory) errors or inference hangs. If you encounter such issues,
* consider reducing `gpu_memory_utilization` or manually specifying a smaller
* batch size for graph capture.
* For more details, please refer to:
* https://docs.vllm.ai/en/stable/configuration/conserving_memory.html#reduce-cuda-graphs
**********************************************************************************\033[0m
"""
logger.warning(warning_message)
else:
logger.info(
"%s cudagraph_mode is not support on NPU. falling back to NONE",
compilation_config.cudagraph_mode)
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
compilation_config.level = CompilationLevel.NO_COMPILATION
if parallel_config and parallel_config.worker_cls == "auto":
# TODO: this is a tricky way to disable `use_sequence_parallel_moe` in vllm.
os.environ["VLLM_ALL2ALL_BACKEND"] = "flashinfer_all2allv"
if ascend_config.torchair_graph_config.enabled:
parallel_config.worker_cls = "vllm_ascend.torchair.torchair_worker.NPUTorchairWorker"
else:
parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker"
if cache_config:
if cache_config.block_size is None:
cache_config.block_size = 128
if cache_config.enable_prefix_caching and cache_config.block_size != 128:
logger.warning(
"If prefix caching is enabled, block size must be set to 128."
)
cache_config.block_size = 128
# Activate custom ops for v1, except on 310P
if not is_310p():
compilation_config.custom_ops = ["all"]
# If ascend_scheduler_config is enabled,
# extents original scheduler_config to use AscendScheduler.
if ascend_config.ascend_scheduler_config.enabled:
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(
vllm_config.scheduler_config,
ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config
elif ascend_config.recompute_scheduler_enable:
from vllm_ascend.core.recompute_schedule_config import \
RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
vllm_config.scheduler_config)
vllm_config.scheduler_config = recompute_scheduler_config
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink=False,
use_sparse=False,
):
if not use_v1:
raise ValueError("vLLM Ascend does not support V0 engine.")
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and use_sparse:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
use_torchair = ascend_config.torchair_graph_config.enabled
# choose attention backend based on use_mla and use_torchair
backend_map = {
(True, False, True):
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
(True, False, False):
"vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, False, True):
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
(False, False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True, False):
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
(True, True, True):
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
}
return backend_map[(use_mla, use_sparse, use_torchair)]
@classmethod
def get_punica_wrapper(cls) -> str:
return "vllm_ascend.lora.punica_npu.PunicaWrapperNPU"
@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
) -> float:
torch.npu.reset_peak_memory_stats(device)
return torch.npu.max_memory_allocated(device)
@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm_ascend.distributed.communicator.NPUCommunicator"
@classmethod
def is_pin_memory_available(cls):
return True
@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
"""Returns whether the current platform can support v1 for the supplied
model configuration.
"""
return True
@classmethod
def get_static_graph_wrapper_cls(cls) -> str:
"""
Get piecewise backend class for piecewise graph.
"""
return "vllm_ascend.compilation.acl_graph.ACLGraphWrapper" # noqa
@classmethod
def stateless_init_device_torch_dist_pg(
cls,
backend: str,
prefix_store: PrefixStore,
group_rank: int,
group_size: int,
timeout: timedelta,
) -> ProcessGroup:
from torch.distributed import is_hccl_available
from torch_npu._C._distributed_c10d import ProcessGroupHCCL
assert is_hccl_available()
pg: ProcessGroup = ProcessGroup(
prefix_store,
group_rank,
group_size,
)
backend_options = ProcessGroupHCCL.Options()
backend_options._timeout = timeout
backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size,
backend_options)
device = torch.device("npu")
# TODO(Yizhou): Like we mentioned above, _set_default_backend is not
# implemented in the 2.5.1 version of PyTorch. But we need to set it
# after the latest version is released.
# pg._set_default_backend(backend_type)
backend_class._set_sequence_number_for_group()
backend_type = ProcessGroup.BackendType.CUSTOM
pg._register_backend(device, backend_type, backend_class)
return pg
@classmethod
def support_hybrid_kv_cache(cls) -> bool:
return True
@classmethod
def support_static_graph_mode(cls) -> bool:
return True