mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Add mrope op fusion (#3509)
### What this PR does / why we need it? Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't support Qwen3-VL currently. Thus could only take affect in qwen2.5-vl - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com> Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
@ -6,13 +6,14 @@ import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||
|
||||
MODEL = "Qwen3-0.6B"
|
||||
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
MAX_NUM_BATCHED_TOKEND = 10000
|
||||
|
||||
|
||||
@ -376,3 +377,86 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
|
||||
|
||||
class TestAscendMRotaryEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.number_tokens = 3
|
||||
self.num_head = 8
|
||||
self.num_kvhead = 8
|
||||
self.head_size = 128
|
||||
self.max_position_embeddings = 128000
|
||||
self.is_neox_style = True
|
||||
self.rope_theta = 1000000.0
|
||||
self.positions_1d = torch.tensor([1, 2, 3])
|
||||
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
|
||||
|
||||
self.query = torch.randn(
|
||||
(self.number_tokens, self.num_head * self.head_size),
|
||||
dtype=torch.bfloat16)
|
||||
self.key = torch.randn(
|
||||
(self.number_tokens, self.num_kvhead * self.head_size),
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
# Qwen2.5-VL mrope section case
|
||||
self.mrope_section = [16, 24, 24]
|
||||
|
||||
self.layer = MRotaryEmbedding(self.head_size,
|
||||
self.head_size,
|
||||
self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
is_neox_style=self.is_neox_style,
|
||||
dtype=torch.bfloat16,
|
||||
mrope_section=self.mrope_section)
|
||||
|
||||
self.mock_config = MagicMock()
|
||||
self.mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
def _create_vllm_config(self):
|
||||
vllm_config = VllmConfig()
|
||||
model_config = ModelConfig(MODEL_VL,
|
||||
tokenizer=MODEL_VL,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKEND)
|
||||
model_config.hf_config = PretrainedConfig()
|
||||
vllm_config.model_config = model_config
|
||||
return vllm_config
|
||||
|
||||
@patch('torch_npu.npu_mrope')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_forward_oot_1d_positions(self, mock_npu_mrope):
|
||||
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
|
||||
torch.zeros_like(self.key))
|
||||
|
||||
vllm_config = self._create_vllm_config()
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward_oot(
|
||||
self.positions_1d, self.query, self.key)
|
||||
|
||||
mock_npu_mrope.assert_called_once()
|
||||
self.assertFalse(torch.isnan(result_q).any().item())
|
||||
self.assertFalse(torch.isnan(result_k).any().item())
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
|
||||
@patch('torch_npu.npu_mrope')
|
||||
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
|
||||
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
|
||||
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
|
||||
def test_forward_oot_2d_positions(self, mock_npu_mrope):
|
||||
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
|
||||
torch.zeros_like(self.key))
|
||||
|
||||
vllm_config = self._create_vllm_config()
|
||||
with set_ascend_forward_context(None, vllm_config):
|
||||
result_q, result_k = self.layer.forward_oot(
|
||||
self.positions_2d, self.query, self.key)
|
||||
|
||||
mock_npu_mrope.assert_called_once()
|
||||
self.assertFalse(torch.isnan(result_q).any().item())
|
||||
self.assertFalse(torch.isnan(result_k).any().item())
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
|
@ -22,7 +22,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding,
|
||||
DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
|
||||
YaRNScalingRotaryEmbedding)
|
||||
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@ -395,3 +395,37 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
|
||||
q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
|
||||
is_neox_style, offsets)
|
||||
return q_pe, k_pe
|
||||
|
||||
|
||||
class AscendMRotaryEmbedding(MRotaryEmbedding):
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
):
|
||||
if self.mrope_section != [16, 24, 24]:
|
||||
return super().forward_oot(positions, query, key)
|
||||
|
||||
import torch_npu
|
||||
mrope_section = [0, 0, 0
|
||||
] if positions.ndim == 1 else self.mrope_section
|
||||
|
||||
if self.cos_sin_cache.device != query.device: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.device) # type: ignore
|
||||
|
||||
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
|
||||
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
|
||||
query.dtype) # type: ignore
|
||||
|
||||
query, key = torch_npu.npu_mrope(positions,
|
||||
query.contiguous(),
|
||||
key.contiguous(),
|
||||
self.cos_sin_cache.contiguous(),
|
||||
self.head_size,
|
||||
mrope_section=mrope_section,
|
||||
rotary_mode='half')
|
||||
|
||||
return query, key
|
||||
|
@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
AscendReplicatedLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.ops.rotary_embedding import (
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding,
|
||||
AscendYaRNRotaryEmbedding)
|
||||
AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
|
||||
AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead,
|
||||
AscendVocabParallelEmbedding)
|
||||
@ -528,6 +528,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
"QuickGELU": AscendQuickGELU,
|
||||
"SiluAndMul": AscendSiluAndMul,
|
||||
"RotaryEmbedding": AscendRotaryEmbedding,
|
||||
"MRotaryEmbedding": AscendMRotaryEmbedding,
|
||||
"ColumnParallelLinear": AscendColumnParallelLinear,
|
||||
"RowParallelLinear": AscendRowParallelLinear,
|
||||
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,
|
||||
|
Reference in New Issue
Block a user