mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Refactor] Remove moe_align_block_size_triton
(#21335)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@ -5,9 +5,8 @@ import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size_triton,
|
||||
moe_align_block_size,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
)
|
||||
|
||||
|
||||
def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
|
||||
"""
|
||||
Verifies vllm vs. Triton
|
||||
"""
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
# 1. malloc space for triton and vllm
|
||||
# malloc enough space (max_num_tokens_padded) for the sorted ids
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
expert_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
|
||||
expert_ids_vllm = torch.empty_like(expert_ids_triton)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
|
||||
|
||||
# 2. run implementations
|
||||
moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_triton,
|
||||
expert_ids_triton,
|
||||
num_tokens_post_pad_triton,
|
||||
)
|
||||
|
||||
ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids_vllm,
|
||||
expert_ids_vllm,
|
||||
num_tokens_post_pad_vllm,
|
||||
)
|
||||
print(f"✅ VLLM implementation works with {num_experts} experts!")
|
||||
|
||||
# 3. compare results
|
||||
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
|
||||
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
|
||||
):
|
||||
print("✅ Triton and VLLM implementations match.")
|
||||
else:
|
||||
print("❌ Triton and VLLM implementations DO NOT match.")
|
||||
print("Triton expert_ids:", expert_ids_triton)
|
||||
print("VLLM expert_ids:", expert_ids_vllm)
|
||||
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
|
||||
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
|
||||
|
||||
|
||||
# test configurations
|
||||
num_tokens_range = [1, 16, 256, 4096]
|
||||
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||
@ -87,8 +32,8 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
|
||||
x_names=["num_tokens", "num_experts", "topk"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["vllm", "triton"], # "triton"
|
||||
line_names=["VLLM", "Triton"], # "Triton"
|
||||
line_vals=["vllm"],
|
||||
line_names=["vLLM"],
|
||||
plot_name="moe-align-block-size-performance",
|
||||
args={},
|
||||
)
|
||||
@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
block_size = 256
|
||||
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "vllm":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: ops.moe_align_block_size(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
elif provider == "triton":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||
lambda: moe_align_block_size_triton(
|
||||
topk_ids,
|
||||
num_experts,
|
||||
block_size,
|
||||
sorted_ids.clone(),
|
||||
expert_ids.clone(),
|
||||
num_tokens_post_pad.clone(),
|
||||
),
|
||||
lambda: moe_align_block_size(topk_ids, block_size, num_experts),
|
||||
quantiles=quantiles,
|
||||
)
|
||||
|
||||
@ -151,6 +71,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Running correctness check...")
|
||||
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
|
||||
benchmark.run(print_data=True, show_plots=True)
|
||||
|
@ -5,144 +5,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv, round_up
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage1(
|
||||
topk_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
|
||||
off_c = (pid + 1) * num_experts
|
||||
|
||||
for i in range(tokens_per_thread):
|
||||
if start_idx + i < numel:
|
||||
idx = tl.load(topk_ids_ptr + start_idx + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
|
||||
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage2(
|
||||
tokens_cnts_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
last_cnt = 0
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
|
||||
last_cnt = last_cnt + token_cnt
|
||||
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage3(
|
||||
total_tokens_post_pad_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
last_cumsum = 0
|
||||
off_cnt = num_experts * num_experts
|
||||
for i in range(1, num_experts + 1):
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
|
||||
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
|
||||
tl.store(cumsum_ptr + i, last_cumsum)
|
||||
tl.store(total_tokens_post_pad_ptr, last_cumsum)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def moe_align_block_size_stage4(
|
||||
topk_ids_ptr,
|
||||
sorted_token_ids_ptr,
|
||||
expert_ids_ptr,
|
||||
tokens_cnts_ptr,
|
||||
cumsum_ptr,
|
||||
num_experts: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
numel: tl.constexpr,
|
||||
tokens_per_thread: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
start_idx = tl.load(cumsum_ptr + pid)
|
||||
end_idx = tl.load(cumsum_ptr + pid + 1)
|
||||
|
||||
for i in range(start_idx, end_idx, block_size):
|
||||
tl.store(expert_ids_ptr + i // block_size, pid)
|
||||
|
||||
start_idx = pid * tokens_per_thread
|
||||
off_t = pid * num_experts
|
||||
|
||||
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
|
||||
numel)):
|
||||
expert_id = tl.load(topk_ids_ptr + i)
|
||||
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
|
||||
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
|
||||
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
|
||||
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
|
||||
|
||||
|
||||
# Triton implementation based on:
|
||||
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
|
||||
def moe_align_block_size_triton(
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
block_size: int,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_pad: torch.Tensor,
|
||||
) -> None:
|
||||
numel = topk_ids.numel()
|
||||
grid = (num_experts, )
|
||||
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
cumsum = torch.zeros((num_experts + 1, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
tokens_per_thread = cdiv(numel, num_experts)
|
||||
sorted_token_ids.fill_(numel)
|
||||
expert_ids.zero_()
|
||||
|
||||
moe_align_block_size_stage1[grid](
|
||||
topk_ids,
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
moe_align_block_size_stage2[grid](
|
||||
tokens_cnts,
|
||||
num_experts,
|
||||
)
|
||||
moe_align_block_size_stage3[(1, )](
|
||||
num_tokens_post_pad,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
)
|
||||
moe_align_block_size_stage4[grid](
|
||||
topk_ids,
|
||||
sorted_token_ids,
|
||||
expert_ids,
|
||||
tokens_cnts,
|
||||
cumsum,
|
||||
num_experts,
|
||||
block_size,
|
||||
numel,
|
||||
tokens_per_thread,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
|
Reference in New Issue
Block a user