Add mrope op fusion (#3509)

### What this PR does / why we need it?
Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't support
Qwen3-VL currently. Thus could only take affect in qwen2.5-vl

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
Co-authored-by: shaopeng666 <shaopeng666@noreply.gitcode.com>
This commit is contained in:
shaopeng-666
2025-10-18 18:08:24 +08:00
committed by GitHub
parent 0777e2f899
commit 646c1db5d7
3 changed files with 123 additions and 4 deletions

View File

@ -6,13 +6,14 @@ import torch
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding) DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
MODEL = "Qwen3-0.6B" MODEL = "Qwen3-0.6B"
MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct"
MAX_NUM_BATCHED_TOKEND = 10000 MAX_NUM_BATCHED_TOKEND = 10000
@ -376,3 +377,86 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
expected, expected,
places=6, places=6,
msg=f"Failed for scale={scale}, mscale={mscale}") msg=f"Failed for scale={scale}, mscale={mscale}")
class TestAscendMRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.number_tokens = 3
self.num_head = 8
self.num_kvhead = 8
self.head_size = 128
self.max_position_embeddings = 128000
self.is_neox_style = True
self.rope_theta = 1000000.0
self.positions_1d = torch.tensor([1, 2, 3])
self.positions_2d = torch.randint(1, 10, (3, self.number_tokens))
self.query = torch.randn(
(self.number_tokens, self.num_head * self.head_size),
dtype=torch.bfloat16)
self.key = torch.randn(
(self.number_tokens, self.num_kvhead * self.head_size),
dtype=torch.bfloat16)
# Qwen2.5-VL mrope section case
self.mrope_section = [16, 24, 24]
self.layer = MRotaryEmbedding(self.head_size,
self.head_size,
self.max_position_embeddings,
base=self.rope_theta,
is_neox_style=self.is_neox_style,
dtype=torch.bfloat16,
mrope_section=self.mrope_section)
self.mock_config = MagicMock()
self.mock_config.torchair_graph_config.enabled = False
def _create_vllm_config(self):
vllm_config = VllmConfig()
model_config = ModelConfig(MODEL_VL,
tokenizer=MODEL_VL,
max_model_len=MAX_NUM_BATCHED_TOKEND)
model_config.hf_config = PretrainedConfig()
vllm_config.model_config = model_config
return vllm_config
@patch('torch_npu.npu_mrope')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_1d_positions(self, mock_npu_mrope):
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))
vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_1d, self.query, self.key)
mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)
@patch('torch_npu.npu_mrope')
@patch('vllm.config.ModelConfig.__post_init__', MagicMock())
@patch('vllm.config.VllmConfig.__post_init__', MagicMock())
@patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1))
@patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1))
def test_forward_oot_2d_positions(self, mock_npu_mrope):
mock_npu_mrope.return_value = (torch.zeros_like(self.query),
torch.zeros_like(self.key))
vllm_config = self._create_vllm_config()
with set_ascend_forward_context(None, vllm_config):
result_q, result_k = self.layer.forward_oot(
self.positions_2d, self.query, self.key)
mock_npu_mrope.assert_called_once()
self.assertFalse(torch.isnan(result_q).any().item())
self.assertFalse(torch.isnan(result_k).any().item())
self.assertEqual(result_q.shape, self.query.shape)

View File

@ -22,7 +22,7 @@ import torch
import torch_npu import torch_npu
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding, DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding,
YaRNScalingRotaryEmbedding) YaRNScalingRotaryEmbedding)
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
@ -395,3 +395,37 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
q_pe, k_pe = _rope_forward_oot(self, positions, query, key, q_pe, k_pe = _rope_forward_oot(self, positions, query, key,
is_neox_style, offsets) is_neox_style, offsets)
return q_pe, k_pe return q_pe, k_pe
class AscendMRotaryEmbedding(MRotaryEmbedding):
def forward_oot(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
):
if self.mrope_section != [16, 24, 24]:
return super().forward_oot(positions, query, key)
import torch_npu
mrope_section = [0, 0, 0
] if positions.ndim == 1 else self.mrope_section
if self.cos_sin_cache.device != query.device: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.device) # type: ignore
if self.cos_sin_cache.dtype != query.dtype: # type: ignore
self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore
query.dtype) # type: ignore
query, key = torch_npu.npu_mrope(positions,
query.contiguous(),
key.contiguous(),
self.cos_sin_cache.contiguous(),
self.head_size,
mrope_section=mrope_section,
rotary_mode='half')
return query, key

View File

@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
AscendReplicatedLinear, AscendReplicatedLinear,
AscendRowParallelLinear) AscendRowParallelLinear)
from vllm_ascend.ops.rotary_embedding import ( from vllm_ascend.ops.rotary_embedding import (
AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding,
AscendYaRNRotaryEmbedding) AscendRotaryEmbedding, AscendYaRNRotaryEmbedding)
from vllm_ascend.ops.vocab_parallel_embedding import ( from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendLogitsProcessor, AscendParallelLMHead,
AscendVocabParallelEmbedding) AscendVocabParallelEmbedding)
@ -528,6 +528,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"QuickGELU": AscendQuickGELU, "QuickGELU": AscendQuickGELU,
"SiluAndMul": AscendSiluAndMul, "SiluAndMul": AscendSiluAndMul,
"RotaryEmbedding": AscendRotaryEmbedding, "RotaryEmbedding": AscendRotaryEmbedding,
"MRotaryEmbedding": AscendMRotaryEmbedding,
"ColumnParallelLinear": AscendColumnParallelLinear, "ColumnParallelLinear": AscendColumnParallelLinear,
"RowParallelLinear": AscendRowParallelLinear, "RowParallelLinear": AscendRowParallelLinear,
"YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding, "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,