[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-06-11 12:53:10 -04:00
committed by GitHub
parent b2d9be6f7d
commit 29fa5cac1c
15 changed files with 458 additions and 396 deletions

View File

@ -29,6 +29,7 @@ MNK_FACTORS = [
(224, 1024, 1536),
(224, 3072, 1024),
(224, 3072, 1536),
(1024 * 128, 1024, 1024),
]
vllm_config = VllmConfig(parallel_config=ParallelConfig(

View File

@ -15,7 +15,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
@ -76,6 +77,13 @@ def test_fused_moe(
else:
e_map = None
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=None)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a,
@ -103,7 +111,20 @@ def test_fused_moe(
expert_map=e_map,
renormalize=False)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_triton_output = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(m_triton_output,
torch_output,
atol=2e-2,
rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,

View File

@ -1,10 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
@ -14,6 +15,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
@ -64,6 +67,7 @@ def pplx_cutlass_moe(
out_dtype,
per_act_token: bool,
per_out_ch: bool,
group_name: Optional[str],
):
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
@ -84,7 +88,7 @@ def pplx_cutlass_moe(
else:
scale_elems = (hidden_dim + block_size - 1) // block_size
ata = AllToAll.internode(
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
@ -96,6 +100,12 @@ def pplx_cutlass_moe(
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
)
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
w1 = w1.to(device)
w2 = w2.to(device)
w1_scale = w1_scale.to(device)
@ -113,7 +123,10 @@ def pplx_cutlass_moe(
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
out_dtype, per_act_token, per_out_ch)
out_dtype,
per_act_token,
per_out_ch,
use_batched_format=True)
fused_cutlass_experts = FusedMoEModularKernel(
prepare_finalize,
@ -184,11 +197,17 @@ def _pplx_moe(
w2_full: torch.Tensor,
per_act_token: bool,
per_out_ch: bool,
use_internode: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
with set_current_vllm_config(vllm_config):
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
@ -196,7 +215,7 @@ def _pplx_moe(
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
w2_scale, topk_weights, topk_ids,
a1_scale, out_dtype, per_act_token,
per_out_ch)
per_out_ch, group_name)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
@ -207,7 +226,8 @@ def _pplx_moe(
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
nvshmem_finalize()
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("m", [2, 224])
@ -218,6 +238,7 @@ def _pplx_moe(
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
@pytest.mark.parametrize("use_internode", [False])
@pytest.mark.skipif(
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
current_platform.get_device_capability()),
@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
per_act_token: bool,
per_out_ch: bool,
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)
@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
use_internode)

View File

@ -18,7 +18,6 @@ try:
except ImportError:
has_pplx = False
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import override_config
@ -30,6 +29,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from .deepep_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
not has_pplx,
reason="Requires PPLX kernels",
@ -153,7 +154,10 @@ def batched_moe(
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1,
rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
@ -229,9 +233,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
return t[(r * chunk):(r + 1) * chunk]
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
num_experts: int) -> torch.Tensor:
def pplx_prepare_finalize(
pgi: ProcessGroupInfo,
dp_size: int,
a: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
group_name: Optional[str],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
@ -245,7 +255,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
world_size = pgi.world_size
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
ata = AllToAll.internode(
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
@ -259,6 +269,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
torch.float32.itemsize)),
)
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
@ -318,11 +334,19 @@ def _pplx_prepare_finalize(
score: torch.Tensor,
topk: torch.Tensor,
num_experts: int,
use_internode: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
device = pgi.device
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
@ -335,14 +359,15 @@ def _pplx_prepare_finalize(
a.dtype)
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
num_experts)
num_experts, group_name)
torch_output = chunk_by_rank(torch_output, pgi.rank,
pgi.world_size).to(pplx_output.device)
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
if use_internode:
nvshmem_finalize()
# TODO (bnell): this test point does not work for odd M due to how the test is
@ -353,6 +378,7 @@ def _pplx_prepare_finalize(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_prepare_finalize(
mnk: tuple[int, int, int],
@ -360,6 +386,7 @@ def test_pplx_prepare_finalize(
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
@ -369,10 +396,11 @@ def test_pplx_prepare_finalize(
score = torch.randn((m, e), device=device, dtype=dtype)
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
topk, e)
topk, e, use_internode)
def pplx_moe(
group_name: Optional[str],
rank: int,
world_size: int,
dp_size: int,
@ -394,7 +422,7 @@ def pplx_moe(
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
ata = AllToAll.internode(
args = dict(
max_num_tokens=max_num_tokens,
num_experts=num_experts,
experts_per_token=topk,
@ -408,6 +436,12 @@ def pplx_moe(
torch.float32.itemsize)),
)
if group_name is None:
ata = AllToAll.internode(**args)
else:
args["group_name"] = group_name
ata = AllToAll.intranode(**args)
topk_ids = topk_ids.to(dtype=torch.uint32)
prepare_finalize = PplxPrepareAndFinalize(
@ -522,11 +556,18 @@ def _pplx_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
use_internode: bool,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
if use_internode:
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
torch.distributed.broadcast(uid, src=0)
nvshmem_init(uid, pgi.rank, pgi.world_size)
group_name = None
else:
group_ranks = list(range(pgi.world_size))
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
group_name = cpu_group.group_name
m, k = a.shape
e, _, n = w2.shape
@ -536,8 +577,8 @@ def _pplx_moe(
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
@ -548,7 +589,8 @@ def _pplx_moe(
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize()
if use_internode:
nvshmem_finalize()
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@ -556,6 +598,7 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_moe(
mnk: tuple[int, int, int],
@ -563,6 +606,7 @@ def test_pplx_moe(
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
@ -572,4 +616,5 @@ def test_pplx_moe(
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
use_internode)

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
@ -45,7 +46,7 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048]
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score = torch.randn((M, E), dtype=dtype)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_out = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=E,
w1_scale=w1_s,
w2_scale=w2_s)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
rel_diff = (torch.mean(
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,

View File

@ -1,123 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import os
import traceback
from typing import Callable
import torch
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
"tcp://localhost:29500",
worker,
) + args,
nprocs=world_size,
join=True,
)
def parallel_launch_from_env(
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert not kwargs
world_size = int(os.environ["WORLD_SIZE"])
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
node_rank = int(os.environ["NODE_RANK"])
assert "MASTER_ADDR" in os.environ
assert "MASTER_PORT" in os.environ
spawn(
_worker_parallel_launch,
args=(
world_size,
world_local_size,
node_rank,
"env://",
worker,
) + args,
nprocs=world_local_size,
join=True,
)

View File

@ -36,6 +36,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert (len(self.block_shape) == 2 and all(
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
def supports_chunking(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -45,17 +48,19 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -72,7 +77,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
import deep_gemm as dg
assert hidden_states.ndim == 3
@ -89,7 +94,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K))
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
@ -118,8 +122,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
out=workspace3,
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)
return workspace3

View File

@ -64,6 +64,15 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_shape=self.block_shape, # type: ignore[arg-type]
) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None
assert (self.batched_deep_gemm_experts is not None
or self.batched_triton_experts is not None)
def supports_chunking(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return ((bdge is None or bdge.supports_chunking())
and (bte is None or bte.supports_chunking()))
def workspace_shapes(
self,
a: torch.Tensor,
@ -73,7 +82,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
@ -87,6 +96,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -103,7 +113,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
use_batched_deep_gemm_experts = (self.allow_deep_gemm
and self.batched_deep_gemm_experts
is not None)
@ -111,7 +121,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if use_batched_deep_gemm_experts else
self.batched_triton_experts)
assert experts is not None
return experts.apply(hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale,
workspace13, workspace2, expert_num_tokens)
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_num_tokens)

View File

@ -14,6 +14,7 @@ from vllm.scalar_type import scalar_types
def run_cutlass_moe_fp8(
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -31,7 +32,8 @@ def run_cutlass_moe_fp8(
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
) -> torch.Tensor:
use_batched_format: bool,
):
a1q = hidden_states
assert w1_scale is not None
@ -61,23 +63,20 @@ def run_cutlass_moe_fp8(
if expert_map is not None:
assert expert_num_tokens is None
# We have two modes: PPLX and non-PPLX. We differentiate them by checking
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
# uses to track the number of tokens per expert).
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
# We have two modes: batched experts and non-batched experts.
# In the non-batched mode, the input tokens are not padded: thus, the shape
# of the input is [total_num_tokens, hidden_size]. The input and output
# require shuffling by a_map and c_map such that the tokens assigned to
# each expert are contiguous.
# In the PPLX mode, the input tokens are padded per expert to ensure that
# the PPLX dispatch and combine functions work correctly: thus, the shape
# In the batched mode, the input tokens are padded per expert to ensure that
# the batched dispatch and combine functions work correctly: thus, the shape
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
# The PPLX input and output require no shuffling by a_map and c_map since
# The batched input and output require no shuffling by a_map and c_map since
# their tokens are already contiguous for each expert as a result of
# the dispatch function.
is_pplx = expert_num_tokens is not None
M = a1q.shape[0] # no pplx
padded_M = a1q.shape[1] # pplx
M = a1q.shape[0] # non batched expert M
padded_M = a1q.shape[1] # batched expert M
_, K, N = w2.shape
device = a1q.device
@ -95,7 +94,9 @@ def run_cutlass_moe_fp8(
topk = local_topk_ids.shape[1]
local_E = w1.shape[0]
if is_pplx:
if use_batched_format:
assert expert_num_tokens is not None
expert_offsets = torch.empty((local_E),
dtype=torch.int32,
device=device)
@ -167,7 +168,7 @@ def run_cutlass_moe_fp8(
device=device,
dtype=torch.int64)
if is_pplx:
if use_batched_format:
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
@ -192,12 +193,15 @@ def run_cutlass_moe_fp8(
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
per_act_token, per_out_ch)
if is_pplx:
return c3.reshape(local_E, padded_M, K)
if use_batched_format:
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
else:
return c3[c_map].view(M, topk, K)
# We can't do this inplace because output may point to the same tensor
# as c3.
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
# TODO (bnell): split class batched vs. non-batched?
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
@ -206,12 +210,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
use_batched_format: bool = False,
):
super().__init__()
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
self.use_batched_format = use_batched_format
def supports_chunking(self) -> bool:
return not self.use_batched_format
def workspace_shapes(
self,
@ -222,14 +231,24 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
padded_M = aq.shape[1]
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
return (workspace1, workspace2, self.out_dtype)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1: tuple[int, ...] = ()
workspace2: tuple[int, ...] = ()
output: tuple[int, ...] = ()
if self.use_batched_format:
padded_M = aq.shape[1]
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
output = (self.max_experts_per_worker, padded_M, K)
else:
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M * topk, K)
return (workspace1, workspace2, output, self.out_dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -246,16 +265,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch)
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch,
self.use_batched_format)
def cutlass_moe_fp8(
@ -325,6 +345,7 @@ def cutlass_moe_fp8(
out_dtype=out_dtype,
per_act_token=per_act_token,
per_out_ch=per_out_ch,
use_batched_format=False,
),
)

View File

@ -70,6 +70,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
super().__init__()
self.block_shape = deep_gemm_block_shape()
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
@ -79,18 +82,18 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K)
workspace2 = M_sum * max(N, K)
return (workspace1, workspace2, a.dtype)
workspace1 = (M_sum, max(N * 2, K))
workspace2 = (M_sum, max(N, K))
output = (M * topk, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -107,7 +110,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
import deep_gemm as dg
a1q = hidden_states
@ -143,7 +146,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K))
out = _resize_cache(workspace13, (inv_perm.size(0), K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
@ -159,9 +161,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
torch.index_select(mm2_out, 0, inv_perm, out=out)
return out
torch.index_select(mm2_out, 0, inv_perm, out=output)
def deep_gemm_moe_fp8(

View File

@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel(
BLOCK_M = config['BLOCK_SIZE_M']
BLOCK_N = config['BLOCK_SIZE_N']
BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N))
@ -390,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
that the PPLX dispatch/combine kernels use.
"""
def __init__(self, max_num_tokens: Optional[int], world_size: int,
dp_size: int, rank: int):
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
rank: int):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
@ -430,14 +427,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
if self.max_num_tokens is None:
tokens_per_expert = torch.bincount(topk_ids.view(-1),
minlength=num_experts)
self.max_num_tokens = int(tokens_per_expert.max().item())
else:
tokens_per_expert = torch.zeros(num_experts,
dtype=torch.int,
device=a1.device)
tokens_per_expert = torch.zeros(num_experts,
dtype=torch.int,
device=a1.device)
assert num_experts % self.world_size == 0
@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_num_tokens: int,
world_size: int,
dp_size: int,
max_num_tokens: Optional[int] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@ -518,6 +510,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.world_size = world_size
self.dp_size = dp_size
def supports_chunking(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -527,18 +522,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
workspace13 = num_experts * max_num_tokens * num_dp * K
workspace2 = max_num_tokens * num_dp * N
return (workspace13, workspace2, a.dtype)
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
workspace2 = (self.max_num_tokens * num_dp, N)
return (workspace13, workspace2, workspace13, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -555,20 +548,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
hidden_dim = hidden_states.size(-1)
if self.max_num_tokens is None:
max_num_tokens = hidden_states.size(1)
else:
max_num_tokens = self.max_num_tokens
max_num_tokens = self.max_num_tokens
num_dp = self.world_size // self.dp_size
num_experts = global_num_experts
out = _resize_cache(workspace13,
(num_experts, max_num_tokens * num_dp, hidden_dim))
num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
f"{num_local_experts} == {w1.size(0)}")
@ -585,15 +570,13 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()):
num = max_num_tokens * num_dp
num = hidden_states.shape[1]
else:
num = int(expert_num_tokens[expert].item())
tmp = _resize_cache(workspace2, (num, N))
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input)
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
return out
output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
@ -630,6 +613,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert not use_int4_w4a16, "NYI"
assert self.block_shape is None, "NYI"
def supports_chunking(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -639,17 +625,19 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -666,7 +654,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
@ -723,8 +711,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(E, max_num_tokens, N))
intermediate_cache2 = _resize_cache(workspace2,
(E, max_num_tokens, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(E, max_num_tokens, K))
# MM1
invoke_moe_batched_triton_kernel(A=hidden_states,
@ -761,7 +747,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,
C=intermediate_cache3,
C=output,
expert_num_tokens=expert_num_tokens,
compute_type=compute_type,
A_scale=a2q_scale,
@ -772,4 +758,3 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=self.use_int4_w4a16,
config=config,
block_shape=self.block_shape)
return intermediate_cache3

View File

@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
def supports_chunking(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
factor = num_experts if a.dim() == 3 else 1
workspace1 = M * topk * max(N * 2, K) * factor
workspace2 = M * topk * N * factor
return (workspace1, workspace2, a.dtype)
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
workspace1 = (M, topk, max(N * 2, K))
workspace2 = (M, topk, N)
output = (M, topk, K)
return (workspace1, workspace2, output, a.dtype)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
(num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(workspace2,
(num_tokens * top_k_num, N // 2))
intermediate_cache3 = _resize_cache(workspace13,
(num_tokens, top_k_num, K))
sorted_token_ids, expert_ids, num_tokens_post_padded = (
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
intermediate_cache3,
output,
a2q_scale,
w2_scale,
w2_zp,
@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape)
return intermediate_cache3
def modular_triton_fused_moe(
use_fp8_w8a8: bool,

View File

@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from math import prod
from typing import Optional
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
#
# This file defines a set of base classes used to make MoE kernels more modular.
# The goal is to be able to utilize different communication mechanisms with
@ -115,9 +120,9 @@ class FusedMoEPrepareAndFinalize(ABC):
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
number of tokens assigned to each local expert.
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
- Optional dispatched expert topk weight
"""
raise NotImplementedError
@ -159,7 +164,7 @@ class FusedMoEPrepareAndFinalize(ABC):
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
the implementation can process at a time.
Return None if there are no such restrictions.
"""
raise NotImplementedError
@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above.
"""
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
"""
A flag indicating whether or not this class supports activation
chunking.
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
@ -181,19 +195,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
"""
Compute the number of elements for the temporary outputs of the two
gemms and activation in the fused expert function. Since the
gemms are independent, the workspace for the first gemm can be shared
with the workspace for the last gemm.
Compute the shapes for the temporary and final outputs of the two gemms
and activation in the fused expert function. Since the gemms are
independent, the workspace for the first gemm can be shared with the
workspace for the last gemm.
Returns a tuple of:
- Number of workspace13 elements: must be large enough to hold the
- workspace13 shape tuple: must be large enough to hold the
result of either expert gemm.
- Number of workspace2 elements: must be large enough to hold the
- workspace2 shape tuple: must be large enough to hold the
result of the activation function.
- output shape tuple: must be exact size of the final gemm output.
- Workspace type: The dtype to use for the workspace tensors.
- Note: in order for activation chunking to work, the first dimension
of each tuple must be the number of tokens.
"""
raise NotImplementedError
@ -210,6 +227,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@abstractmethod
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -226,12 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
"""
This function computes the intermediate result of a Mixture of Experts
(MoE) layer using two sets of weights, w1 and w2.
Parameters:
- output: (torch.Tensor): The unweighted, unreduced output tensor.
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
layer.
- w1 (torch.Tensor): The first set of expert weights.
@ -259,13 +278,20 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
Returns:
- torch.Tensor: The unweighted, unreduced output tensor
"""
raise NotImplementedError
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
@ -288,61 +314,6 @@ class FusedMoEModularKernel(torch.nn.Module):
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
def _do_fused_experts(
self,
a1: torch.Tensor, # input to forward fn
a1q: torch.Tensor, # output of prepare fn
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
expert_num_tokens: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
# Use a1 here to decipher the correct workspace datatype
workspace13_shape, workspace2_shape, workspace_dtype = (
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the time
# we need cache3, we're done with cache1
workspace13 = torch.zeros(workspace13_shape,
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(workspace2_shape,
device=a1.device,
dtype=workspace_dtype)
fused_out = self.fused_experts.apply(
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
return fused_out
def forward(
self,
hidden_states: torch.Tensor,
@ -408,12 +379,14 @@ class FusedMoEModularKernel(torch.nn.Module):
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
_expert_topk_weights)
fused_out = None
if a1q.numel() == 0:
# This happens when none of the tokens from the all2all reach this
# EP rank. Also, note that this is only relevant for CUDAGraph
@ -423,22 +396,107 @@ class FusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self._do_fused_experts(
a1=a1,
a1q=a1q,
w1=w1,
w2=w2,
topk_ids=topk_ids,
expert_num_tokens=expert_num_tokens,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale)
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
if self.fused_experts.supports_chunking():
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)
else:
CHUNK_SIZE = M
num_chunks = 1
if num_chunks == 1:
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts)
else:
# Use the full M to get the final output shape.
_, _, fused_out_shape, _ = (
self.fused_experts.workspace_shapes(
a1, a1q, M, N, K, top_k, global_num_experts))
# Use the CHUNK_SIZE to get the workspace shapes.
workspace13_shape, workspace2_shape, _, workspace_dtype = (
self.fused_experts.workspace_shapes(
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
# We can reuse the memory between cache1 and cache3 because by the
# time we need cache3, we're done with cache1.
workspace13 = torch.zeros(prod(workspace13_shape),
device=a1.device,
dtype=workspace_dtype)
workspace2 = torch.zeros(prod(workspace2_shape),
device=a1.device,
dtype=workspace_dtype)
if num_chunks == 1:
fused_out = _resize_cache(workspace13, fused_out_shape)
self.fused_experts.apply(
fused_out,
a1q,
w1,
w2,
topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=a1q_scale,
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
else:
# The leading output dimension may not be equal to M, so
# we compute output indices separately.
M_out = fused_out_shape[0]
assert M_out >= M
factor = M_out // M
assert factor > 0
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
fused_out = torch.empty(fused_out_shape,
device=a1q.device,
dtype=workspace_dtype)
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
for chunk in range(num_chunks):
begin_chunk_idx = chunk * CHUNK_SIZE
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
begin_out_idx = chunk * OUT_CHUNK_SIZE
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
end_chunk_idx)
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
end_chunk_idx)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
self.fused_experts.apply(
fused_out[begin_out_idx:end_out_idx],
curr_a1q,
w1,
w2,
curr_topk_ids,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_zp=w1_zp,
w2_zp=w2_zp,
a1q_scale=curr_a1q_scale,
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
)
self.prepare_finalize.finalize(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input)

View File

@ -34,6 +34,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
def supports_chunking(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return ((dge is None or dge.supports_chunking())
and (te is None or te.supports_chunking()))
def workspace_shapes(
self,
a: torch.Tensor,
@ -43,7 +49,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
K: int,
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
@ -57,6 +63,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -73,45 +80,31 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
) -> torch.Tensor:
):
N = w1.size(1)
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2)):
assert self.deep_gemm_expert is not None
return self.deep_gemm_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)
else:
return self.triton_expert.apply(
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)
use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
and _valid_deep_gemm(hidden_states, w1, w2))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
assert experts is not None
experts.apply(
output,
hidden_states,
w1,
w2,
topk_ids,
activation,
global_num_experts,
expert_map,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_num_tokens,
)

View File

@ -562,9 +562,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker, moe.in_dtype,
max_experts_per_worker,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
use_batched_format=True,
)
if has_pplx and isinstance(
prepare_finalize,