[Kernels] MoE refactor (#19636)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: ElizaWszola <ewszola@redhat.com>
Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
bnellnm
2025-07-02 09:08:27 -04:00
committed by GitHub
parent b95877509b
commit c1909e7e8c
36 changed files with 2698 additions and 1584 deletions

View File

@ -113,6 +113,7 @@ def bench_run(
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_repeats: int,
):
for _ in range(num_repeats):
@ -124,7 +125,8 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)
def run_cutlass_from_graph(
@ -148,7 +150,8 @@ def bench_run(
topk_ids,
w1_scale,
w2_scale,
a1_scale=a_scale,
per_act_token,
a1_scale=None,
)
def run_triton_from_graph(
@ -227,6 +230,7 @@ def bench_run(
"w2_q": w2_q,
"w1_scale": w1_scale,
"w2_scale": w2_scale,
"per_act_token": per_act_token,
# cuda graph params
"cutlass_graph": cutlass_graph,
"triton_graph": triton_graph,
@ -287,12 +291,13 @@ def bench_run(
w2_scale,
topk_weights,
topk_ids,
per_act_token,
num_warmup,
)
results.append(
benchmark.Timer(
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
globals=globals,
label=label,
sub_label=sub_label,

View File

@ -0,0 +1,190 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
)
## DeepEP specific utils
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)

View File

@ -2,18 +2,59 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
import triton.language as tl
from tests.kernels.moe.utils import (batched_moe,
make_quantized_test_activations,
make_test_weights, triton_moe)
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
MNK_FACTORS = [
(1, 128, 128),
(1, 128, 2048),
(1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048),
(32, 128, 128),
(32, 512, 512),
(32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048),
(45, 512, 512),
(45, 1024, 128),
(45, 1024, 2048),
(64, 128, 128),
(64, 512, 512),
(64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048),
(222, 512, 512),
(222, 1024, 128),
(222, 1024, 2048),
]
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
in_dtype: torch.dtype
quant_dtype: Optional[torch.dtype]
out_dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
@ -32,79 +73,127 @@ class BatchedMMTensors:
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
dtype=config.in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
dtype=config.in_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
dtype=config.out_dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
device="cuda",
dtype=torch.int32)
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@pytest.mark.parametrize("num_experts", [16, 32])
@pytest.mark.parametrize("num_experts", [8, 16, 32])
@pytest.mark.parametrize("max_tokens_per_expert",
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("block_shape", [None])
@pytest.mark.parametrize("per_act_token_quant", [False])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
N: int, dtype: torch.dtype,
block_shape: Optional[list[int]],
per_act_token_quant: bool):
current_platform.seed_everything(7)
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
test_output = tensors.C
ref_output = test_output.clone()
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.")
if per_act_token_quant and block_shape is not None:
pytest.skip("Skip illegal quantization test.")
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
num_expert_tokens = torch.randint(low=0,
high=max_tokens_per_expert,
size=(num_experts, ),
device="cuda",
dtype=torch.int32)
A, A_q, A_scale = make_quantized_test_activations(
num_experts,
max_tokens_per_expert,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant)
B, B_q, B_scale, _, _, _ = make_test_weights(
num_experts,
N // 2,
K,
in_dtype=act_dtype,
quant_dtype=quant_dtype,
block_shape=block_shape,
)
out_shape = (num_experts, max_tokens_per_expert, N)
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
assert A_q.dtype == B_q.dtype
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
A_q,
B_q,
test_output,
tensors.num_expert_tokens,
num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
A_scale,
B_scale,
None,
# Quantization schemes
False,
use_fp8_w8a8,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
},
block_shape=block_shape,
)
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
ref_output = native_batched_masked_quant_matmul(
A,
B,
ref_output,
num_expert_tokens,
None,
None,
None,
)
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
num_expert_tokens,
A_scale, B_scale,
block_shape)
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("per_act_token_quant", [False])
@pytest.mark.parametrize("block_shape", [None])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
pytest.skip("Skip quantization test for non-quantized type")
if per_act_token_quant and block_shape is not None or topk > e:
pytest.skip("Skip illegal quantization test.")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
if dtype.itemsize == 1:
act_dtype = torch.bfloat16
quant_dtype = dtype
else:
act_dtype = dtype
quant_dtype = None
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
block_shape=block_shape,
in_dtype=act_dtype,
quant_dtype=quant_dtype)
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
baseline_output = torch_experts(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
triton_output = triton_moe(
a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
torch.testing.assert_close(triton_output,
baseline_output,
atol=2e-2,
rtol=2e-2)
torch.testing.assert_close(triton_output,
batched_output,
atol=2e-2,
rtol=2e-2)

View File

@ -0,0 +1,296 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.platforms import current_platform
dg_available = False
try:
import deep_gemm
dg_available = True
except ImportError:
pass
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# Test configurations
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168),
(83, 128, 128),
(83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512),
(83, 4608, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4608, 512),
(2048, 4608, 7168),
(8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168),
(8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168),
]
MNK_FACTORS_DG = [
(128, 128, 128),
(128, 512, 512),
(128, 128, 7168),
(128, 1024, 7168),
(128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168),
(192, 128, 128),
(192, 512, 512),
(192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168),
(1335, 128, 128),
(1335, 1024, 7168),
(1335, 4608, 512),
(1335, 4608, 7168),
(2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168),
(2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168),
]
BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16] # [128, 256]
TOP_KS = [1, 2, 6]
SEEDS = [0]
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.size(1)
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@pytest.fixture(autouse=True)
def setup_cuda():
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=False,
block_shape=block_size)
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(
a,
w1,
w2,
w1_s,
w2_s,
topk_weights,
topk_ids,
block_size,
)
out = fused_experts(
a,
w1,
w2,
topk_weights,
topk_ids,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
m_out = m_fused_moe(
a,
w1,
w2,
topk_weights,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
)
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
tol = 0.035 if M < 40000 else 0.039
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=block_size)
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids, block_size)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)

View File

@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from tests.kernels.moe.utils import make_test_weights
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
native_w8a8_block_matmul)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.platforms import current_platform
if current_platform.get_device_capability() < (7, 0):
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
allow_module_level=True)
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16]
MNK_FACTORS = [
(1, 128, 128),
(1, 512, 512),
(1, 128, 7168),
(1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512),
(1, 4096, 7168),
(33, 128, 128),
(33, 512, 512),
(33, 128, 7168),
(33, 1024, 7168),
(33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168),
(128, 128, 128),
(128, 512, 512),
(128, 1024, 7168),
(128, 4096, 512),
(128, 4096, 7168),
(222, 128, 128),
(222, 512, 512),
(222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168),
(2048, 128, 128),
(2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 7168),
]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.fixture(autouse=True, scope="module")
def setup_cuda():
"""Sets the default CUDA device for all tests in this module."""
torch.set_default_device("cuda")
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
@pytest.mark.parametrize("E", E)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
a = torch.randn((M, K), dtype=dtype) / 10
score = torch.randn((M, E), dtype=dtype)
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
N,
K,
dtype,
torch.int8,
per_act_token_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)

View File

@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
n_b_scales = 2 * n if per_out_channel else 1
k_b_scales = k if per_out_channel else 1
# Get the right scale for tests.
_, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
a_scale,
use_per_token_if_dynamic=per_act_token)
a_q, a_scale = ops.scaled_fp8_quant(
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
def run_8_bit(moe_tensors: MOETensors8Bit,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
per_act_token: bool,
num_local_experts: Optional[int] = None) -> torch.Tensor:
assert not any([
t is None for t in [
@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
'topk_ids': topk_ids,
'w1_scale': moe_tensors.w1_scale,
'w2_scale': moe_tensors.w2_scale,
'a1_scale': moe_tensors.a_scale
'per_act_token': per_act_token,
'a1_scale': None #moe_tensors.a_scale
}
num_experts = moe_tensors.w1.size(0)
@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
# Note 5.5 only needed for larger problem sizes, 5 works ok for
# the rest.
torch.testing.assert_close(triton_output,
cutlass_output,
atol=5e-2,
atol=5.5e-2,
rtol=1e-2)
@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
per_act_token)
torch.cuda.synchronize()
graph.replay()
@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
cutlass_output = run_8_bit(mt,
topk_weights,
topk_ids,
per_act_token,
num_local_experts=e // ep_size)
torch.testing.assert_close(triton_output,

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
"""
Test DeepEP + DeepGEMM integration
Test DeepEP + DeepGEMM integration
DeepGEMM are gemm kernels specialized for the
fp8 block-quantized case.
"""
@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep, has_deep_gemm
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
from .utils import make_test_weights
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@ -30,10 +29,9 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
if has_deep_gemm():
import deep_gemm
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
@ -60,25 +58,6 @@ def next_power_of_2(x):
return 2**math.ceil(math.log2(x))
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def make_block_quant_fp8_weights(
e: int,
n: int,
@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
block_size: list[int],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
Return weights w1q, w2q, w1_scale, w2_scale
"""
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w2 = (n + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
device="cuda",
dtype=torch.float32)
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
device="cuda",
dtype=torch.float32)
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(e):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
return w1, w2, w1_s, w2_s
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
return w1q, w2q, w1_scale, w2_scale
@dataclasses.dataclass
@ -132,6 +79,7 @@ class TestConfig:
k: int
n: int
num_experts: int
per_act_token_quant: bool
block_size: list[int]
# configs for testing low-latency kernels
low_latency: bool
@ -150,8 +98,7 @@ class TestTensors:
def make(config: TestConfig, rank) -> "TestTensors":
dtype = torch.bfloat16
topk, m, k, block_size = (config.topk, config.m, config.k,
config.block_size)
topk, m, k = (config.topk, config.m, config.k)
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
@ -159,9 +106,7 @@ class TestTensors:
rank_tokens = torch.randn(
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
block_k = block_size[1]
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
rank_token_scales = None
topk_ids = torch.randint(
low=0,
@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype=q_dtype,
block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size)
fused_experts = BatchedDeepGemmExperts(
max_num_tokens=max_tokens_per_rank,
world_size=pgi.world_size,
dp_size=dp_size,
block_shape=test_config.block_size,
per_act_token_quant=test_config.per_act_token_quant)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
"""
Tests for High-Throughput DeepEP + DeepGemm integration.
"""
import deep_gemm
m, n, k = mnk
current_platform.seed_everything(7)
@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=False,
use_fp8_dispatch=None)
@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@requires_deep_ep
@requires_deep_gemm
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
int], num_experts: int, topk: int,
use_fp8_dispatch: bool, block_size: list[int],
world_dp_size: tuple[int, int]):
def test_ll_deepep_deepgemm_moe(
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
use_fp8_dispatch: bool,
block_size: list[int],
world_dp_size: tuple[int, int],
):
"""
Tests for Low-Latency DeepEP + DeepGemm integration.
"""
@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
k=k,
n=n,
num_experts=num_experts,
per_act_token_quant=False,
block_size=block_size,
low_latency=True,
use_fp8_dispatch=use_fp8_dispatch,

View File

@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.platforms import current_platform
from vllm.utils import has_deep_ep
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
@ -31,7 +31,7 @@ if has_deep_ep():
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
requires_deep_ep = pytest.mark.skipif(
not has_deep_ep(),
@ -102,10 +102,6 @@ class TestTensors:
rank_tokens = torch.randn(
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
rank_token_scales = None
if config.dtype == torch.float8_e4m3fn:
# low_latency_mode kernels dont support per-token quant.
_, rank_token_scales = ops.scaled_fp8_quant(
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
topk = torch.randint(low=0,
high=config.num_experts,
@ -121,11 +117,18 @@ class TestTensors:
config=config)
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, hidden_size: int, dp_size: int,
num_experts: int, num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
def make_modular_kernel(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
hidden_size: int,
dp_size: int,
num_experts: int,
num_local_experts: int,
q_dtype: Optional[torch.dtype],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> FusedMoEModularKernel:
is_quantized = q_dtype is not None
@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
deepep_ll_args = ll_args)
if low_latency_mode:
assert not per_act_token_quant, "not supported in ll mode"
fused_experts = BatchedTritonExperts(
max_num_tokens=MAX_TOKENS_PER_RANK,
world_size=pgi.world_size,
@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False)
use_int4_w4a16=False,
per_act_token_quant=False,
)
else:
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False)
fused_experts = TritonExperts(
use_fp8_w8a8=is_quantized,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_act_token_quant=per_act_token_quant,
)
mk = FusedMoEModularKernel(prepare_finalize=a2a,
fused_experts=fused_experts)
return mk
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
low_latency_mode: bool, dp_size: int,
test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], num_experts: int,
use_fp8_dispatch: bool) -> torch.Tensor:
def deep_ep_moe_impl(
pg: ProcessGroup,
pgi: ProcessGroupInfo,
low_latency_mode: bool,
dp_size: int,
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
num_experts: int,
use_fp8_dispatch: bool,
per_act_token_quant: bool,
) -> torch.Tensor:
num_local_experts = w1.size(0)
@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
q_dtype = torch.float8_e4m3fn
# Make modular kernel
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
hidden_size, dp_size,
num_experts,
num_local_experts, q_dtype,
use_fp8_dispatch)
mk: FusedMoEModularKernel = make_modular_kernel(
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
total_num_tokens = test_tensors.rank_tokens.size(0)
@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
return out_hidden_states
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
def torch_moe_impl(
test_tensors: TestTensors,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
using_fp8_dispatch: bool,
per_act_token_quant: bool,
):
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
test_tensors.topk_weights)
@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
# The DeepEP implementation is requested to dispatch using FP8.
# For numerical stability for testing, emulate the fp8 dispatch by
# blockwise quant and de-quant.
assert not per_act_token_quant
a = test_tensors.rank_tokens
aq, aq_scale = per_token_group_quant_fp8(a, 128)
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
@ -310,6 +331,7 @@ def _deep_ep_moe(
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
use_fp8_dispatch: bool,
per_act_token_quant: bool,
):
if not low_latency_mode:
@ -331,7 +353,8 @@ def _deep_ep_moe(
with set_current_vllm_config(VllmConfig()):
# Reference
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
w2_scale, use_fp8_dispatch)
w2_scale, use_fp8_dispatch,
per_act_token_quant)
# Splice experts for this rank.
num_local_experts = config.num_experts // pgi.world_size
@ -356,6 +379,7 @@ def _deep_ep_moe(
w2_scale_ep,
config.num_experts,
use_fp8_dispatch,
per_act_token_quant,
)
torch.testing.assert_close(
@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
@pytest.mark.parametrize("num_experts", [32])
@pytest.mark.parametrize("topk", [6])
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@requires_deep_ep
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
num_experts: int, topk: int, world_dp_size: tuple[int,
int]):
def test_deep_ep_moe(
dtype: torch.dtype,
mnk: tuple[int, int, int],
num_experts: int,
topk: int,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
):
low_latency_mode = False
use_fp8_dispatch = False
m, n, k = mnk
@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
per_act_token_quant)
MNKs = [
@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
False)

View File

@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import (
@ -142,6 +143,10 @@ def test_fused_moe(
# Setup test data
#
#
# Setup test data
#
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
@ -169,7 +174,7 @@ def test_fused_moe(
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
per_act_token_quant=False,
block_shape=None)
def m_fused_moe(
@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
if dtype == torch.float32:
pytest.skip("AITER ROCm test skip for float32")
monkeypatch.setenv('RANK', "0")
monkeypatch.setenv('LOCAL_RANK', "0")
monkeypatch.setenv('WORLD_SIZE', "1")
monkeypatch.setenv('MASTER_ADDR', 'localhost')
monkeypatch.setenv('MASTER_PORT', '12345')
init_distributed_environment()
# Instantiate our and huggingface's MoE blocks
vllm_config.compilation_config.static_forward_context = dict()
with (set_current_vllm_config(vllm_config),

View File

@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.platforms import current_platform
if not current_platform.has_device_capability(100):
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
allow_module_level=True)
MNK_FACTORS = [

View File

@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
try:
from pplx_kernels import AllToAll
@ -93,7 +93,7 @@ def pplx_cutlass_moe(
num_experts=num_experts,
experts_per_token=topk,
rank=rank,
world_size=pgi.world_size,
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
@ -118,8 +118,6 @@ def pplx_cutlass_moe(
pgi.world_size,
rank,
dp_size,
quant_dtype=torch.float8_e4m3fn,
per_act_token=per_act_token,
)
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,

View File

@ -18,18 +18,20 @@ try:
except ImportError:
has_pplx = False
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
from tests.kernels.utils import torch_experts
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import override_config
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.platforms import current_platform
from vllm.utils import round_up
from .utils import ProcessGroupInfo, parallel_launch
from .parallel_utils import ProcessGroupInfo, parallel_launch
requires_pplx = pytest.mark.skipif(
not has_pplx,
@ -144,25 +146,6 @@ def torch_batched_moe(
return torch_finalize(out, topk_weight, topk_ids)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
) -> torch.Tensor:
num_experts = w1.shape[0]
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1,
rank=0),
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024])
@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
torch.testing.assert_close(baseline_output,
torch_output,
@ -226,7 +209,6 @@ def pplx_prepare_finalize(
topk = topk_ids.shape[1]
num_tokens, hidden_dim = a.shape
block_size = 128
device = pgi.device
rank = pgi.rank
world_size = pgi.world_size
@ -241,9 +223,7 @@ def pplx_prepare_finalize(
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_scale_bytes=0,
)
if group_name is None:
@ -260,7 +240,6 @@ def pplx_prepare_finalize(
world_size,
rank,
dp_size,
a.dtype,
)
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
@ -276,6 +255,7 @@ def pplx_prepare_finalize(
num_experts,
None,
False,
FusedMoEQuantConfig(),
)
b_a = b_a * 1.5
@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
# TODO (bnell) add fp8 tests
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@ -386,18 +367,31 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
use_compile: bool = False,
use_cudagraphs: bool = True,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
max_num_tokens,
hidden_dim,
a.dtype,
qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
args = dict(
max_num_tokens=max_num_tokens,
@ -407,10 +401,8 @@ def pplx_moe(
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=scale_bytes,
)
if group_name is None:
@ -429,9 +421,11 @@ def pplx_moe(
dp_size,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size)
dp_size=dp_size,
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
block_shape=block_shape)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -447,6 +441,13 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if w1_scale is not None:
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
else:
w1_scale_chunk = None
w2_scale_chunk = None
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
@ -465,6 +466,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@ -477,6 +480,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
rank=rank,
)
experts = BatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
world_size=1,
dp_size=1)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -539,7 +544,12 @@ def _pplx_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
use_internode: bool,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
use_internode: bool = False,
):
if use_internode:
uid = nvshmem_get_unique_id(
@ -557,11 +567,28 @@ def _pplx_moe(
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
device = torch.device("cuda", pgi.rank)
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_s,
w2_scale=w2_s,
quant_dtype=qtype,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape)
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
a, w1, w2, topk_weight, topk_ids)
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
qtype, per_act_token_quant, block_shape)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
@ -581,6 +608,8 @@ def _pplx_moe(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@pytest.mark.parametrize("per_act_token_quant", [False, True])
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
@pytest.mark.parametrize("use_internode", [False])
@requires_pplx
def test_pplx_moe(
@ -589,15 +618,33 @@ def test_pplx_moe(
topk: int,
dtype: torch.dtype,
world_dp_size: tuple[int, int],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
use_internode: bool,
):
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
if dtype == torch.float8_e4m3fn:
use_fp8_w8a8 = True
quant_dtype = dtype
else:
use_fp8_w8a8 = False
quant_dtype = None
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
pytest.skip("Skip quantization test for non-quantized type")
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
n,
k,
quant_dtype=quant_dtype,
block_shape=block_shape)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
use_internode)

View File

@ -1,194 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
"""
DeepEP test utilities
"""
import dataclasses
import importlib
import os
import traceback
from typing import Callable, Optional
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import torch
from torch.distributed import ProcessGroup
from torch.multiprocessing import (
spawn) # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import Concatenate, ParamSpec
from vllm.utils import get_open_port
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
if has_deep_ep:
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
DeepEPHTPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
DeepEPLLPrepareAndFinalize)
## Parallel Processes Utils
P = ParamSpec("P")
import vllm._custom_ops as ops
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
per_block_cast_to_int8)
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import round_up
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
world_local_size: int
rank: int
node_rank: int
local_rank: int
device: torch.device
def triton_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_channel_quant=per_act_token_quant,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
block_shape=block_shape)
def _worker_parallel_launch(
local_rank: int,
world_size: int,
world_local_size: int,
node_rank: int,
init_method: str,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
rank = node_rank * world_local_size + local_rank
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
torch.distributed.init_process_group(
backend="cpu:gloo,cuda:nccl",
init_method=init_method,
rank=rank,
world_size=world_size,
device_id=device,
)
barrier = torch.tensor([rank], device=device)
torch.distributed.all_reduce(barrier)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
try:
worker(
ProcessGroupInfo(
world_size=world_size,
world_local_size=world_local_size,
rank=rank,
node_rank=node_rank,
local_rank=local_rank,
device=device,
),
*args,
**kwargs,
)
except Exception as ex:
print(ex)
traceback.print_exc()
raise
finally:
torch.distributed.destroy_process_group()
def parallel_launch(
world_size: int,
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
*args: P.args,
**kwargs: P.kwargs,
) -> None:
assert not kwargs
spawn(
_worker_parallel_launch,
args=(
world_size,
world_size,
0,
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
worker,
) + args,
nprocs=world_size,
join=True,
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
BatchedTritonExperts(
max_num_tokens=max_num_tokens,
world_size=1,
dp_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
## DeepEP specific utils
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
@dataclasses.dataclass
class DeepEPHTArgs:
num_local_experts: int
def naive_batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
@dataclasses.dataclass
class DeepEPLLArgs:
max_tokens_per_rank: int
hidden_size: int
num_experts: int
use_fp8_dispatch: bool
def make_deepep_ht_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
ht_args: DeepEPHTArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# high throughput a2a
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
buffer = deep_ep.Buffer(group=pg,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=low_latency_mode,
num_qps_per_rank=num_qps_per_rank)
return DeepEPHTPrepareAndFinalize(buffer=buffer,
world_size=pgi.world_size,
rank=pgi.rank,
dp_size=dp_size,
rank_expert_offset=pgi.rank *
ht_args.num_local_experts,
quant_dtype=q_dtype,
block_shape=block_shape)
def make_deepep_ll_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ll_args: DeepEPLLArgs,
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
import deep_ep
# low-latency a2a
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
pgi.world_size, deepep_ll_args.num_experts)
buffer = deep_ep.Buffer(group=pg,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=deepep_ll_args.num_experts //
pgi.world_size)
return DeepEPLLPrepareAndFinalize(
buffer=buffer,
world_size=pgi.world_size,
dp_size=dp_size,
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
quant_dtype=q_dtype,
block_shape=block_shape,
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0),
NaiveBatchedExperts(
max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
def make_deepep_a2a(pg: ProcessGroup,
pgi: ProcessGroupInfo,
dp_size: int,
deepep_ht_args: Optional[DeepEPHTArgs],
deepep_ll_args: Optional[DeepEPLLArgs],
q_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
if deepep_ht_args is not None:
assert deepep_ll_args is None
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
block_shape)
assert deepep_ll_args is not None
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
block_shape)
def chunk_scales(scales: Optional[torch.Tensor], start: int,
end: int) -> Optional[torch.Tensor]:
if scales is not None:
if scales.numel() == 1:
return scales
else:
return scales[start:end]
return None
def make_quantized_test_activations(
E: int,
m: int,
k: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
a_q = a
a_scale = None
if quant_dtype is not None:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
a_q = torch.zeros_like(a, dtype=quant_dtype)
a_scale_l = [None] * E
for e in range(E):
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
a[e], None, quant_dtype, per_act_token_quant, block_shape)
a_scale = torch.stack(a_scale_l)
if not per_act_token_quant and block_shape is None:
a_scale = a_scale.view(E, 1, 1)
return a, a_q, a_scale
def moe_quantize_weights(
w: torch.Tensor,
w_s: Optional[torch.Tensor],
quant_dtype: Optional[torch.dtype],
per_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
assert (quant_dtype == torch.float8_e4m3fn
or quant_dtype == torch.int8), "only fp8/int8 supported"
if block_shape is not None:
assert not per_token_quant
if quant_dtype == torch.int8:
w, w_s = per_block_cast_to_int8(w, block_shape)
else:
w, w_s = per_block_cast_to_fp8(w, block_shape)
else:
if quant_dtype == torch.int8:
w, w_s = ops.scaled_int8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
else:
w, w_s = ops.scaled_fp8_quant(
w, w_s, use_per_token_if_dynamic=per_token_quant)
return w, w_s
def make_test_weight(
e: int,
rows: int,
cols: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
if quant_dtype is not None:
w_l = [None] * e
w_s_l = [None] * e
for idx in range(e):
w_l[idx], w_s_l[idx] = moe_quantize_weights(
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
w = torch.stack(w_l)
w_s = torch.stack(w_s_l)
if w_s.ndim == 2:
assert w_s.shape[-1] == 1
w_s = w_s.view(-1, 1, 1)
if block_shape is not None:
block_n, block_k = block_shape
n_tiles = (rows + block_n - 1) // block_n
k_tiles = (cols + block_k - 1) // block_k
assert w_s.shape == (e, n_tiles, k_tiles)
else:
w = w_16
w_s = None
return w_16, w, w_s
def make_test_weights(
e: int,
n: int,
k: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
torch.Tensor, Optional[torch.Tensor]]:
return (
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
per_act_token_quant),
)

View File

@ -5,7 +5,10 @@ from typing import Optional, Union
import torch
from vllm.model_executor.layers.quantization.utils.quant_utils import (
group_broadcast)
from vllm.platforms import current_platform
from vllm.utils import round_up
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
return ref_out, ref_scale.view((1, ))
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, Bs: torch.Tensor, block_size,
output_dtype):
def native_w8a8_block_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
compute_type: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
A = A.to(compute_type)
B = B.to(compute_type)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` must be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
DEFAULT_BLOCK_SHAPE = [128, 128]
def per_block_cast_to_fp8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_block_cast_to_int8(
x: torch.Tensor,
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
) -> tuple[torch.Tensor, torch.Tensor]:
block_m, block_n = block_shape
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def dequant(
t: torch.Tensor,
scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
per_act_token_quant: bool,
out_dtype: Optional[torch.dtype] = torch.float32,
) -> torch.Tensor:
if scale is not None:
f32 = torch.float32
if per_act_token_quant or block_shape is None:
return (t.to(f32) * scale).to(out_dtype)
else:
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
else:
return t.to(out_dtype)
def native_batched_masked_quant_matmul(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor] = None,
B_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
if A.dtype.itemsize == 1 and block_shape is not None:
assert A_scale is not None and B_scale is not None
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
block_shape, C.dtype)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
elif A.dtype.itemsize == 1 and block_shape is None:
assert A_scale is not None and B_scale is not None
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
C[e, :num_tokens, :] = (
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
else:
assert A_scale is None
assert B_scale is None
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C

View File

@ -7,16 +7,10 @@ import itertools
import pytest
import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, modular_triton_fused_moe)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size)
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
native_w8a8_block_matmul,
per_block_cast_to_fp8)
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
@ -46,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824]
K = [256, 3884, 4096, 13824, 16384]
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
# and its hidden size is 7168.
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
M_moe_dg = [128, 192, 1335, 2048]
N_moe = [128, 256, 1024, 4608] # [13824]
K_moe = [256, 512, 7168] # [13824]
BLOCK_SIZE = [[128, 128]]
E = [2, 8, 16, 24] # [128, 256]
TOP_KS = [1, 2, 6]
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
SEEDS = [0]
def native_per_token_group_quant_fp8(x,
group_size,
eps=1e-10,
dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
"be divisible by `group_size`")
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""Fused moe with block-wise quantization using native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
# Skip all tests if CUDA is not available
pytest.importorskip("torch.cuda")
@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M,N,K,E,topk,block_size,dtype,seed",
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
SEEDS))
@torch.inference_mode()
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
if topk > E:
pytest.skip(f"Skipping test; topk={topk} > E={E}")
torch.manual_seed(seed)
factor_for_scale = 1e-2
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = (torch.rand(
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w1_bf16
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
del w2_bf16
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
w2_s = torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
score = torch.randn((M, E), dtype=dtype)
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=block_size)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_fp8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
m_out = m_fused_moe(a,
w1,
w2,
topk_weights,
topk_ids,
global_num_experts=E,
w1_scale=w1_s,
w2_scale=w2_s)
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
rel_diff = (torch.mean(
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03
def per_block_cast_to_fp8(
x: torch.Tensor,
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128,
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
@pytest.mark.parametrize(
"M,N,K,block_size,out_dtype,seed",
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001
def fp8_perm(m, idx):
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
M, K = a.shape
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
num_tokens = topk * M
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
a = fp8_perm(a, sorted_token_ids // topk)
if a_s is not None:
a_s = a_s[sorted_token_ids // topk]
return a, a_s, m_indices, inv_perm
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
M = topk_weight.shape[0]
out = out[inv_perm, ...]
tmp_out = out.view(-1, topk, K)
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
block_shape):
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
num_groups = w1.shape[0]
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = per_token_group_quant_fp8(a, block_m)
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
num_groups, topk, block_m)
inter_out = torch.zeros((a_q.shape[0], N * 2),
dtype=torch.bfloat16,
device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
inter_out, m_indices)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
return final_out
@pytest.mark.parametrize(
"M,N,K,E,topk,seed",
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
@torch.inference_mode()
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
monkeypatch):
if topk > E:
pytest.skip(f"Skipping test: topk={topk} > E={E}")
if not _valid_deep_gemm_shape(M, N, K):
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
chunk_size = 1024
torch.manual_seed(seed)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
block_size = [block_m, block_m]
dtype = torch.bfloat16
fp8_info = torch.finfo(torch.float8_e4m3fn)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
fp8_max).clamp(min=fp8_min, max=fp8_max)
score = torch.randn((M, E), dtype=dtype)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w2 = (N + block_k - 1) // block_k
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
for i in range(E):
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
# Note: for now use_compile will error out if the problem size is
# large enough to trigger chunking. I'm leaving the flag and
# setup code in case we are able to revisit this later.
use_compile = False
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
and current_platform.is_cuda_alike())
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
if M >= 128:
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
score, topk, block_size)
else:
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
if use_compile:
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
backend="inductor",
fullgraph=True)
torch._dynamo.mark_dynamic(a, 0)
torch._dynamo.mark_dynamic(topk_weights, 0)
torch._dynamo.mark_dynamic(topk_ids, 0)
else:
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
if use_cudagraph:
out.fill_(0)
stream = torch.cuda.Stream()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
topk_ids)
torch.cuda.synchronize()
graph.replay()
torch.cuda.synchronize()
#print(f"{out.sum()=}")
#print(f"{ref_out.sum()=}")
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.03

View File

@ -8,9 +8,7 @@ import pytest
import torch
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.int8_utils import (
w8a8_block_int8_matmul)
from vllm.platforms import current_platform
@ -23,82 +21,10 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
# For test
def native_per_token_group_quant_int8(x,
group_size,
eps=1e-10,
dtype=torch.int8):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch.
It converts the tensor values into int8 values and returns the
quantized tensor along with the scaling factor used for quantization.
"""
assert (x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
iinfo = torch.iinfo(dtype)
int8_min = iinfo.min
int8_max = iinfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
# Use float32 for scale calculation for stability
amax = x_.abs().max(dim=-1,
keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / int8_max
x_q = (x_.to(torch.float32) / x_s).round().clamp(
min=int8_min, max=int8_max).to(dtype) # Round before clamping
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
return x_q, x_s
# For test
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
"""This function performs fused moe with block-wise quantization using
native torch."""
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
score = torch.softmax(score, dim=-1, dtype=torch.float32)
topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_matmul(a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_int8(
act_out, block_k)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_matmul(act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222]
N = [128, 1024]
K = [256, 4096]
E = [8, 24]
TOP_KS = [2, 6]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.001
@pytest.mark.parametrize(
"M, N, K, E, topk, block_size, dtype, seed",
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
@torch.inference_mode()
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
native torch reference."""
torch.manual_seed(seed)
# Use a smaller factor for scale initialization to prevent large
# values/overflow especially when output dtype might be float16
factor_for_scale = 1e-2
int8_info = torch.iinfo(torch.int8)
int8_max, int8_min = int8_info.max, int8_info.min
a = torch.randn((M, K), dtype=dtype) / 10
w1_fp32 = (torch.rand(
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
block_n, block_k = block_size[0], block_size[1]
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_s = (torch.rand(
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
w2_s = (torch.rand(
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
score = torch.randn((M, E), dtype=dtype)
# Set the context to avoid lots of warning spam.
with set_current_vllm_config(vllm_config):
out = fused_moe(
a,
w1,
w2,
score,
topk,
renormalize=False,
use_int8_w8a8=True,
w1_scale=w1_s,
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
block_size)
# Check results
rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
torch.mean(torch.abs(ref_out.to(torch.float32))))
assert rel_diff < 0.06

View File

@ -13,8 +13,11 @@ import pytest
import torch
from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.platforms.interface import _Backend
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref):
torch.abs(output_ref))
def torch_experts(a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
def torch_experts(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
quant_dtype: Optional[torch.dtype] = None,
per_act_token_quant=False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert (global_num_experts == -1
or (global_num_experts == w1.shape[0] and expert_map is None)
or (expert_map is not None
and global_num_experts == expert_map.shape[0]))
M, K = a.shape
topk = topk_ids.shape[1]
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
per_act_token_quant, block_shape)
num_experts = w1.shape[0]
topk_ids = topk_ids.view(-1)
if expert_map is not None:
topk_ids = expert_map[topk_ids]
for i in range(w1.shape[0]):
for i in range(num_experts):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
if quant_dtype is None:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
elif block_shape is not None:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
out.dtype)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = moe_kernel_quantize_input(
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
out.dtype)
else:
assert (a_scale is not None and w1_scale is not None
and w2_scale is not None)
f32 = torch.float32
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
tmp1 = a[mask].to(f32) * scales
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
tmp1 = tmp1 @ w1_dq
tmp2 = SiluAndMul()(tmp1)
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
def torch_moe(a: torch.Tensor,

View File

@ -1274,7 +1274,7 @@ def scaled_fp8_quant(
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale

View File

@ -4,8 +4,12 @@
from contextlib import contextmanager
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
__all__ = [
"FusedMoE",
"FusedMoEConfig",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"override_config",
"get_config",
]
@ -36,11 +44,21 @@ if HAS_TRITON:
# import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp4, cutlass_moe_fp8)
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts, fused_experts, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
__all__ += [
"fused_moe",
@ -50,5 +68,11 @@ if HAS_TRITON:
"grouped_topk",
"cutlass_moe_fp8",
"cutlass_moe_fp4",
"CutlassExpertsFp8",
"TritonExperts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
]

View File

@ -5,6 +5,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm(
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# The Deep Gemm kernels only support block size of 128
DEEPGEMM_BLOCK_SHAPE = 128
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
block_shape: list[int]):
def __init__(self,
max_num_tokens: int,
world_size: int,
dp_size: int,
block_shape: list[int],
per_act_token_quant=False):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
world_size: Number of EP ranks
dp_size: Number of data-parallel ranks
block_shape: Block quantization block shape
"""
super().__init__()
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
self.block_shape = block_shape
assert (len(self.block_shape) == 2 and all(
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -248,6 +265,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
):
import deep_gemm as dg
assert hidden_states.ndim == 3
assert self.block_shape is not None
a1q = hidden_states
_, N, K = w1.size()

View File

@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts)
@ -20,43 +21,45 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token_quant: bool = False,
allow_deep_gemm: bool = False):
super().__init__()
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape,
per_act_token_quant=per_act_token_quant,
))
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.use_int4_w4a16 = use_int4_w4a16
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.allow_deep_gemm = allow_deep_gemm
# BatchedTritonKernel doesn't support block quantization
# at the moment.
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=self.max_num_tokens,
use_fp8_w8a8=self.use_fp8_w8a8,
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
block_shape=self.block_shape,
world_size=self.world_size,
dp_size=self.dp_size) if self.block_shape is None else None
dp_size=self.dp_size,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape,
) if self.block_shape is None else None
is_fp8_128_block_quantized = (
use_fp8_w8a8 and self.block_shape
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
is_fp8_128_block_quantized = (self.use_fp8_w8a8
and self.block_shape is not None
and len(self.block_shape) == 2 and all(
[b == 128
for b in self.block_shape]))
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
max_num_tokens=self.max_num_tokens,
world_size=self.world_size,
@ -67,12 +70,31 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert (self.batched_deep_gemm_experts is not None
or self.batched_triton_experts is not None)
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.batched_triton_experts is not None:
assert (self.batched_deep_gemm_experts is None
or self.batched_deep_gemm_experts.activation_formats
== self.batched_triton_experts.activation_formats)
return self.batched_triton_experts.activation_formats
else:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.activation_formats
def supports_chunking(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return ((bdge is None or bdge.supports_chunking())
and (bte is None or bte.supports_chunking()))
def supports_expert_map(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return ((bdge is None or bdge.supports_expert_map())
and (bte is None or bte.supports_expert_map()))
def workspace_shapes(
self,
a: torch.Tensor,
@ -87,7 +109,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes(
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
else:

View File

@ -0,0 +1,410 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import Optional, Union
import torch
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
logger = init_logger(__name__)
def _get_quant_config_quantization_args(
quant_config: Optional[QuantizationConfig],
prop_name: str,
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(prop_name)
else:
return None
def get_quant_config_input_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config,
"input_activations")
def get_quant_config_weight_quant(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
return _get_quant_config_quantization_args(quant_config, "weights")
# TODO (bnell): use scalar_type instead of bools?
def get_config_quant_dtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
@dataclass
class FusedMoEQuantConfig:
# The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
per_act_token_quant: bool = False
per_out_ch_quant: bool = False
block_shape: Optional[list[int]] = None
# TODO: add col major flag?
# add detailed quant info for input, intermediates, weights, etc?
@staticmethod
def make(
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
block_shape: Optional[list[int]] = None,
) -> "FusedMoEQuantConfig":
assert sum([
int(flag) for flag in [
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
]
]) <= 1, "Quantization flags are mutually exclusive."
quant_dtype = get_config_quant_dtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return FusedMoEQuantConfig(
quant_dtype,
per_act_token_quant,
per_out_ch_quant,
block_shape,
)
@dataclass
class FusedMoEParallelConfig:
tp_size: int
dp_size: int
ep_size: int
tp_rank: int
dp_rank: int
ep_rank: int
world_size: int
use_ep: bool # whether to use EP or not
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
@property
def use_deepep_ht_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
@property
def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@staticmethod
def make(tp_size_: int, dp_size_: int, world_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input tp_size_,
dp_size_, ep_size_ and vllm's parallel config, determine what
level's of parallelism to use in the fused moe layer.
Args:
tp_size_ (int): tp_size passed into the FusedMoE constructor.
dp_size_ (int): dp_size passed into the FusedMoE constructor.
ep_size_ (int): ep_size passed into the FusedMoE constructor.
world_size_ (int): the world size of the current All2All manager.
vllm_parallel_config (ParallelConfig): vllm's parallel config
object.
Examples:
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
we simply return the sizes unaltered and the ranks set to 0.
Expert Parallelism is considered only when either dp_size_ or tp_size_
is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
legend : {size, rank}
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
devices,
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = (dp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
world_size=world_size_,
use_ep=False)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
world_size=world_size_,
use_ep=True)
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class FusedMoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
# The activation type.
in_dtype: torch.dtype
quant_config: Optional[FusedMoEQuantConfig] = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
def __post_init__(self):
if self.dp_size > 1:
logger.debug("Using FusedMoEConfig::max_num_tokens=%d",
self.max_num_tokens)
@property
def quant_dtype(self) -> Optional[torch.dtype]:
if self.quant_config is not None:
return self.quant_config.quant_dtype
else:
return None
@property
def block_shape(self) -> Optional[list[int]]:
if self.quant_config is not None:
return self.quant_config.block_shape
else:
return None
@property
def per_act_token_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_act_token_quant
else:
return False
@property
def per_out_ch_quant(self) -> bool:
if self.quant_config is not None:
return self.quant_config.per_out_ch_quant
else:
return False
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def world_size(self):
return self.moe_parallel_config.world_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@staticmethod
def make(
num_experts: int,
experts_per_token: int,
hidden_dim: int,
num_local_experts: int,
moe_parallel_config: FusedMoEParallelConfig,
in_dtype: torch.dtype,
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config: Optional[Union[FusedMoEQuantConfig,
QuantizationConfig]] = None
) -> "FusedMoEConfig":
_quant_config: Optional[FusedMoEQuantConfig] = None
if quant_config is not None and isinstance(quant_config,
QuantizationConfig):
if hasattr(quant_config, 'weight_block_size'):
block_shape = quant_config.weight_block_size
else:
block_shape = None
per_act_token_quant = False
per_out_ch_quant = False
quant_dtype: Optional[torch.dtype] = None
input_quant = get_quant_config_input_quant(quant_config)
weight_quant = get_quant_config_weight_quant(quant_config)
if input_quant is not None:
per_act_token_quant = (input_quant.strategy
== QuantizationStrategy.TOKEN
if input_quant is not None else False)
if input_quant.num_bits == 8:
if input_quant.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_quant.type == QuantizationType.INT:
quant_dtype = torch.int8
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn
if weight_quant is not None:
per_out_ch_quant = (
weight_quant.strategy == QuantizationStrategy.CHANNEL)
if quant_dtype is not None:
_quant_config = FusedMoEQuantConfig(
quant_dtype=quant_dtype,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
)
else:
_quant_config = FusedMoEQuantConfig()
logger.warning_once("MoE DP setup unable to determine "
"quantization scheme or unsupported "
"quantization type. This model will "
"not run with DP enabled.")
else:
_quant_config = quant_config
return FusedMoEConfig(
num_experts=num_experts,
experts_per_token=experts_per_token,
hidden_dim=hidden_dim,
num_local_experts=num_local_experts,
moe_parallel_config=moe_parallel_config,
in_dtype=in_dtype,
quant_config=_quant_config,
max_num_tokens=max_num_tokens,
)

View File

@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
@ -202,26 +203,47 @@ def run_cutlass_moe_fp8(
# TODO (bnell): split class batched vs. non-batched?
# maybe remove need for passing aq to workspace_shapes
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_experts_per_worker: int,
out_dtype: torch.dtype,
per_act_token: bool,
per_out_ch: bool,
out_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
per_out_ch_quant: bool,
block_shape: Optional[list[int]] = None,
use_batched_format: bool = False,
):
super().__init__()
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=per_out_ch_quant,
block_shape=block_shape,
))
assert max_experts_per_worker > 0
self.max_experts_per_worker = max_experts_per_worker
self.out_dtype = out_dtype
self.per_act_token = per_act_token
self.per_out_ch = per_out_ch
self.use_batched_format = use_batched_format
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.use_batched_format:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
else:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return not self.use_batched_format
def supports_expert_map(self) -> bool:
return not self.use_batched_format
def workspace_shapes(
self,
a: torch.Tensor,
@ -245,7 +267,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = (M * topk, max(2 * N, K))
workspace2 = (M * topk, N)
output = (M * topk, K)
return (workspace1, workspace2, output, self.out_dtype)
return (workspace1, workspace2, output,
self.out_dtype if self.out_dtype is not None else a.dtype)
def apply(
self,
@ -270,13 +293,14 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
activation_callable = lambda i, o: self.activation(activation, i, o)
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
activation_callable, global_num_experts,
expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2,
expert_num_tokens, self.out_dtype,
self.per_act_token, self.per_out_ch,
self.use_batched_format)
in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
a2_scale, workspace13, workspace2, expert_num_tokens,
self.out_dtype if self.out_dtype is not None else in_dtype,
self.per_act_token_quant, self.per_out_ch_quant,
self.use_batched_format)
def cutlass_moe_fp8(
@ -287,6 +311,7 @@ def cutlass_moe_fp8(
topk_ids: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
per_act_token: bool,
activation: str = "silu",
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
@ -330,22 +355,18 @@ def cutlass_moe_fp8(
Returns:
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
"""
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
per_out_ch = w1_scale.numel() != w1_q.size(0)
out_dtype = a.dtype
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
0)
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=torch.float8_e4m3fn,
per_channel_quant=per_act_token,
),
MoEPrepareAndFinalizeNoEP(),
CutlassExpertsFp8(
max_experts_per_worker=global_num_experts,
out_dtype=out_dtype,
per_act_token=per_act_token,
per_out_ch=per_out_ch,
max_experts_per_worker=num_experts,
out_dtype=a.dtype,
per_act_token_quant=per_act_token,
per_out_ch_quant=per_out_ch,
use_batched_format=False,
),
)
@ -358,7 +379,7 @@ def cutlass_moe_fp8(
topk_ids,
False,
activation,
global_num_experts if global_num_experts != -1 else w1_q.size(0),
num_experts,
expert_map,
w1_scale,
w2_scale,

View File

@ -7,13 +7,13 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_permute)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm, round_up
logger = init_logger(__name__)
@ -65,16 +65,31 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self):
super().__init__()
self.block_shape = deep_gemm_block_shape()
super().__init__(
FusedMoEQuantConfig(
quant_dtype=torch.float8_e4m3fn,
per_act_token_quant=False,
block_shape=deep_gemm_block_shape(),
))
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
num_experts = global_num_experts
@ -107,6 +122,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expert_num_tokens: Optional[torch.Tensor],
):
import deep_gemm as dg
assert self.block_shape is not None
a1q = hidden_states
_, N, K = w1.size()
@ -213,8 +229,7 @@ def deep_gemm_moe_fp8(
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
"""
fn = mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
block_shape=deep_gemm_block_shape()),
MoEPrepareAndFinalizeNoEP(),
DeepGemmExperts(),
)
return fn(

View File

@ -6,6 +6,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
@ -15,22 +16,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
def __init__(self,
buffer: deep_ep.Buffer,
world_size: int,
rank: int,
dp_size: int,
rank_expert_offset: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None):
def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int,
dp_size: int, rank_expert_offset: int):
super().__init__()
self.buffer = buffer
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.quant_dtype = quant_dtype
self.block_shape = block_shape
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
@ -39,6 +32,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
@ -55,13 +52,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
def _do_quant(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor], per_act_token: bool):
tokens, token_scales = moe_kernel_quantize_input(
tokens, token_scales, self.quant_dtype, per_act_token,
self.block_shape)
return tokens, token_scales
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
@ -130,43 +120,51 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
a1 = a1 * topk_weights.to(a1.dtype)
# Check if there is a block_shape / or if we can infer the quantization
# schemes from the scales.
per_token_quant = None
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
]) and self.quant_dtype is not None:
if all([
x is None
for x in [quant_config.block_shape, a1_scale, a2_scale]
]) and quant_config.quant_dtype is not None:
# Quantization required despite none of the inputs suggesting
# quantization. Fallback to per_token_dynamic quant.
per_token_quant = True
else:
per_token_quant = ((self.block_shape is not None) or
(a1_scale is not None and a1_scale.numel() != 1)
or (a2_scale is not None
and a2_scale.numel() != 1))
per_token_quant = False
if per_token_quant:
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=True,
block_shape=quant_config.block_shape,
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
else:
# DeepEP kernels only support dispatching per-token-quant
@ -175,15 +173,18 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=rank_topk_ids,
rank_topk_weights=rank_topk_weights,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# quantize now
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = self._do_quant(expert_x,
a1_scale,
per_act_token=False)
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
expert_topk_weights)

View File

@ -5,11 +5,13 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
maybe_fix_scales, moe_kernel_quantize_input)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
def dequant_fp8(expert_x_fp8: torch.Tensor,
@ -35,30 +37,30 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP low-latency kernels are compiled only for certain
# specific hidden sizes.
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168]
def __init__(self,
buffer: deep_ep.Buffer,
max_tokens_per_rank: int,
world_size: int,
dp_size: int,
max_tokens_per_rank: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_fp8_dispatch: bool = False):
super().__init__()
self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank
self.world_size = world_size
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.block_shape = block_shape
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
self.handle = None
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_tokens_per_rank
@ -66,12 +68,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return torch.int64
def _do_quant(
self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = self.block_shape[1] if self.block_shape is not None else None
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch:
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us.
@ -84,32 +91,20 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
assert isinstance(x, torch.Tensor)
# Check if there is a block_shape / or if we can infer the quantization
# schemes from the scales.
per_token_quant = None
if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
]) and self.quant_dtype is not None:
# Quantization required despite none of the inputs suggesting
# quantization. Fallback to per_token_dynamic quant.
per_token_quant = True
else:
per_token_quant = ((self.block_shape is not None) or
(a1_scale is not None and a1_scale.numel() != 1)
or (a2_scale is not None
and a2_scale.numel() != 1))
assert not per_act_token_quant
num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
per_token_quant,
self.block_shape)
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape)
x = x.view((num_experts, -1, hidden_dim))
if per_token_quant:
if quant_dtype is not None:
assert x_scales is not None
x_scales = x_scales.view(num_experts, max_tokens, -1)
x_scales = maybe_fix_scales(x_scales, num_experts)
return x, x_scales
@ -118,11 +113,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
@ -142,24 +138,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"low_latency kernels doesn't support dispatching per-token scales")
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handle, event, hook = \
self.buffer.low_latency_dispatch(a1,
rank_topk_ids,
topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
a1.dtype)
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return (expert_x, expert_x_scale, expert_num_tokens, None, None)

View File

@ -8,6 +8,7 @@ import triton
import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import (
@ -317,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: torch.Tensor,
B_scale: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
@ -387,14 +388,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
that the PPLX dispatch/combine kernels use.
"""
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
rank: int):
def __init__(
self,
max_num_tokens: int,
world_size: int,
dp_size: int,
rank: int,
):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
self.rank = rank
self.max_num_tokens = max_num_tokens
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
@ -411,6 +421,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
assert a1.dim() == 2
@ -435,22 +446,35 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_local_experts = num_experts // self.world_size
if quant_config.quant_dtype is None:
b_type = a1.dtype
else:
b_type = quant_config.quant_dtype
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=a1.dtype,
dtype=b_type,
device=a1.device)
b_a1_scale = None
assert quant_config.quant_dtype is None, "quantization NYI"
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
b_a1[expert_id -
first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows
if rows == 0:
continue
idx = expert_id - first_expert
b_a1[idx, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[idx] = rows
return b_a1, a1_scale, tokens_per_expert, None, None
assert b_a1_scale is None or b_a1_scale.ndim == 3
return b_a1, b_a1_scale, tokens_per_expert, None, None
def finalize(
self,
@ -480,7 +504,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
output[topks] = output[topks] + rhs
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A reference MoE expert class that operates on expert batched format,
i.e. E x max_num_tokens x K. This is the format that the pplx
@ -497,11 +521,17 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
per_act_token_quant: bool = False,
):
super().__init__()
assert block_shape is None
assert block_m is None
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_fp8_w8a8, "NYI"
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
@ -510,9 +540,19 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.world_size = world_size
self.dp_size = dp_size
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -554,20 +594,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
max_num_tokens = self.max_num_tokens
num_dp = self.world_size // self.dp_size
num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
f"{num_local_experts} == {w1.size(0)}")
N = w1.size(1) // 2
# Not cudagraph friendly
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
for expert in range(num_local_experts):
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
if (torch.compiler.is_compiling()
@ -575,6 +607,10 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
num = hidden_states.shape[1]
else:
num = int(expert_num_tokens[expert].item())
if num == 0:
continue
tmp = _resize_cache(workspace2, (num, N))
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input)
@ -590,34 +626,53 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_num_tokens: Optional[int] = None,
max_num_tokens: int,
world_size: int,
dp_size: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
world_size: int = 1,
dp_size: int = 1,
):
super().__init__()
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int8_w8a8 = use_int8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.per_channel_quant = per_channel_quant
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
assert world_size > 0
assert dp_size > 0
assert dp_size <= world_size
assert max_num_tokens > 0
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
assert self.block_shape is None, "NYI"
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.BatchedExperts,
mk.FusedMoEActivationFormat.BatchedExperts)
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def workspace_shapes(
self,
a: torch.Tensor,
@ -630,10 +685,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
local_num_experts: int,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
num_dp = self.world_size
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
max_num_tokens = self.max_num_tokens
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
@ -708,7 +762,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13,
@ -734,6 +787,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
config=config,
block_shape=self.block_shape)
intermediate_cache2.fill_(0)
# TODO: would be nice to use expert_num_tokens here to reduce
# garbage compute
self.activation(activation, intermediate_cache2.view(-1, N // 2),
@ -745,8 +800,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
per_channel_quant=self.per_channel_quant,
quant_dtype=self.quant_dtype,
per_act_token_quant=self.per_act_token_quant,
block_shape=self.block_shape)
qintermediate_cache2 = qintermediate_cache2.view(

View File

@ -12,6 +12,10 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, get_config_quant_dtype)
# yapf: enable
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_valid_deep_gemm, deep_gemm_moe_fp8)
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
@ -980,20 +984,6 @@ def get_config_dtype_str(
return None
# TODO (bnell): use scalar_type instead of bools?
def get_config_qtype(
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
return torch.int8
return None
def inplace_fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@ -1262,10 +1252,10 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16,
dtype=hidden_states.dtype)
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
get_config_func = functools.partial(
try_get_optimal_moe_config,
@ -1332,8 +1322,8 @@ def fused_experts_impl(
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
A=curr_hidden_states,
A_scale=a1_scale,
qtype=qtype,
per_channel_quant=per_channel_quant,
quant_dtype=qtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
sorted_token_ids, expert_ids, num_tokens_post_padded = (
@ -1373,8 +1363,8 @@ def fused_experts_impl(
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
A=intermediate_cache2,
A_scale=a2_scale,
qtype=qtype,
per_channel_quant=per_channel_quant,
quant_dtype=qtype,
per_act_token_quant=per_channel_quant,
block_shape=block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
@ -1521,30 +1511,41 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):
super().__init__()
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.use_fp8_w8a8 = use_fp8_w8a8
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a8 = use_int8_w8a8
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.block_m = block_m
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16)
self.per_channel_quant = per_channel_quant
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def workspace_shapes(
self,
a: torch.Tensor,
@ -1660,7 +1661,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
self.activation(activation, intermediate_cache2,
@ -1669,8 +1670,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
a2q_scale: Optional[torch.Tensor] = None
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
self.block_shape)
intermediate_cache2, a2_scale, self.quant_dtype,
self.per_act_token_quant, self.block_shape)
invoke_fused_moe_kernel(qintermediate_cache2,
w2,
@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
per_channel_quant=self.per_channel_quant,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape)
@ -1699,27 +1700,17 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> mk.FusedMoEModularKernel:
qtype = get_config_qtype(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
)
return mk.FusedMoEModularKernel(
MoEPrepareAndFinalizeNoEP(
quant_dtype=qtype,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
),
MoEPrepareAndFinalizeNoEP(),
TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
),
)

View File

@ -3,27 +3,30 @@
from abc import abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Literal, Optional, Union, overload
from typing import Callable, Literal, Optional, overload
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
from vllm.config import ParallelConfig, get_current_vllm_config
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_world_group,
tensor_model_parallel_all_reduce)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
# yapf: disable
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEParallelConfig)
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import (
@ -36,14 +39,12 @@ from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
if current_platform.is_cuda_alike():
from .fused_batched_moe import BatchedTritonExperts
from .fused_moe import TritonExperts, fused_experts
from .modular_kernel import (FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
if has_pplx():
from .pplx_prepare_finalize import PplxPrepareAndFinalize
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes)
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize)
else:
fused_experts = None # type: ignore
@ -60,209 +61,10 @@ if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__)
@dataclass
class FusedMoEParallelConfig:
tp_size: int
dp_size: int
ep_size: int
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # whether to use EP or not
@property
def use_all2all_kernels(self):
return self.dp_size > 1 and self.use_ep
@property
def use_pplx_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
@property
def use_deepep_ht_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
@property
def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input tp_size_,
dp_size_, ep_size_ and vllm's parallel config, determine what
level's of parallelism to use in the fused moe layer.
Args:
tp_size_ (int): tp_size passed into the FusedMoE constructor.
dp_size_ (int): dp_size passed into the FusedMoE constructor.
ep_size_ (int): ep_size passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vllm's parallel config
object.
Examples:
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
we simply return the sizes unaltered and the ranks set to 0.
Expert Parallelism is considered only when either dp_size_ or tp_size_
is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
legend : {size, rank}
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
devices,
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
devices,
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
devices,
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
- Comment: There are 2 engine instances and the experts are split
between the 4 devices.
"""
def flatten_tp_across_dp(dp_rank: int):
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size_ * tp_size_ devices. Update tp_size
# and tp_rank so we shard across all devices.
tp_size = dp_size_ * tp_size_
tp_rank = dp_rank * tp_size_ + tp_rank
return tp_size, tp_rank
use_ep = (dp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
if not use_ep:
return FusedMoEParallelConfig(tp_size=tp_size,
tp_rank=tp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
ep_rank=0,
use_ep=False)
# DP + EP / TP + EP / DP + TP + EP
assert use_ep
# In EP, each device owns a set of experts fully. There is no tensor
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
ep_size = tp_size
ep_rank = tp_rank
return FusedMoEParallelConfig(tp_size=1,
tp_rank=0,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
use_ep=True)
# Adapted from pplx-kernels tests/all_to_all_utils.py
@dataclass
class MoEConfig:
num_experts: int
experts_per_token: int
hidden_dim: int
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
quant_dtype: torch.dtype = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
def __post_init__(self):
if self.dp_size > 1:
logger.debug("Using MOEConfig::max_num_tokens=%d",
self.max_num_tokens)
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
@property
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@property
def tp_rank(self):
return self.moe_parallel_config.tp_rank
@property
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@property
def use_ep(self):
return self.moe_parallel_config.use_ep
@property
def use_pplx_kernels(self):
return self.moe_parallel_config.use_pplx_kernels
@property
def use_deepep_ht_kernels(self):
return self.moe_parallel_config.use_deepep_ht_kernels
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
@ -270,21 +72,9 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
moe: MoEConfig
moe: FusedMoEConfig
@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
@ -292,23 +82,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError
def init_prepare_finalize(self, moe: MoEConfig,
def init_prepare_finalize(self, moe: FusedMoEConfig,
quant_config: Optional[QuantizationConfig]):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.moe = moe
quant_dtype = None
act_quant_block_size = None
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if isinstance(quant_config, Fp8Config):
act_quant_block_size = quant_config.weight_block_size
quant_dtype = torch.float8_e4m3fn
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
DeepEPHTPrepareAndFinalize,
DeepEPLLPrepareAndFinalize]] = None
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
if moe.use_pplx_kernels:
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
moe.quant_dtype,
per_act_token_quant=moe.per_act_token_quant,
block_shape=moe.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
@ -318,14 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(
0 if moe.quant_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
# Intranode pplx a2a takes a group name while internode does not.
@ -335,9 +121,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
handle = all2all_manager.get_handle(all_to_all_args)
input_activations = get_quant_config_input_activations(
quant_config)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
@ -345,10 +128,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.quant_dtype,
per_act_token=(input_activations.strategy
== QuantizationStrategy.TOKEN
if input_activations is not None else False),
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
@ -362,8 +141,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank *
moe.num_local_experts,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
)
elif moe.use_deepep_ll_kernels:
@ -380,25 +157,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# Note : We may want to use FP8 dispatch even otherwise just to
# reduce datamovement
assert act_quant_block_size is not None
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
and act_quant_block_size[1]
== DEEPEP_QUANT_BLOCK_SIZE)
use_fp8_dispatch = (moe.quant_config is not None
and moe.quant_config.quant_dtype
== current_platform.fp8_dtype()
and moe.quant_config.block_shape
== DEEPEP_QUANT_BLOCK_SHAPE)
# Note (varun): Whether to use FP8 dispatch or not needs some
# profiling. Turning it off for now.
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
max_tokens_per_rank=moe.max_num_tokens,
quant_dtype=quant_dtype,
block_shape=act_quant_block_size,
use_fp8_dispatch=use_fp8_dispatch,
)
self.topk_indices_dtype = None
if prepare_finalize is not None:
logger.debug("%s", prepare_finalize.__class__.__name__)
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
experts = self.select_gemm_impl(prepare_finalize, moe)
self.fused_experts = FusedMoEModularKernel(
@ -407,13 +184,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
)
def select_gemm_impl(
self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
"Subclass must select appropriate gemm implementation"
" based on the prepare_finalize")
f"{self.__class__.__name__} must select appropriate gemm "
"implementation based on the prepare_finalize")
@abstractmethod
def apply(
@ -445,7 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, moe: MoEConfig):
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
@ -458,44 +237,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else:
self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
moe: Optional[MoEConfig]):
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
assert self.fused_experts == fused_experts
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
) is not None
if use_batched_experts:
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
logger.debug("BatchedTritonExperts %s", self.moe)
assert self.moe.dp_size == all2all_manager.dp_world_size
experts = BatchedTritonExperts(
return BatchedTritonExperts(
max_num_tokens=self.moe.max_num_tokens,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
return TritonExperts()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -883,13 +648,18 @@ class FusedMoE(torch.nn.Module):
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
tp_size_ = (tp_size if tp_size is not None else
get_tensor_model_parallel_world_size())
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
world_size_ = get_world_group().world_size
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
tp_size_=(tp_size if tp_size is not None else
get_tensor_model_parallel_world_size()),
dp_size_=(dp_size if dp_size is not None else
get_dp_group().world_size),
tp_size_=tp_size_,
dp_size_=dp_size_,
world_size_=world_size_,
vllm_parallel_config=vllm_config.parallel_config))
self.global_num_experts = num_experts + num_redundant_experts
@ -948,25 +718,22 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
# Only support float8 for now.
quant_dtype = params_dtype
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8
and input_activations.type == QuantizationType.FLOAT):
quant_dtype = torch.float8_e4m3fn
if vllm_config.model_config is not None:
model_dtype = vllm_config.model_config.dtype
else:
# TODO (bnell): This is a hack to get test_mixtral_moe to work
# since model_config is not set in the pytest test.
model_dtype = params_dtype
moe = MoEConfig(
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=params_dtype,
quant_dtype=quant_dtype,
in_dtype=model_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
quant_config=quant_config,
)
self.moe_config = moe
self.quant_config = quant_config
@ -1017,16 +784,15 @@ class FusedMoE(torch.nn.Module):
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
# Note here we use `num_experts` which is logical expert count
self.batched_router_logits = torch.zeros(
(envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
dtype=act_dtype,
(moe.max_num_tokens, num_experts),
dtype=moe.in_dtype,
device=torch.cuda.current_device())
@property
@ -1588,7 +1354,7 @@ class FusedMoE(torch.nn.Module):
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore

View File

@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from enum import Enum
from math import prod
from typing import Optional
from typing import Optional, final
import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv
@ -82,6 +84,18 @@ def _moe_problem_size(
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
"""
Standard = "standard",
"""
The batched experts format (num experts, max tokens per expert, hidden dim)
"""
BatchedExperts = "batched_experts",
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
@ -99,6 +113,7 @@ class FusedMoEPrepareAndFinalize(ABC):
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
@ -148,6 +163,15 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
@property
@abstractmethod
def activation_format(self) -> FusedMoEActivationFormat:
"""
A property indicating the output format of the activations for the
'prepare' method.
"""
raise NotImplementedError
@abstractmethod
def topk_indices_dtype(self) -> Optional[torch.dtype]:
"""
@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
above.
"""
def __init__(
self,
quant_config: Optional[FusedMoEQuantConfig],
):
if quant_config is not None:
self.quant_config = quant_config
else:
self.quant_config = FusedMoEQuantConfig()
@property
@abstractmethod
def activation_formats(
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
"""
A property which is a tuple of the input and output activation formats
for the 'apply' method.
"""
raise NotImplementedError
@property
def quant_dtype(self) -> Optional[torch.dtype]:
return self.quant_config.quant_dtype
@property
def block_shape(self) -> Optional[list[int]]:
return self.quant_config.block_shape
@property
def per_act_token_quant(self) -> bool:
return self.quant_config.per_act_token_quant
@property
def per_out_ch_quant(self) -> bool:
return self.quant_config.per_out_ch_quant
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
@abstractmethod
def supports_chunking(self) -> bool:
@ -185,6 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
@abstractmethod
def supports_expert_map(self) -> bool:
"""
A flag indicating whether or not this class supports expert maps
"""
raise NotImplementedError
@abstractmethod
def workspace_shapes(
self,
@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
return None
@final
class FusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
@ -318,6 +385,12 @@ class FusedMoEModularKernel(torch.nn.Module):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
def forward(
self,
@ -383,8 +456,16 @@ class FusedMoEModularKernel(torch.nn.Module):
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, a1_scale, a2_scale, topk_weights, topk_ids,
global_num_experts, expert_map, apply_router_weight_on_input)
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids

View File

@ -6,33 +6,76 @@ import pplx_kernels as pplx
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
from vllm.utils import cdiv, round_up
def pplx_hidden_dim_scale_bytes(
max_num_tokens: int,
hidden_dim: int,
in_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
):
# All pplx byte sizes must be 16-byte aligned.
align = 16
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
if quant_dtype is not None:
assert quant_dtype.itemsize == 1
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
elem_size = torch.float32.itemsize
if per_act_token_quant:
# per-token
assert block_shape is None
hidden_scale_bytes = elem_size
elif block_shape is not None:
# per-group
block_size = block_shape[1]
num_blocks = cdiv(hidden_dim, block_size)
hidden_scale_bytes = num_blocks * elem_size
else:
# per-tensor
hidden_scale_bytes = elem_size
else:
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
hidden_scale_bytes = 0
return (
round_up(hidden_dim_bytes, align),
round_up(hidden_scale_bytes, align),
)
# The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def __init__(self,
a2a: pplx.AllToAll,
max_num_tokens: int,
world_size: int,
rank: int,
dp_size: int,
quant_dtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token: bool = False):
def __init__(
self,
a2a: pplx.AllToAll,
max_num_tokens: int,
world_size: int,
rank: int,
dp_size: int,
):
super().__init__()
assert max_num_tokens > 0
self.a2a = a2a
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.rank = rank
self.dp_size = dp_size
self.quant_dtype = quant_dtype
self.per_act_token = per_act_token
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
def max_num_tokens_per_rank(self) -> Optional[int]:
return self.max_num_tokens
@ -45,36 +88,43 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
rank_topk_weights: torch.Tensor,
rank_topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
assert rank_topk_ids.size(0) == num_tokens
assert topk_ids.size(0) == num_tokens
# assert expert_map is None, "NYI"
# Is this always going to be a1.device?
device = a1.device
if apply_router_weight_on_input:
topk = rank_topk_ids.size(1)
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1")
a1 = a1 * rank_topk_weights.to(a1.dtype)
a1 = a1 * topk_weights.to(a1.dtype)
repeat_cols = 4
repeat_rows = 1 if self.per_act_token else a1.size(0)
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
a1q, a1q_scale = moe_kernel_quantize_input(
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
self.per_act_token, self.block_shape)
a1, (None if quant_config.per_act_token_quant else a1_scale),
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape)
if a1q_scale is not None:
if a1q_scale.numel() == 1:
orig_a_scale_block_shape = 1
else:
orig_a_scale_block_shape = a1q_scale.shape[-1]
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
# rem_experts need to be 0 for pplx to work properly.
@ -98,15 +148,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x_scale: Optional[torch.Tensor] = None
if a1q.dtype.itemsize == 1:
float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size
block_size = (quant_config.block_shape[1]
if quant_config.block_shape is not None else 1)
expert_x_scale = torch.empty(
(
num_local_experts,
expert_x.size(1),
(expert_x.size(2) + block_size - 1) // block_size,
),
(num_local_experts, expert_x.size(1),
round_up(
(expert_x.size(2) + block_size - 1) // block_size, 4)),
dtype=torch.float32,
device=device,
)
@ -121,11 +168,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=rank_topk_ids,
indices=topk_ids,
bound_m=bound_m,
)
if expert_x_scale is not None:
expert_x_scale = expert_x_scale[:, :, 0:1]
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
return expert_x, expert_x_scale, expert_num_tokens, None, None

View File

@ -5,6 +5,7 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
_moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.utils import (
@ -13,16 +14,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
def __init__(
self,
quant_dtype: Optional[torch.dtype] = None,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
):
super().__init__()
self.per_channel_quant = per_channel_quant
self.block_shape = block_shape
self.quant_dtype = quant_dtype
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> Optional[int]:
return None
@ -39,7 +33,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool = False,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
@ -50,10 +45,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
self.quant_dtype,
self.per_channel_quant,
self.block_shape)
a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1_scale, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
return a1q, a1q_scale, None, None, None

View File

@ -5,6 +5,7 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
@ -12,34 +13,59 @@ from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
allow_deep_gemm: bool = False):
super().__init__()
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
block_m=block_m)
self.allow_deep_gemm = allow_deep_gemm
self.use_fp8_w8a8 = use_fp8_w8a8
def __init__(
self,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_deep_gemm: bool = False,
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
))
self.triton_expert = TritonExperts(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
)
self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant
and use_fp8_w8a8)
self.deep_gemm_expert = DeepGemmExperts(
) if self.allow_deep_gemm else None
@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
assert (self.deep_gemm_expert is None
or self.triton_expert.activation_formats
== self.deep_gemm_expert.activation_formats)
return self.triton_expert.activation_formats
def supports_chunking(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return ((dge is None or dge.supports_chunking())
and (te is None or te.supports_chunking()))
def supports_expert_map(self) -> bool:
dge = self.deep_gemm_expert
te = self.triton_expert
return ((dge is None or dge.supports_expert_map())
and (te is None or te.supports_expert_map()))
def workspace_shapes(
self,
a: torch.Tensor,
@ -83,9 +109,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
N = w1.size(1)
use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
use_deep_gemm = (self.allow_deep_gemm
and _valid_deep_gemm(hidden_states, w1, w2))
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert

View File

@ -37,6 +37,7 @@ def _fp8_quantize(
A, A_scale = ops.scaled_fp8_quant(
A, A_scale, use_per_token_if_dynamic=per_act_token)
else:
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_fp8(A, block_k)
@ -64,6 +65,7 @@ def _int8_quantize(
"int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
else:
assert not per_act_token
assert len(block_shape) == 2
_, block_k = block_shape[0], block_shape[1]
A, A_scale = per_token_group_quant_int8(A, block_k)
@ -75,16 +77,15 @@ def _int8_quantize(
def moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
qtype: Optional[torch.dtype],
per_channel_quant: bool,
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if qtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
elif qtype == torch.int8:
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
assert A_scale is None
return A, A_scale
@ -96,3 +97,17 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
else:
return m[idx, ...]
# TODO(bnell): better name
def maybe_fix_scales(scales: Optional[torch.Tensor],
num_experts: int) -> Optional[torch.Tensor]:
if scales is not None and scales.ndim < 3:
if scales.numel() == 1:
scales = scales.view(1)
scales = torch.repeat_interleave(scales, num_experts,
dim=0).view(num_experts, 1, 1)
else:
scales = scales.view(num_experts, -1, scales.size(-1))
return scales

View File

@ -13,8 +13,10 @@ from compressed_tensors.quantization import (ActivationOrdering,
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe import (
CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig,
FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
from vllm.model_executor.layers.quantization.utils import replace_parameter
@ -32,14 +34,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import has_pplx
if current_platform.is_cuda_alike():
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize)
if has_pplx():
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
logger = init_logger(__name__)
@ -569,15 +563,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
requires_grad=False)
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
else:
from vllm.model_executor.layers.fused_moe import fused_experts
self.fused_experts_func = fused_experts
if self.use_marlin:
elif self.use_marlin:
prepare_moe_fp8_layer_for_marlin(layer, False)
# Activations not quantized for marlin.
del layer.w13_input_scale
del layer.w2_input_scale
self.fused_experts_func = None
else:
self.fused_experts_func = fused_experts
def apply(
self,
@ -653,6 +646,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=global_num_experts,
expert_map=expert_map)
assert self.fused_experts_func is not None
return self.fused_experts_func(
hidden_states=x,
w1=layer.w13_weight,
@ -826,28 +821,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassExpertsFp8)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
assert moe is not None
use_batched_format = (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts)
num_experts = (moe.num_local_experts
if use_batched_format else moe.num_experts)
max_experts_per_worker = (
(moe.num_experts + prepare_finalize.world_size - 1) //
prepare_finalize.world_size)
experts = CutlassExpertsFp8(
max_experts_per_worker,
num_experts,
moe.in_dtype,
self.input_quant.strategy == QuantizationStrategy.TOKEN,
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
use_batched_format=True,
use_batched_format=use_batched_format,
)
if has_pplx() and isinstance(
prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
# no expert_map support in this case
self.disable_expert_map = True
self.disable_expert_map = not experts.supports_expert_map()
return experts
def apply(
@ -888,7 +882,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
indices_type=self.topk_indices_dtype,
)
return self.fused_experts(
x,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional
import torch
import torch.nn.functional as F
@ -13,8 +13,11 @@ import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe import (
BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat,
FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported,
TritonOrDeepGemmExperts)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -777,44 +780,46 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize, moe):
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
TritonOrDeepGemmExperts]] = None
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
use_batched_experts = max_num_tokens_per_rank is not None
if use_batched_experts:
experts = BatchedTritonOrDeepGemmExperts(
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
logger.debug(
"BatchedTritonOrDeepGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False)
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
world_size=prepare_finalize.world_size,
dp_size=prepare_finalize.dp_size,
world_size=prepare_finalize.
world_size, # type: ignore [attr-defined]
dp_size=prepare_finalize.
dp_size, # type: ignore [attr-defined]
use_fp8_w8a8=True,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
per_channel_quant=False,
block_shape=self.quant_config.weight_block_size,
per_act_token_quant=False,
allow_deep_gemm=self.allow_deep_gemm,
)
else:
experts = TritonOrDeepGemmExperts(
logger.debug(
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size,
False)
return TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
assert experts is not None
return experts
def apply(
self,
layer: torch.nn.Module,