mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fused MoE Modular Kernel chunking loop (#20392)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
41060c6e08
commit
fdadb6f43a
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
140
tests/kernels/moe/test_count_expert_num_tokens.py
Normal file
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests compute_expert_num_tokens kernels
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.utils import count_expert_num_tokens
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestTensors:
|
||||
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor] = None
|
||||
|
||||
def to_device(self, device: str):
|
||||
self.topk_ids = self.topk_ids.to(device=device)
|
||||
if self.expert_map is not None:
|
||||
self.expert_map = self.expert_map.to(device=device)
|
||||
|
||||
@staticmethod
|
||||
def make(num_tokens: int, num_topk: int, num_experts: int, device: str,
|
||||
topk_ids_dtype: torch.dtype) -> "TestTensors":
|
||||
|
||||
# make topk ids
|
||||
topk_ids = torch.empty((num_tokens, num_topk),
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
for x in range(num_tokens):
|
||||
topk_ids[x] = torch.randperm(num_experts)[:num_topk]
|
||||
topk_ids = topk_ids.to(dtype=torch.int64)
|
||||
return TestTensors(topk_ids=topk_ids)
|
||||
|
||||
def with_ep_rank(self, ep_rank: int, num_global_experts: int,
|
||||
num_local_experts: int, device: str):
|
||||
# make an expert map
|
||||
expert_map = torch.empty((num_global_experts),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
expert_map.fill_(-1)
|
||||
s = ep_rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)),
|
||||
device=device)
|
||||
|
||||
return TestTensors(topk_ids=self.topk_ids.clone(),
|
||||
expert_map=expert_map)
|
||||
|
||||
|
||||
def ref_impl(tt: TestTensors, expert_num_tokens: torch.Tensor):
|
||||
# do the reference in cpu
|
||||
tt.to_device("cpu")
|
||||
expert_ids, counts = tt.topk_ids.unique(return_counts=True)
|
||||
|
||||
for eid, count in zip(expert_ids, counts):
|
||||
if eid != -1 and tt.expert_map is not None:
|
||||
eid = tt.expert_map[eid]
|
||||
|
||||
if eid == -1:
|
||||
continue
|
||||
|
||||
expert_num_tokens[eid] += count
|
||||
|
||||
|
||||
def do_test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
|
||||
assert num_topk <= num_experts
|
||||
|
||||
tt = TestTensors.make(num_tokens,
|
||||
num_topk,
|
||||
num_experts,
|
||||
topk_ids_dtype=topk_ids_dtype,
|
||||
device="cpu")
|
||||
|
||||
num_global_experts = num_experts
|
||||
assert num_global_experts % ep_size == 0
|
||||
num_local_experts = num_global_experts // ep_size
|
||||
for ep_rank in range(ep_size):
|
||||
tt_rank = tt.with_ep_rank(ep_rank, num_global_experts,
|
||||
num_local_experts, "cpu")
|
||||
|
||||
ref_expert_num_tokens = torch.zeros((num_local_experts),
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
ref_impl(tt_rank, ref_expert_num_tokens)
|
||||
ref_expert_num_tokens = ref_expert_num_tokens.to("cuda")
|
||||
|
||||
tt_rank.to_device("cuda")
|
||||
# Test with expert_map
|
||||
triton_expert_num_tokens_w_emap = count_expert_num_tokens(
|
||||
tt_rank.topk_ids, num_local_experts, tt_rank.expert_map)
|
||||
|
||||
# Test without expert map
|
||||
topk_ids = tt_rank.expert_map[tt_rank.topk_ids].to(topk_ids_dtype)
|
||||
triton_expert_num_tokens_wo_emap = count_expert_num_tokens(
|
||||
topk_ids, num_local_experts, expert_map=None)
|
||||
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_w_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(ref_expert_num_tokens,
|
||||
triton_expert_num_tokens_wo_emap,
|
||||
atol=0,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_tokens", [1, 4, 8, 11, 19, 128, 127, 405, 1024, 3333, 6666, 7317])
|
||||
@pytest.mark.parametrize("num_topk", [2, 6, 8])
|
||||
@pytest.mark.parametrize("num_experts", [64])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens(num_tokens: int, num_topk: int,
|
||||
num_experts: int, ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens, num_topk, num_experts,
|
||||
ep_size, topk_ids_dtype)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("numel", list(range(1, 8192, 11)))
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("ep_size", [2])
|
||||
@pytest.mark.parametrize("topk_ids_dtype", [torch.int64])
|
||||
def test_compute_expert_num_tokens_from_numel(numel: int, num_experts: int,
|
||||
ep_size: int,
|
||||
topk_ids_dtype: torch.dtype):
|
||||
do_test_compute_expert_num_tokens(num_tokens=numel,
|
||||
num_topk=1,
|
||||
num_experts=num_experts,
|
||||
ep_size=ep_size,
|
||||
topk_ids_dtype=topk_ids_dtype)
|
@ -98,7 +98,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = (M_sum, max(N * 2, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M * topk, K)
|
||||
output = (M, topk, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -172,7 +172,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output)
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output.view((-1, K)))
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
|
@ -10,7 +10,8 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.fused_moe.utils import ( # yapf: disable
|
||||
_resize_cache, count_expert_num_tokens)
|
||||
from vllm.utils import cdiv
|
||||
|
||||
#
|
||||
@ -421,6 +422,177 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def _do_fused_experts(
|
||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
||||
local_num_experts: int, expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts, local_num_experts)
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert fused_out is None or fused_out.shape == fused_out_shape, (
|
||||
f"fused_out {fused_out.shape} but expected {fused_out_shape}")
|
||||
if fused_out is None:
|
||||
# reuse workspace13 for the output
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
return fused_out
|
||||
|
||||
def _maybe_chunk_fused_experts(
|
||||
self, a1: torch.Tensor, a1q: torch.Tensor, w1: torch.Tensor,
|
||||
w2: torch.Tensor, topk_ids: torch.Tensor, activation: str,
|
||||
global_num_experts: int, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata]
|
||||
) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
|
||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||
return self._do_fused_experts(
|
||||
fused_out=None,
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
# Chunking required case
|
||||
assert num_chunks > 1
|
||||
|
||||
# Construct the entire output that can then be processed in chunks.
|
||||
(_, _, fused_out_shape,
|
||||
_) = self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts,
|
||||
local_num_experts)
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=a1.dtype)
|
||||
|
||||
def slice_input_tensors(
|
||||
chunk_idx: int
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor]:
|
||||
s = chunk_idx * CHUNK_SIZE
|
||||
e = min(s + CHUNK_SIZE, M)
|
||||
return (a1q[s:e], _chunk_scales(a1q_scale, s, e),
|
||||
_chunk_scales(a2_scale, s, e), topk_ids[s:e])
|
||||
|
||||
def slice_output_tensor(chunk_idx: int) -> torch.Tensor:
|
||||
assert fused_out.size(0) % M == 0, (
|
||||
f"fused_out shape {fused_out.shape} vs M {M}")
|
||||
factor = fused_out.size(0) // M
|
||||
out_chunk_size = CHUNK_SIZE * factor
|
||||
s = chunk_idx * out_chunk_size
|
||||
e = min(s + out_chunk_size, fused_out.size(0))
|
||||
return fused_out[s:e]
|
||||
|
||||
def slice_expert_tokens_metadata(
|
||||
full_expert_tokens_meta: ExpertTokensMetadata,
|
||||
chunk_topk_ids: torch.Tensor, local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> ExpertTokensMetadata:
|
||||
# The existing expert_num_tokens is for the entire a1q
|
||||
# input. Chunking forces recomputation of the number
|
||||
# of tokens assigned to each expert.
|
||||
c_expert_num_tokens = count_expert_num_tokens(
|
||||
chunk_topk_ids, local_num_experts, expert_map)
|
||||
|
||||
c_expert_num_tokens_cpu = None
|
||||
need_expert_num_tokens_cpu = (
|
||||
full_expert_tokens_meta.expert_num_tokens_cpu is not None)
|
||||
if need_expert_num_tokens_cpu:
|
||||
c_expert_num_tokens_cpu = c_expert_num_tokens.to(
|
||||
"cpu", non_blocking=True)
|
||||
|
||||
return ExpertTokensMetadata(
|
||||
expert_num_tokens=c_expert_num_tokens,
|
||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids = (
|
||||
slice_input_tensors(chunk_idx))
|
||||
|
||||
c_expert_tokens_meta = None
|
||||
if expert_tokens_meta is not None:
|
||||
c_expert_tokens_meta = slice_expert_tokens_metadata(
|
||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||
expert_map)
|
||||
|
||||
self._do_fused_experts(fused_out=slice_output_tensor(chunk_idx),
|
||||
a1=a1,
|
||||
a1q=c_a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=c_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta)
|
||||
|
||||
return fused_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -512,110 +684,23 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
if self.fused_experts.enable_chunking():
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
else:
|
||||
CHUNK_SIZE = M
|
||||
num_chunks = 1
|
||||
|
||||
if num_chunks == 1:
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts)
|
||||
else:
|
||||
# Use the full M to get the final output shape.
|
||||
_, _, fused_out_shape, _ = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
# Use the CHUNK_SIZE to get the workspace shapes.
|
||||
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts,
|
||||
local_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.empty(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.empty(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
if num_chunks == 1:
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
else:
|
||||
# The leading output dimension may not be equal to M, so
|
||||
# we compute output indices separately.
|
||||
M_out = fused_out_shape[0]
|
||||
assert M_out >= M
|
||||
factor = M_out // M
|
||||
assert factor > 0
|
||||
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
|
||||
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
|
||||
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx = chunk * CHUNK_SIZE
|
||||
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
|
||||
begin_out_idx = chunk * OUT_CHUNK_SIZE
|
||||
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
|
||||
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
|
||||
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out[begin_out_idx:end_out_idx],
|
||||
curr_a1q,
|
||||
w1,
|
||||
w2,
|
||||
curr_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=curr_a1q_scale,
|
||||
a2_scale=curr_a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
)
|
||||
fused_out = self._maybe_chunk_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
@ -13,9 +13,81 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
quant_dequant_mxfp4)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils import cdiv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _count_expert_num_tokens(topk_ids_ptr, expert_num_tokens_ptr, num_experts,
|
||||
topk_numel, expert_map,
|
||||
HAS_EXPERT_MAP: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
|
||||
curr_expert = tl.program_id(0)
|
||||
|
||||
offsets = tl.arange(0, BLOCK_SIZE)
|
||||
topk_ids_ptrs = topk_ids_ptr + offsets
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)
|
||||
for x in range(tl.cdiv(topk_numel, BLOCK_SIZE)):
|
||||
mask = offsets < (topk_numel - x * BLOCK_SIZE)
|
||||
expert_ids = tl.load(topk_ids_ptrs, mask=mask, other=-1)
|
||||
if HAS_EXPERT_MAP:
|
||||
expert_map_ptrs = expert_map + expert_ids
|
||||
expert_map_mask = expert_ids >= 0
|
||||
expert_ids = tl.load(expert_map_ptrs,
|
||||
mask=expert_map_mask,
|
||||
other=-1)
|
||||
|
||||
has_curr_expert = tl.where(expert_ids == curr_expert, 1, 0)
|
||||
acc = acc + has_curr_expert
|
||||
topk_ids_ptrs += BLOCK_SIZE
|
||||
|
||||
if curr_expert < num_experts:
|
||||
tl.store(expert_num_tokens_ptr + curr_expert, tl.sum(acc))
|
||||
|
||||
|
||||
def count_expert_num_tokens(
|
||||
topk_ids: torch.Tensor, num_local_experts: int,
|
||||
expert_map: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Count the number to tokens assigned to each expert.
|
||||
|
||||
Parameters:
|
||||
- topk_ids (torch.Tensor): Tensor mapping each token to its
|
||||
list of experts.
|
||||
- num_local_experts (int): Number of experts in this rank.
|
||||
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
|
||||
from the global expert space to the local expert space of the expert
|
||||
parallel shard.
|
||||
|
||||
Returns:
|
||||
A tensor of size num_local_experts, where tensor[i] holds the number
|
||||
of tokens assigned to the ith expert.
|
||||
"""
|
||||
assert topk_ids.dtype.is_signed, (
|
||||
"The kernel uses -1 to represent invalid topk_ids")
|
||||
expert_num_tokens = torch.empty((num_local_experts),
|
||||
device=topk_ids.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
grid = num_local_experts
|
||||
BLOCK_SIZE = min(topk_ids.numel(), 1024)
|
||||
BLOCK_SIZE = triton.next_power_of_2(BLOCK_SIZE)
|
||||
|
||||
_count_expert_num_tokens[(grid, )](
|
||||
topk_ids,
|
||||
expert_num_tokens,
|
||||
num_local_experts,
|
||||
topk_ids.numel(),
|
||||
expert_map,
|
||||
HAS_EXPERT_MAP=expert_map is not None,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return expert_num_tokens
|
||||
|
||||
|
||||
def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor:
|
||||
"""
|
||||
Shrink the given tensor and apply the given view to it. This is
|
||||
|
Reference in New Issue
Block a user