From 646c1db5d707969307f0b7065a98d4cdd22a3949 Mon Sep 17 00:00:00 2001 From: shaopeng-666 Date: Sat, 18 Oct 2025 18:08:24 +0800 Subject: [PATCH] Add mrope op fusion (#3509) ### What this PR does / why we need it? Add mrope fusion op for qwen2.5-vl. This mrope operator dosen't support Qwen3-VL currently. Thus could only take affect in qwen2.5-vl - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: shaopeng666 Co-authored-by: shaopeng666 --- tests/ut/ops/test_rotary_embedding.py | 86 ++++++++++++++++++++++++++- vllm_ascend/ops/rotary_embedding.py | 36 ++++++++++- vllm_ascend/utils.py | 5 +- 3 files changed, 123 insertions(+), 4 deletions(-) diff --git a/tests/ut/ops/test_rotary_embedding.py b/tests/ut/ops/test_rotary_embedding.py index 21d95bb71..3a796aed5 100644 --- a/tests/ut/ops/test_rotary_embedding.py +++ b/tests/ut/ops/test_rotary_embedding.py @@ -6,13 +6,14 @@ import torch from transformers.configuration_utils import PretrainedConfig from vllm.config import ModelConfig, VllmConfig from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding) + DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding) from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled MODEL = "Qwen3-0.6B" +MODEL_VL = "Qwen/Qwen2.5-VL-3B-Instruct" MAX_NUM_BATCHED_TOKEND = 10000 @@ -376,3 +377,86 @@ class TestAscendDeepseekScalingRotaryEmbedding(TestBase): expected, places=6, msg=f"Failed for scale={scale}, mscale={mscale}") + + +class TestAscendMRotaryEmbedding(unittest.TestCase): + + def setUp(self): + # Common setup for tests + self.number_tokens = 3 + self.num_head = 8 + self.num_kvhead = 8 + self.head_size = 128 + self.max_position_embeddings = 128000 + self.is_neox_style = True + self.rope_theta = 1000000.0 + self.positions_1d = torch.tensor([1, 2, 3]) + self.positions_2d = torch.randint(1, 10, (3, self.number_tokens)) + + self.query = torch.randn( + (self.number_tokens, self.num_head * self.head_size), + dtype=torch.bfloat16) + self.key = torch.randn( + (self.number_tokens, self.num_kvhead * self.head_size), + dtype=torch.bfloat16) + + # Qwen2.5-VL mrope section case + self.mrope_section = [16, 24, 24] + + self.layer = MRotaryEmbedding(self.head_size, + self.head_size, + self.max_position_embeddings, + base=self.rope_theta, + is_neox_style=self.is_neox_style, + dtype=torch.bfloat16, + mrope_section=self.mrope_section) + + self.mock_config = MagicMock() + self.mock_config.torchair_graph_config.enabled = False + + def _create_vllm_config(self): + vllm_config = VllmConfig() + model_config = ModelConfig(MODEL_VL, + tokenizer=MODEL_VL, + max_model_len=MAX_NUM_BATCHED_TOKEND) + model_config.hf_config = PretrainedConfig() + vllm_config.model_config = model_config + return vllm_config + + @patch('torch_npu.npu_mrope') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) + def test_forward_oot_1d_positions(self, mock_npu_mrope): + mock_npu_mrope.return_value = (torch.zeros_like(self.query), + torch.zeros_like(self.key)) + + vllm_config = self._create_vllm_config() + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward_oot( + self.positions_1d, self.query, self.key) + + mock_npu_mrope.assert_called_once() + self.assertFalse(torch.isnan(result_q).any().item()) + self.assertFalse(torch.isnan(result_k).any().item()) + self.assertEqual(result_q.shape, self.query.shape) + + @patch('torch_npu.npu_mrope') + @patch('vllm.config.ModelConfig.__post_init__', MagicMock()) + @patch('vllm.config.VllmConfig.__post_init__', MagicMock()) + @patch('vllm.distributed.parallel_state._DP', MagicMock(world_size=1)) + @patch('vllm.distributed.parallel_state._TP', MagicMock(world_size=1)) + def test_forward_oot_2d_positions(self, mock_npu_mrope): + mock_npu_mrope.return_value = (torch.zeros_like(self.query), + torch.zeros_like(self.key)) + + vllm_config = self._create_vllm_config() + with set_ascend_forward_context(None, vllm_config): + result_q, result_k = self.layer.forward_oot( + self.positions_2d, self.query, self.key) + + mock_npu_mrope.assert_called_once() + self.assertFalse(torch.isnan(result_q).any().item()) + self.assertFalse(torch.isnan(result_k).any().item()) + self.assertEqual(result_q.shape, self.query.shape) diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 69102f39e..fddc5238a 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -22,7 +22,7 @@ import torch import torch_npu from vllm.forward_context import get_forward_context from vllm.model_executor.layers.rotary_embedding import ( - DeepseekScalingRotaryEmbedding, RotaryEmbedding, + DeepseekScalingRotaryEmbedding, MRotaryEmbedding, RotaryEmbedding, YaRNScalingRotaryEmbedding) from vllm_ascend.platform import NPUPlatform @@ -395,3 +395,37 @@ class AscendDeepseekScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding): q_pe, k_pe = _rope_forward_oot(self, positions, query, key, is_neox_style, offsets) return q_pe, k_pe + + +class AscendMRotaryEmbedding(MRotaryEmbedding): + + def forward_oot( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ): + if self.mrope_section != [16, 24, 24]: + return super().forward_oot(positions, query, key) + + import torch_npu + mrope_section = [0, 0, 0 + ] if positions.ndim == 1 else self.mrope_section + + if self.cos_sin_cache.device != query.device: # type: ignore + self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore + query.device) # type: ignore + + if self.cos_sin_cache.dtype != query.dtype: # type: ignore + self.cos_sin_cache = self.cos_sin_cache.to( # type: ignore + query.dtype) # type: ignore + + query, key = torch_npu.npu_mrope(positions, + query.contiguous(), + key.contiguous(), + self.cos_sin_cache.contiguous(), + self.head_size, + mrope_section=mrope_section, + rotary_mode='half') + + return query, key diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0929e40a9..f824662ac 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -517,8 +517,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): AscendReplicatedLinear, AscendRowParallelLinear) from vllm_ascend.ops.rotary_embedding import ( - AscendDeepseekScalingRotaryEmbedding, AscendRotaryEmbedding, - AscendYaRNRotaryEmbedding) + AscendDeepseekScalingRotaryEmbedding, AscendMRotaryEmbedding, + AscendRotaryEmbedding, AscendYaRNRotaryEmbedding) from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) @@ -528,6 +528,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "QuickGELU": AscendQuickGELU, "SiluAndMul": AscendSiluAndMul, "RotaryEmbedding": AscendRotaryEmbedding, + "MRotaryEmbedding": AscendMRotaryEmbedding, "ColumnParallelLinear": AscendColumnParallelLinear, "RowParallelLinear": AscendRowParallelLinear, "YaRNScalingRotaryEmbedding": AscendYaRNRotaryEmbedding,