mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
|
||||
head_size: int, max_position_embeddings: int,
|
||||
dtype: torch.dtype, device: torch.device):
|
||||
"""Generate test data for given configuration."""
|
||||
current_platform.seed_everything(42)
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(0,
|
||||
max_position_embeddings // 4, (3, num_tokens),
|
||||
@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def unroll_model_tp_dict(model_tp_dict):
|
||||
return [(model_name, tp_size)
|
||||
for model_name, tp_sizes in model_tp_dict.items()
|
||||
for tp_size in tp_sizes]
|
||||
class MRoPETestInfo(NamedTuple):
|
||||
model_name: str
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
atol: float = 1e-2
|
||||
rtol: float = 1.6e-2
|
||||
marks: list[pytest.MarkDecorator] = []
|
||||
|
||||
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
|
||||
}
|
||||
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
dtype_atol_rtol_list = [
|
||||
[torch.bfloat16, 1e-2, 1.6e-2],
|
||||
MODELS_TO_TEST = [
|
||||
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-4B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
]),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
]),
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize("model_name, tp_size",
|
||||
unroll_model_tp_dict(model_tp_dict))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("model_info, model_name", [
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
|
||||
dtype: torch.dtype, num_tokens: int):
|
||||
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
head_dim = (config.head_dim if hasattr(config, "head_dim") else
|
||||
config.hidden_size // total_num_heads)
|
||||
is_neox_style = True
|
||||
|
||||
rope_theta = config.rope_theta
|
||||
@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tp_size",
|
||||
unroll_model_tp_dict({
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
|
||||
}))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("num_tokens", [4])
|
||||
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
num_tokens):
|
||||
@pytest.mark.parametrize("model_info, model_name", [
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope_torch_compile_tracing(model_name: str,
|
||||
model_info: MRoPETestInfo, tp_size: int,
|
||||
dtype: torch.dtype, num_tokens: int):
|
||||
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
head_dim = (config.head_dim if hasattr(config, "head_dim") else
|
||||
config.hidden_size // total_num_heads)
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _triton_qwen2vl_mrope_forward(
|
||||
def _triton_mrope_forward(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
cos,
|
||||
@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
|
||||
pad_hd: tl.constexpr,
|
||||
mrope_section_t: tl.constexpr,
|
||||
mrope_section_h: tl.constexpr,
|
||||
mrope_section_w: tl.constexpr,
|
||||
is_interleaved: tl.constexpr,
|
||||
):
|
||||
# Adapted from
|
||||
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
||||
# This version supports flatten input tensors from vllm
|
||||
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
||||
# instead of (3, bsz, seq_len, head_dim)
|
||||
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
|
||||
pid = tl.program_id(0)
|
||||
# locate start address
|
||||
q_ptr = q_ptr + pid * (n_qh * hd)
|
||||
@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
|
||||
# ####################################################################
|
||||
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
||||
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
|
||||
# Updated stride calculation for half head_dim
|
||||
half_rd = rd // 2
|
||||
t_cos = cos + pid * half_rd
|
||||
@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(
|
||||
|
||||
# Updated offsets for half head_dim
|
||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||
t_mask = cos_offsets < t_end
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
if is_interleaved:
|
||||
h_mask = (((cos_offsets % 3) == 1) &
|
||||
(cos_offsets <= 3 * mrope_section_h))
|
||||
w_mask = (((cos_offsets % 3) == 2) &
|
||||
(cos_offsets <= 3 * mrope_section_w))
|
||||
t_mask = ~(h_mask | w_mask)
|
||||
else:
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
t_mask = cos_offsets < mrope_section_t
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
|
||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||
@ -131,6 +139,7 @@ def triton_mrope(
|
||||
mrope_section: list[int],
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
mrope_interleaved: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Qwen2VL mrope kernel.
|
||||
|
||||
@ -158,7 +167,7 @@ def triton_mrope(
|
||||
cos = cos.contiguous()
|
||||
sin = sin.contiguous()
|
||||
|
||||
_triton_qwen2vl_mrope_forward[(n_row, )](
|
||||
_triton_mrope_forward[(n_row, )](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
@ -173,6 +182,8 @@ def triton_mrope(
|
||||
pad_hd,
|
||||
mrope_section[0],
|
||||
mrope_section[1],
|
||||
mrope_section[2],
|
||||
mrope_interleaved,
|
||||
)
|
||||
return q, k
|
||||
|
||||
@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[list[int]] = None,
|
||||
mrope_interleaved: Optional[bool] = False,
|
||||
mrope_interleaved: bool = False,
|
||||
) -> None:
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
@ -282,10 +293,6 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
if self.mrope_interleaved:
|
||||
# TODO: add triton implementation to support mrope-interleaved
|
||||
return self.forward_native(positions, query, key)
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -302,6 +309,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
Reference in New Issue
Block a user