From cea91a32f2364d19d5e708026e84ce21a450c53d Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Fri, 19 Sep 2025 18:27:49 +0800 Subject: [PATCH] [Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055) Signed-off-by: Isotr0py --- tests/kernels/core/test_mrope.py | 92 +++++++++++++------ .../layers/rotary_embedding/mrope.py | 36 +++++--- 2 files changed, 85 insertions(+), 43 deletions(-) diff --git a/tests/kernels/core/test_mrope.py b/tests/kernels/core/test_mrope.py index 3f2f330f6d..5a903438f5 100644 --- a/tests/kernels/core/test_mrope.py +++ b/tests/kernels/core/test_mrope.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import NamedTuple import pytest import torch +from packaging.version import Version from transformers import AutoConfig +from transformers import __version__ as TRANSFORMERS_VERSION from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.platforms import current_platform @@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, head_size: int, max_position_embeddings: int, dtype: torch.dtype, device: torch.device): """Generate test data for given configuration.""" + current_platform.seed_everything(42) # Create 2D positions (3, num_tokens) for multimodal case positions = torch.randint(0, max_position_embeddings // 4, (3, num_tokens), @@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int, return positions, query, key -def unroll_model_tp_dict(model_tp_dict): - return [(model_name, tp_size) - for model_name, tp_sizes in model_tp_dict.items() - for tp_size in tp_sizes] +class MRoPETestInfo(NamedTuple): + model_name: str + # https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 + atol: float = 1e-2 + rtol: float = 1.6e-2 + marks: list[pytest.MarkDecorator] = [] -model_tp_dict = { - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "Qwen/Qwen2-VL-72B-Instruct": [1, 2], - "Qwen/Qwen2.5-VL-72B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2], -} +TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version -# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317 -dtype_atol_rtol_list = [ - [torch.bfloat16, 1e-2, 1.6e-2], +MODELS_TO_TEST = [ + MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"), + MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-4B-Instruct", + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ]), + MRoPETestInfo( + model_name="Qwen/Qwen3-VL-30B-A3B-Instruct", + marks=[ + pytest.mark.skipif( + Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"), + reason="Qwen3-VL only available after Transformers v4.57", + ) + ]), ] num_tokens_list = [11, 8192] @@ -56,20 +75,29 @@ num_tokens_list = [11, 8192] @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize("model_name, tp_size", - unroll_model_tp_dict(model_tp_dict)) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) +@pytest.mark.parametrize("model_info, model_name", [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST +]) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("num_tokens", num_tokens_list) -def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): +def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int, + dtype: torch.dtype, num_tokens: int): + + atol = model_info.atol + rtol = model_info.rtol config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = (config.head_dim if hasattr(config, "head_dim") else + config.hidden_size // total_num_heads) is_neox_style = True rope_theta = config.rope_theta @@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens): @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests.") -@pytest.mark.parametrize( - "model_name, tp_size", - unroll_model_tp_dict({ - "Qwen/Qwen2-VL-7B-Instruct": [1, 2], - "zai-org/GLM-4.1V-9B-Thinking": [1, 2] - })) -@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list) -@pytest.mark.parametrize("num_tokens", [4]) -def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol, - num_tokens): +@pytest.mark.parametrize("model_info, model_name", [ + pytest.param(test_config, test_config.model_name, marks=test_config.marks) + for test_config in MODELS_TO_TEST +]) +@pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("num_tokens", num_tokens_list) +def test_mrope_torch_compile_tracing(model_name: str, + model_info: MRoPETestInfo, tp_size: int, + dtype: torch.dtype, num_tokens: int): + + atol = model_info.atol + rtol = model_info.rtol + config = AutoConfig.from_pretrained(model_name) + config = config.get_text_config() # get the model config total_num_kv_heads = config.num_key_value_heads total_num_heads = config.num_attention_heads num_heads = total_num_heads // tp_size num_kv_heads = max(1, total_num_kv_heads // tp_size) - head_dim = config.hidden_size // total_num_heads + head_dim = (config.head_dim if hasattr(config, "head_dim") else + config.hidden_size // total_num_heads) is_neox_style = True rope_theta = config.rope_theta max_position = config.max_position_embeddings diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index ef61dbc1a5..ccc59bbbe2 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch @triton.jit -def _triton_qwen2vl_mrope_forward( +def _triton_mrope_forward( q_ptr, k_ptr, cos, @@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward( pad_hd: tl.constexpr, mrope_section_t: tl.constexpr, mrope_section_h: tl.constexpr, + mrope_section_w: tl.constexpr, + is_interleaved: tl.constexpr, ): # Adapted from # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py # This version supports flatten input tensors from vllm # and supports cos and sin cache with shape (3, num_tokens, head_dim // 2) - # instead of (3, bsz, seq_len, head_dim) + # instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary pid = tl.program_id(0) # locate start address q_ptr = q_ptr + pid * (n_qh * hd) @@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward( # #################################################################### # Note: cos and sin now have shape (3, num_tokens, head_dim // 2) - t_end = mrope_section_t - h_end = t_end + mrope_section_h - # Updated stride calculation for half head_dim half_rd = rd // 2 t_cos = cos + pid * half_rd @@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward( # Updated offsets for half head_dim cos_offsets = tl.arange(0, pad_hd // 2) - t_mask = cos_offsets < t_end - h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) - w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) + if is_interleaved: + h_mask = (((cos_offsets % 3) == 1) & + (cos_offsets <= 3 * mrope_section_h)) + w_mask = (((cos_offsets % 3) == 2) & + (cos_offsets <= 3 * mrope_section_w)) + t_mask = ~(h_mask | w_mask) + else: + t_end = mrope_section_t + h_end = t_end + mrope_section_h + t_mask = cos_offsets < mrope_section_t + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) + w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd) t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) @@ -131,6 +139,7 @@ def triton_mrope( mrope_section: list[int], head_size: int, rotary_dim: int, + mrope_interleaved: bool, ) -> tuple[torch.Tensor, torch.Tensor]: """Qwen2VL mrope kernel. @@ -158,7 +167,7 @@ def triton_mrope( cos = cos.contiguous() sin = sin.contiguous() - _triton_qwen2vl_mrope_forward[(n_row, )]( + _triton_mrope_forward[(n_row, )]( q, k, cos, @@ -173,6 +182,8 @@ def triton_mrope( pad_hd, mrope_section[0], mrope_section[1], + mrope_section[2], + mrope_interleaved, ) return q, k @@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding): is_neox_style: bool, dtype: torch.dtype, mrope_section: Optional[list[int]] = None, - mrope_interleaved: Optional[bool] = False, + mrope_interleaved: bool = False, ) -> None: # In Qwen2.5-VL, the maximum index value is related to the duration of # the input video. We enlarge max_position_embeddings to 4 times to get @@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding): assert positions.ndim == 1 or positions.ndim == 2 assert key is not None - if self.mrope_interleaved: - # TODO: add triton implementation to support mrope-interleaved - return self.forward_native(positions, query, key) - num_tokens = positions.shape[-1] cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) @@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding): self.mrope_section, self.head_size, self.rotary_dim, + self.mrope_interleaved, ) return q.reshape(query_shape), k.reshape(key_shape)