[Feat]mtp aclgraph support (#3244)

### What this PR does / why we need it?
Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR
is use to allow MTP to be captured in ACLGraph mode.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
anon189Ty
2025-10-17 18:14:49 +08:00
committed by GitHub
parent 1b424fb7f1
commit 46e62efd44
6 changed files with 26 additions and 10 deletions

View File

@ -23,6 +23,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig, from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config) get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
return logits return logits
@support_torch_compile
class CustomDeepSeekMTP(DeepSeekMTP): class CustomDeepSeekMTP(DeepSeekMTP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -5,7 +5,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig, from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config) get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger from vllm.logger import logger
@ -114,7 +114,9 @@ class EagleProposer(Proposer):
with_prefill: bool = False, with_prefill: bool = False,
skip_attn: bool = False, skip_attn: bool = False,
num_reqs: int = 0, num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None): num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
moe_comm_type = self.runner._select_moe_comm_method( moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill) num_tokens, with_prefill)
with set_ascend_forward_context(None, with set_ascend_forward_context(None,

View File

@ -2,7 +2,7 @@ import enum
from typing import Optional from typing import Optional
import torch import torch
from vllm.config import VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@ -33,7 +33,9 @@ class Proposer:
with_prefill: bool = False, with_prefill: bool = False,
skip_attn: bool = False, skip_attn: bool = False,
num_reqs: int = 0, num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None): num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
"""Called by dummy_run in modle_runner""" """Called by dummy_run in modle_runner"""
raise NotImplementedError raise NotImplementedError

View File

@ -5,8 +5,8 @@ import torch.nn as nn
import torchair import torchair
from torchair import patch_for_hcom from torchair import patch_for_hcom
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config, from vllm.config import (CUDAGraphMode, VllmConfig,
set_current_vllm_config) get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import ( from vllm.model_executor.model_loader.utils import (
@ -109,7 +109,9 @@ class MtpProposer(Proposer):
with_prefill: bool = False, with_prefill: bool = False,
skip_attn: bool = False, skip_attn: bool = False,
num_reqs: int = 0, num_reqs: int = 0,
num_tokens_across_dp=None) -> None: num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
if not self.torchair_graph_enabled: if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later # TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill, (num_tokens, num_tokens_across_dp, with_prefill,
@ -151,7 +153,9 @@ class MtpProposer(Proposer):
reserved_mc2_mask=self.runner.reserved_mc2_mask, reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type, moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run, in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0): num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
if is_running_torchair: if is_running_torchair:
assert attn_metadata is not None assert attn_metadata is not None
torch._dynamo.mark_static(input_ids) torch._dynamo.mark_static(input_ids)
@ -442,6 +446,7 @@ class MtpProposer(Proposer):
reserved_mc2_mask=self.runner.reserved_mc2_mask, reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type, moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
in_profile_run=self.runner.in_profile_run, in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens): num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'): with ProfileExecuteDuration().capture_async('mtp_forward'):

View File

@ -1,4 +1,5 @@
import torch import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.ngram_proposer import \ from vllm.v1.spec_decode.ngram_proposer import \
NgramProposer as VllmNgramProposer NgramProposer as VllmNgramProposer
@ -23,7 +24,9 @@ class NgramProposer(VllmNgramProposer, Proposer):
with_prefill=None, with_prefill=None,
skip_attn=None, skip_attn=None,
num_reqs=None, num_reqs=None,
num_tokens_across_dp=None): num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
pass pass
def generate_token_ids(self, def generate_token_ids(self,

View File

@ -2479,7 +2479,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill=with_prefill, with_prefill=with_prefill,
skip_attn=True, skip_attn=True,
num_reqs=num_reqs, num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp) num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits: if need_dummy_logits:
dummy_compute_logits(hidden_states) dummy_compute_logits(hidden_states)
if self.in_profile_run and self.dynamic_eplb: if self.in_profile_run and self.dynamic_eplb: