[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-19 18:27:49 +08:00
committed by GitHub
parent a684c0124c
commit cea91a32f2
2 changed files with 85 additions and 43 deletions

View File

@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device):
"""Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens),
@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return positions, query, key
def unroll_model_tp_dict(model_tp_dict):
return [(model_name, tp_size)
for model_name, tp_sizes in model_tp_dict.items()
for tp_size in tp_sizes]
class MRoPETestInfo(NamedTuple):
model_name: str
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
atol: float = 1e-2
rtol: float = 1.6e-2
marks: list[pytest.MarkDecorator] = []
model_tp_dict = {
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
}
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list = [
[torch.bfloat16, 1e-2, 1.6e-2],
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
]
num_tokens_list = [11, 8192]
@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_name, tp_size",
unroll_model_tp_dict(model_tp_dict))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True
rope_theta = config.rope_theta
@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_name, tp_size",
unroll_model_tp_dict({
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
}))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", [4])
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
num_tokens):
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(model_name: str,
model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings

View File

@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
@triton.jit
def _triton_qwen2vl_mrope_forward(
def _triton_mrope_forward(
q_ptr,
k_ptr,
cos,
@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
pad_hd: tl.constexpr,
mrope_section_t: tl.constexpr,
mrope_section_h: tl.constexpr,
mrope_section_w: tl.constexpr,
is_interleaved: tl.constexpr,
):
# Adapted from
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
# This version supports flatten input tensors from vllm
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
# instead of (3, bsz, seq_len, head_dim)
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
pid = tl.program_id(0)
# locate start address
q_ptr = q_ptr + pid * (n_qh * hd)
@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
# ####################################################################
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
t_end = mrope_section_t
h_end = t_end + mrope_section_h
# Updated stride calculation for half head_dim
half_rd = rd // 2
t_cos = cos + pid * half_rd
@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(
# Updated offsets for half head_dim
cos_offsets = tl.arange(0, pad_hd // 2)
t_mask = cos_offsets < t_end
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
if is_interleaved:
h_mask = (((cos_offsets % 3) == 1) &
(cos_offsets <= 3 * mrope_section_h))
w_mask = (((cos_offsets % 3) == 2) &
(cos_offsets <= 3 * mrope_section_w))
t_mask = ~(h_mask | w_mask)
else:
t_end = mrope_section_t
h_end = t_end + mrope_section_h
t_mask = cos_offsets < mrope_section_t
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
@ -131,6 +139,7 @@ def triton_mrope(
mrope_section: list[int],
head_size: int,
rotary_dim: int,
mrope_interleaved: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Qwen2VL mrope kernel.
@ -158,7 +167,7 @@ def triton_mrope(
cos = cos.contiguous()
sin = sin.contiguous()
_triton_qwen2vl_mrope_forward[(n_row, )](
_triton_mrope_forward[(n_row, )](
q,
k,
cos,
@ -173,6 +182,8 @@ def triton_mrope(
pad_hd,
mrope_section[0],
mrope_section[1],
mrope_section[2],
mrope_interleaved,
)
return q, k
@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
dtype: torch.dtype,
mrope_section: Optional[list[int]] = None,
mrope_interleaved: Optional[bool] = False,
mrope_interleaved: bool = False,
) -> None:
# In Qwen2.5-VL, the maximum index value is related to the duration of
# the input video. We enlarge max_position_embeddings to 4 times to get
@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
assert positions.ndim == 1 or positions.ndim == 2
assert key is not None
if self.mrope_interleaved:
# TODO: add triton implementation to support mrope-interleaved
return self.forward_native(positions, query, key)
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
self.mrope_section,
self.head_size,
self.rotary_dim,
self.mrope_interleaved,
)
return q.reshape(query_shape), k.reshape(key_shape)