mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat]mtp aclgraph support (#3244)
### What this PR does / why we need it? Currently, MTP Model in deepseek can not be capture in ACLGraph. This PR is use to allow MTP to be captured in ACLGraph mode. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
This commit is contained in:
@ -23,6 +23,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||||
get_current_vllm_config)
|
get_current_vllm_config)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -179,6 +180,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (CompilationLevel, VllmConfig,
|
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
@ -114,7 +114,9 @@ class EagleProposer(Proposer):
|
|||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
skip_attn: bool = False,
|
skip_attn: bool = False,
|
||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
moe_comm_type = self.runner._select_moe_comm_method(
|
moe_comm_type = self.runner._select_moe_comm_method(
|
||||||
num_tokens, with_prefill)
|
num_tokens, with_prefill)
|
||||||
with set_ascend_forward_context(None,
|
with set_ascend_forward_context(None,
|
||||||
|
@ -2,7 +2,7 @@ import enum
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@ -33,7 +33,9 @@ class Proposer:
|
|||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
skip_attn: bool = False,
|
skip_attn: bool = False,
|
||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp: Optional[torch.Tensor] = None):
|
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||||
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
"""Called by dummy_run in modle_runner"""
|
"""Called by dummy_run in modle_runner"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ import torch.nn as nn
|
|||||||
import torchair
|
import torchair
|
||||||
from torchair import patch_for_hcom
|
from torchair import patch_for_hcom
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
|
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||||
set_current_vllm_config)
|
get_layers_from_vllm_config, set_current_vllm_config)
|
||||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
@ -109,7 +109,9 @@ class MtpProposer(Proposer):
|
|||||||
with_prefill: bool = False,
|
with_prefill: bool = False,
|
||||||
skip_attn: bool = False,
|
skip_attn: bool = False,
|
||||||
num_reqs: int = 0,
|
num_reqs: int = 0,
|
||||||
num_tokens_across_dp=None) -> None:
|
num_tokens_across_dp=None,
|
||||||
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None) -> None:
|
||||||
if not self.torchair_graph_enabled:
|
if not self.torchair_graph_enabled:
|
||||||
# TODO: adapt enable_dbo later
|
# TODO: adapt enable_dbo later
|
||||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||||
@ -151,7 +153,9 @@ class MtpProposer(Proposer):
|
|||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
moe_comm_type=moe_comm_type,
|
moe_comm_type=moe_comm_type,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=0):
|
num_actual_tokens=0,
|
||||||
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor):
|
||||||
if is_running_torchair:
|
if is_running_torchair:
|
||||||
assert attn_metadata is not None
|
assert attn_metadata is not None
|
||||||
torch._dynamo.mark_static(input_ids)
|
torch._dynamo.mark_static(input_ids)
|
||||||
@ -442,6 +446,7 @@ class MtpProposer(Proposer):
|
|||||||
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
reserved_mc2_mask=self.runner.reserved_mc2_mask,
|
||||||
moe_comm_type=moe_comm_type,
|
moe_comm_type=moe_comm_type,
|
||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor,
|
||||||
in_profile_run=self.runner.in_profile_run,
|
in_profile_run=self.runner.in_profile_run,
|
||||||
num_actual_tokens=num_tokens):
|
num_actual_tokens=num_tokens):
|
||||||
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
with ProfileExecuteDuration().capture_async('mtp_forward'):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from vllm.config import CUDAGraphMode
|
||||||
from vllm.v1.spec_decode.ngram_proposer import \
|
from vllm.v1.spec_decode.ngram_proposer import \
|
||||||
NgramProposer as VllmNgramProposer
|
NgramProposer as VllmNgramProposer
|
||||||
|
|
||||||
@ -23,7 +24,9 @@ class NgramProposer(VllmNgramProposer, Proposer):
|
|||||||
with_prefill=None,
|
with_prefill=None,
|
||||||
skip_attn=None,
|
skip_attn=None,
|
||||||
num_reqs=None,
|
num_reqs=None,
|
||||||
num_tokens_across_dp=None):
|
num_tokens_across_dp=None,
|
||||||
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
|
batch_descriptor=None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def generate_token_ids(self,
|
def generate_token_ids(self,
|
||||||
|
@ -2479,7 +2479,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
with_prefill=with_prefill,
|
with_prefill=with_prefill,
|
||||||
skip_attn=True,
|
skip_attn=True,
|
||||||
num_reqs=num_reqs,
|
num_reqs=num_reqs,
|
||||||
num_tokens_across_dp=num_tokens_across_dp)
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
|
batch_descriptor=batch_descriptor)
|
||||||
if need_dummy_logits:
|
if need_dummy_logits:
|
||||||
dummy_compute_logits(hidden_states)
|
dummy_compute_logits(hidden_states)
|
||||||
if self.in_profile_run and self.dynamic_eplb:
|
if self.in_profile_run and self.dynamic_eplb:
|
||||||
|
Reference in New Issue
Block a user