mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[FEAT] [Performance] Add triton mrope to replace the torch code path (#22375)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
328
benchmarks/kernels/benchmark_mrope.py
Normal file
328
benchmarks/kernels/benchmark_mrope.py
Normal file
@ -0,0 +1,328 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||
#
|
||||
# The CSV file (named with current date/time) contains these columns:
|
||||
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
|
||||
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||
# speedup
|
||||
#
|
||||
# == Usage Examples ==
|
||||
#
|
||||
# Single model benchmark:
|
||||
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
|
||||
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models benchmark:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different TP sizes:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||
#
|
||||
# All models with different token counts:
|
||||
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
num_tokens: int,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
max_position_embeddings: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Generate test data for given configuration."""
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(
|
||||
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||
)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||
"""Calculate statistics from a list of times."""
|
||||
times_array = np.array(times)
|
||||
return {
|
||||
"mean": np.mean(times_array),
|
||||
"median": np.median(times_array),
|
||||
"p99": np.percentile(times_array, 99),
|
||||
"min": np.min(times_array),
|
||||
"max": np.max(times_array),
|
||||
}
|
||||
|
||||
|
||||
def benchmark_mrope(
|
||||
model_name: str,
|
||||
num_tokens: int,
|
||||
head_dim: int,
|
||||
tp_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
seed: int = 0,
|
||||
warmup_iter: int = 10,
|
||||
benchmark_iter: int = 100,
|
||||
csv_writer=None,
|
||||
):
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
# the parameters to compute the q k v size based on tp_size
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=rope_scaling,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
print(80 * "=")
|
||||
print(
|
||||
f"Evaluating model: {model_name} "
|
||||
f"with tp_size: {tp_size} "
|
||||
f"and num_tokens: {num_tokens}, "
|
||||
f"dtype: {dtype}"
|
||||
)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(
|
||||
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||
)
|
||||
|
||||
# Warm up
|
||||
for _ in range(warmup_iter):
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Time reference implementation
|
||||
torch_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
|
||||
mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch_times.append(time.time() - start_time)
|
||||
|
||||
# Time triton kernel implementation
|
||||
triton_times = []
|
||||
for _ in range(benchmark_iter):
|
||||
query_clone = query.clone()
|
||||
key_clone = key.clone()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query_clone,
|
||||
key_clone,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
triton_times.append(time.time() - start_time)
|
||||
|
||||
# Calculate statistics
|
||||
torch_stats = calculate_stats(torch_times)
|
||||
triton_stats = calculate_stats(triton_times)
|
||||
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||
|
||||
print(
|
||||
f"Torch implementation: "
|
||||
f"mean={torch_stats['mean']:.8f}s, "
|
||||
f"median={torch_stats['median']:.8f}s, "
|
||||
f"p99={torch_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton implementation: "
|
||||
f"mean={triton_stats['mean']:.8f}s, "
|
||||
f"median={triton_stats['median']:.8f}s, "
|
||||
f"p99={triton_stats['p99']:.8f}s"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||
)
|
||||
|
||||
# Write to CSV
|
||||
if csv_writer:
|
||||
row = [
|
||||
model_name,
|
||||
tp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
max_position,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
str(rope_scaling),
|
||||
str(dtype).split(".")[-1],
|
||||
torch_stats["mean"],
|
||||
torch_stats["median"],
|
||||
torch_stats["p99"],
|
||||
torch_stats["min"],
|
||||
torch_stats["max"],
|
||||
triton_stats["mean"],
|
||||
triton_stats["median"],
|
||||
triton_stats["p99"],
|
||||
triton_stats["min"],
|
||||
triton_stats["max"],
|
||||
torch_stats["mean"] / triton_stats["mean"], # speedup
|
||||
]
|
||||
csv_writer.writerow(row)
|
||||
|
||||
return torch_stats, triton_stats
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the rotary embedding kernels."
|
||||
)
|
||||
parser.add_argument("--model-name", type=str, default="")
|
||||
parser.add_argument("--tp-size", type=int, default=1)
|
||||
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
# Create CSV file for results
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
|
||||
|
||||
with open(csv_filename, "w", newline="") as csvfile:
|
||||
csv_writer = csv.writer(csvfile)
|
||||
# Write header
|
||||
header = [
|
||||
"model_name",
|
||||
"tp_size",
|
||||
"num_tokens",
|
||||
"num_heads",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_position",
|
||||
"rope_theta",
|
||||
"is_neox_style",
|
||||
"rope_scaling",
|
||||
"dtype",
|
||||
"torch_mean",
|
||||
"torch_median",
|
||||
"torch_p99",
|
||||
"torch_min",
|
||||
"torch_max",
|
||||
"triton_mean",
|
||||
"triton_median",
|
||||
"triton_p99",
|
||||
"triton_min",
|
||||
"triton_max",
|
||||
"speedup",
|
||||
]
|
||||
csv_writer.writerow(header)
|
||||
|
||||
model_tp_dict = {}
|
||||
if args.model_name == "":
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||
}
|
||||
else:
|
||||
model_tp_dict[args.model_name] = [args.tp_size]
|
||||
|
||||
if args.num_tokens is None:
|
||||
num_tokens_list = [2**i for i in range(0, 18)]
|
||||
else:
|
||||
num_tokens_list = args.num_tokens
|
||||
|
||||
for model_name, tp_list in model_tp_dict.items():
|
||||
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
|
||||
for tp_size in tp_list:
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
q_size = num_heads * head_dim
|
||||
kv_size = num_kv_heads * head_dim
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
for num_tokens in num_tokens_list:
|
||||
benchmark_mrope(
|
||||
model_name=model_name,
|
||||
num_tokens=num_tokens,
|
||||
head_dim=head_dim,
|
||||
tp_size=tp_size,
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
max_position=max_position,
|
||||
rope_theta=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
dtype=getattr(torch, args.dtype),
|
||||
seed=args.seed,
|
||||
warmup_iter=args.warmup_iter,
|
||||
benchmark_iter=args.benchmark_iter,
|
||||
csv_writer=csv_writer,
|
||||
)
|
||||
|
||||
print(f"Benchmark results saved to {csv_filename}")
|
207
tests/kernels/test_mrope.py
Normal file
207
tests/kernels/test_mrope.py
Normal file
@ -0,0 +1,207 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
|
||||
head_size: int, max_position_embeddings: int,
|
||||
dtype: torch.dtype, device: torch.device):
|
||||
"""Generate test data for given configuration."""
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(0,
|
||||
max_position_embeddings // 4, (3, num_tokens),
|
||||
device=device)
|
||||
|
||||
# Create query and key tensors
|
||||
query = torch.randn(num_tokens,
|
||||
num_q_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
key = torch.randn(num_tokens,
|
||||
num_kv_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def unroll_model_tp_dict(model_tp_dict):
|
||||
return [(model_name, tp_size)
|
||||
for model_name, tp_sizes in model_tp_dict.items()
|
||||
for tp_size in tp_sizes]
|
||||
|
||||
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2]
|
||||
}
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
dtype_atol_rtol_list = [
|
||||
[torch.bfloat16, 1e-5, 1.6e-2],
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize("model_name, tp_size",
|
||||
unroll_model_tp_dict(model_tp_dict))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
is_neox_style = True
|
||||
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# create q k v input tensors
|
||||
# create rotary pos emb input tensors
|
||||
positions, query, key = generate_test_data(num_tokens, num_heads,
|
||||
num_kv_heads, head_dim,
|
||||
max_position, dtype, device)
|
||||
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
query_cuda, key_cuda = mrope_helper_class.forward_cuda(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
torch.testing.assert_close(query_native, query_cuda, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(key_native, key_cuda, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tp_size",
|
||||
unroll_model_tp_dict({"Qwen/Qwen2-VL-7B-Instruct": [1, 2]}))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("num_tokens", [4])
|
||||
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
num_tokens):
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
||||
mrope_helper_class = get_rope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=head_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
# Generate test data
|
||||
positions, query, key = generate_test_data(num_tokens, num_heads,
|
||||
num_kv_heads, head_dim,
|
||||
max_position, dtype, device)
|
||||
|
||||
# Create a wrapper that makes the in-place function appear functional
|
||||
def functional_forward_cuda(pos, q, k):
|
||||
"""Wrapper that converts in-place operation to functional style
|
||||
|
||||
CUDA Graph does not support in-place operations.
|
||||
This wrapper creates working copies of the
|
||||
input tensors and modifies them.
|
||||
"""
|
||||
q_work = q.clone() # Create working copies
|
||||
k_work = k.clone()
|
||||
# Your in-place function modifies q_work and k_work
|
||||
mrope_helper_class.forward_cuda(pos, q_work, k_work)
|
||||
return q_work, k_work # Return the modified tensors
|
||||
|
||||
# Get reference results
|
||||
query_native, key_native = mrope_helper_class.forward_native(
|
||||
positions,
|
||||
query.clone(),
|
||||
key.clone(),
|
||||
)
|
||||
|
||||
try:
|
||||
compiled_forward_cuda = torch.compile(functional_forward_cuda,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False)
|
||||
|
||||
# Run compiled version
|
||||
query_compiled_cuda, key_compiled_cuda = compiled_forward_cuda(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
)
|
||||
|
||||
# Run original version for comparison
|
||||
query_cuda = query.clone()
|
||||
key_cuda = key.clone()
|
||||
mrope_helper_class.forward_cuda(positions, query_cuda, key_cuda)
|
||||
|
||||
# Verify results
|
||||
torch.testing.assert_close(query_compiled_cuda,
|
||||
query_cuda,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(key_compiled_cuda,
|
||||
key_cuda,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(query_compiled_cuda,
|
||||
query_native,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
torch.testing.assert_close(key_compiled_cuda,
|
||||
key_native,
|
||||
atol=atol,
|
||||
rtol=rtol)
|
||||
|
||||
print("✓ forward_cuda successfully traced with torch.compile inductor")
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(
|
||||
f"forward_cuda failed to trace with torch.compile inductor: {e}")
|
@ -8,10 +8,173 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _triton_qwen2vl_mrope_forward(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
cos,
|
||||
sin,
|
||||
num_tokens,
|
||||
n_qh: tl.constexpr,
|
||||
n_kh: tl.constexpr,
|
||||
hd: tl.constexpr,
|
||||
pad_n_qh: tl.constexpr,
|
||||
pad_n_kh: tl.constexpr,
|
||||
pad_hd: tl.constexpr,
|
||||
mrope_section_t: tl.constexpr,
|
||||
mrope_section_h: tl.constexpr,
|
||||
):
|
||||
# Adapted from
|
||||
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
||||
# This version supports flatten input tensors from vllm
|
||||
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
||||
# instead of (3, bsz, seq_len, head_dim)
|
||||
pid = tl.program_id(0)
|
||||
# locate start address
|
||||
q_ptr = q_ptr + pid * (n_qh * hd)
|
||||
k_ptr = k_ptr + pid * (n_kh * hd)
|
||||
|
||||
# ####################################################################
|
||||
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
|
||||
# m of this program instance
|
||||
# ####################################################################
|
||||
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
||||
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
|
||||
# Updated stride calculation for half head_dim
|
||||
half_hd = hd // 2
|
||||
t_cos = cos + pid * half_hd
|
||||
h_cos = t_cos + num_tokens * half_hd
|
||||
w_cos = h_cos + num_tokens * half_hd
|
||||
t_sin = sin + pid * half_hd
|
||||
h_sin = t_sin + num_tokens * half_hd
|
||||
w_sin = h_sin + num_tokens * half_hd
|
||||
|
||||
# Updated offsets for half head_dim
|
||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||
t_mask = cos_offsets < t_end
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_hd)
|
||||
|
||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
|
||||
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
|
||||
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
|
||||
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
|
||||
|
||||
cos_row = t_cos_row + h_cos_row + w_cos_row
|
||||
sin_row = t_sin_row + h_sin_row + w_sin_row
|
||||
|
||||
# ####################################################################
|
||||
# Load the left and right half of q and k for the current
|
||||
# program instance (i.e. for the current token) separately
|
||||
# ####################################################################
|
||||
# left half of the head
|
||||
first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(
|
||||
0, pad_hd // 2)[None, :]
|
||||
first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(
|
||||
0, pad_hd // 2)[None, :]
|
||||
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(
|
||||
0, pad_hd // 2)[None, :] < hd // 2)
|
||||
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(
|
||||
0, pad_hd // 2)[None, :] < hd // 2)
|
||||
|
||||
q_tile_1 = tl.load(q_ptr + first_half_q_offsets,
|
||||
mask=first_q_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
k_tile_1 = tl.load(k_ptr + first_half_k_offsets,
|
||||
mask=first_k_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
|
||||
# right half of the head
|
||||
second_half_q_offsets = first_half_q_offsets + (hd // 2)
|
||||
second_half_k_offsets = first_half_k_offsets + (hd // 2)
|
||||
second_q_mask = first_q_mask
|
||||
second_k_mask = first_k_mask
|
||||
|
||||
q_tile_2 = tl.load(q_ptr + second_half_q_offsets,
|
||||
mask=second_q_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
k_tile_2 = tl.load(k_ptr + second_half_k_offsets,
|
||||
mask=second_k_mask,
|
||||
other=0).to(sin_row.dtype)
|
||||
|
||||
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
|
||||
# Since cos and sin are now half-size,
|
||||
# we use the same cos_row and sin_row for both halves
|
||||
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
|
||||
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
|
||||
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
|
||||
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
|
||||
|
||||
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
|
||||
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
|
||||
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
|
||||
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
|
||||
|
||||
|
||||
def triton_mrope(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
mrope_section: list[int],
|
||||
head_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Qwen2VL mrope kernel.
|
||||
|
||||
Args:
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
cos: [3, num_tokens, head_size //2 ]
|
||||
(T/H/W positions with multimodal inputs)
|
||||
sin: [3, num_tokens, head_size //2 ]
|
||||
(T/H/W positions with multimodal inputs)
|
||||
mrope_section: [t, h, w]
|
||||
head_size: int
|
||||
"""
|
||||
n_row, n_q_head_head_dim = q.shape
|
||||
n_q_head = n_q_head_head_dim // head_size
|
||||
n_kv_head = k.shape[1] // head_size
|
||||
pad_hd = triton.next_power_of_2(head_size)
|
||||
pad_n_q_head = triton.next_power_of_2(n_q_head)
|
||||
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
|
||||
|
||||
# ensure tensors passed into the kernel are contiguous.
|
||||
# It will be no-op if they are already contiguous
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
cos = cos.contiguous()
|
||||
sin = sin.contiguous()
|
||||
|
||||
_triton_qwen2vl_mrope_forward[(n_row, )](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
sin,
|
||||
n_row,
|
||||
n_q_head,
|
||||
n_kv_head,
|
||||
head_size,
|
||||
pad_n_q_head,
|
||||
pad_n_kv_head,
|
||||
pad_hd,
|
||||
mrope_section[0],
|
||||
mrope_section[1],
|
||||
)
|
||||
return q, k
|
||||
|
||||
|
||||
class MRotaryEmbedding(RotaryEmbedding):
|
||||
"""Rotary Embedding with Multimodal Sections."""
|
||||
|
||||
@ -36,11 +199,34 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
if self.mrope_section:
|
||||
assert sum(self.mrope_section) == rotary_dim // 2
|
||||
|
||||
self.use_triton = current_platform.is_cuda_alike()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""MRope forward.
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[num_tokens,] (text only) or
|
||||
[3, num_tokens] (T/H/W positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
if self.use_triton:
|
||||
return self.forward_cuda(positions, query, key)
|
||||
else:
|
||||
return self.forward_native(positions, query, key)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
@ -88,6 +274,51 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
key_shape = key.shape
|
||||
if positions.ndim == 2:
|
||||
assert self.mrope_section
|
||||
|
||||
q, k = triton_mrope(
|
||||
query,
|
||||
key,
|
||||
cos,
|
||||
sin,
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin,
|
||||
self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@classmethod
|
||||
def get_input_positions(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user