[Feat]mtp aclgraph support (#3244)

### What this PR does / why we need it?
Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR
is use to allow MTP to be captured in ACLGraph mode.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-17 18:14:49 +08:00
committed by GitHub
parent 1b424fb7f1
commit 46e62efd44
6 changed files with 26 additions and 10 deletions

View File

@ -23,6 +23,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
return logits
@support_torch_compile
class CustomDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
@ -114,7 +114,9 @@ class EagleProposer(Proposer):
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,

View File

@ -2,7 +2,7 @@ import enum
from typing import Optional
import torch
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -33,7 +33,9 @@ class Proposer:
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
"""Called by dummy_run in modle_runner"""
raise NotImplementedError

View File

@ -5,8 +5,8 @@ import torch.nn as nn
import torchair
from torchair import patch_for_hcom
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
@ -109,7 +109,9 @@ class MtpProposer(Proposer):
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
@ -151,7 +153,9 @@ class MtpProposer(Proposer):
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
@ -442,6 +446,7 @@ class MtpProposer(Proposer):
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):

View File

@ -1,4 +1,5 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.ngram_proposer import \
NgramProposer as VllmNgramProposer
@ -23,7 +24,9 @@ class NgramProposer(VllmNgramProposer, Proposer):
with_prefill=None,
skip_attn=None,
num_reqs=None,
num_tokens_across_dp=None):
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
pass
def generate_token_ids(self,

View File

@ -2479,7 +2479,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=with_prefill,
skip_attn=True,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
if self.in_profile_run and self.dynamic_eplb: