Compare commits

...

20 Commits

Author SHA1 Message Date
1236aebf0e Merge remote-tracking branch 'origin/main' into fp8_ep_dp 2025-06-02 14:53:27 -04:00
95c40f9b09 hacks
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-30 02:33:58 +00:00
a0efd3106c hack fix MoEConfig.quant_dtype
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-30 02:08:21 +00:00
e69879996f re-enable cudagraph+torch.compile
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-30 00:12:54 +00:00
922165cba3 fp8 + pplx tests + fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-29 21:25:33 +00:00
12ea698498 pplx + fp8 test
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-29 18:50:37 +00:00
caca0b718a fixes
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-29 02:08:22 +00:00
d86e3f0172 lint
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:56 +00:00
3ca8322b74 lint
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:56 +00:00
03b41b6cad fix merge
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:56 +00:00
cad6447664 fix
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:56 +00:00
c169b05541 merge
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:56 +00:00
468d16654a cleanup quantization
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:53 +00:00
909f234faa stuff
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
f8510587c2 tests + fix
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
9cfebf51ba basic working test
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
77f95b99a6 test
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
bbe888d033 wip
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
25ed6738d4 wip
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
e568e401da fp8 support
Signed-off-by: Bill Nell <bnell@redhat.com>
2025-05-28 23:40:27 +00:00
9 changed files with 729 additions and 121 deletions

View File

@ -1,18 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Optional
import pytest
import torch
import triton.language as tl
import vllm._custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts,
invoke_moe_batched_triton_kernel)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import round_up
NUM_EXPERTS = [8, 64]
TOP_KS = [1, 2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@dataclass
class BatchedMMConfig:
dtype: torch.dtype
in_dtype: torch.dtype
out_dtype: torch.dtype
num_experts: int
max_tokens_per_expert: int
K: int
@ -28,17 +48,26 @@ class BatchedMMTensors:
@staticmethod
def make_tensors(config: BatchedMMConfig):
if config.in_dtype == torch.torch.float8_e4m3fn:
config_in_dtype = torch.bfloat16
else:
config_in_dtype = config.in_dtype
A = torch.randn(
(config.num_experts, config.max_tokens_per_expert, config.K),
device="cuda",
dtype=config.dtype) / 10
dtype=config_in_dtype) / 10
B = torch.randn((config.num_experts, config.N, config.K),
device="cuda",
dtype=config.dtype)
dtype=config_in_dtype)
C = torch.zeros(
(config.num_experts, config.max_tokens_per_expert, config.N),
device="cuda",
dtype=config.dtype)
dtype=config.out_dtype)
A = A.to(config.in_dtype)
B = B.to(config.in_dtype)
num_expert_tokens = torch.randint(low=0,
high=config.max_tokens_per_expert,
size=(config.num_experts, ),
@ -47,16 +76,96 @@ class BatchedMMTensors:
return BatchedMMTensors(A, B, C, num_expert_tokens)
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_expert_tokens: torch.Tensor) -> torch.Tensor:
def native_w8a8_block_matmul(A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size,
output_dtype=torch.bfloat16):
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
fp8 data types.
It takes two input tensors `A` and `B` (int8) with scales `As` and
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32).contiguous()
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def ref_impl(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
num_expert_tokens: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
block_shape: Optional[list[int]],
) -> torch.Tensor:
num_expert_tokens_cpu = num_expert_tokens.clone()
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
num_experts = num_expert_tokens.size(0)
for e in range(num_experts):
num_tokens = num_expert_tokens_cpu[e]
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
if A.dtype == torch.torch.float8_e4m3fn:
if False:
tmp = native_w8a8_block_matmul(A[e, :, :],
B[e].transpose(0, 1), A_scale,
B_scale, block_shape)
else:
tmp = ops.cutlass_scaled_mm(A[e, :, :], B[e].transpose(0, 1),
A_scale, B_scale, torch.bfloat16)
C[e, :num_tokens, :] = tmp[:num_tokens, :]
else:
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
return C
@ -66,22 +175,45 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
[32, 64, 128, 192, 224, 256, 512])
@pytest.mark.parametrize("K", [128, 256, 1024])
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"dtype",
[torch.torch.float8_e4m3fn, torch.float32, torch.float16, torch.bfloat16])
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
N: int, dtype: torch.dtype):
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
if dtype == torch.torch.float8_e4m3fn:
in_dtype = dtype
out_dtype = torch.bfloat16
else:
in_dtype = dtype
out_dtype = dtype
config = BatchedMMConfig(in_dtype, out_dtype, num_experts,
max_tokens_per_expert, K, N)
tensors = BatchedMMTensors.make_tensors(config)
test_output = tensors.C
ref_output = test_output.clone()
ref_output2 = test_output.clone()
compute_tl_dtype = {
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
torch.float32: tl.float32
}[test_output.dtype]
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
block_shape = [16, 16, 32] # 16 for k if not fp8
if use_fp8_w8a8:
A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device)
B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device)
quant_block_shape = [1, 1]
else:
A_scale = None
B_scale = None
quant_block_shape = None
invoke_moe_batched_triton_kernel(
tensors.A,
tensors.B,
@ -89,21 +221,30 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
tensors.num_expert_tokens,
compute_tl_dtype,
# Quantization data
None,
None,
A_scale,
B_scale,
None,
# Quantization schemes
False,
use_fp8_w8a8,
False,
False,
config={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 16
})
"BLOCK_SIZE_M": block_shape[0],
"BLOCK_SIZE_N": block_shape[1],
"BLOCK_SIZE_K": block_shape[2],
},
block_shape=quant_block_shape,
)
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
tensors.num_expert_tokens)
ref_output = ref_output.to(dtype=out_dtype)
ref_output = ref_impl(tensors.A.to(dtype=out_dtype),
tensors.B.to(dtype=out_dtype), ref_output,
tensors.num_expert_tokens, A_scale, B_scale,
block_shape[-2:])
ref_output2 = ref_impl(tensors.A, tensors.B, ref_output2,
tensors.num_expert_tokens, A_scale, B_scale,
block_shape[-2:])
rtol, atol = {
torch.float16: (6e-2, 6e-2),
@ -111,4 +252,154 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
torch.float32: (1e-2, 1e-2),
}[test_output.dtype]
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol)
torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol)
def batched_moe(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
per_act_token: bool = False,
) -> torch.Tensor:
max_num_tokens = round_up(a.shape[0], 64)
fused_experts = FusedMoEModularKernel(
BatchedPrepareAndFinalize(max_num_tokens,
world_size=1,
dp_size=1,
rank=0,
qtype=qtype,
block_shape=block_shape,
per_act_token=per_act_token),
BatchedTritonExperts(max_num_tokens=max_num_tokens,
dp_size=1,
world_size=1,
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
block_shape=block_shape))
return fused_experts(a,
w1,
w2,
topk_weight,
topk_ids,
w1_scale=w1_scale,
w2_scale=w2_scale)
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
if use_fp8_w8a8:
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
else:
a_scale = None
out = torch.zeros(M * topk,
w2.shape[1],
dtype=torch.bfloat16,
device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
if not use_fp8_w8a8:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
else:
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
torch.bfloat16)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
torch.bfloat16)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [32, 45, 64]) #[1, 33, 64, 222])
@pytest.mark.parametrize("n", [128, 512, 1024, 2048])
@pytest.mark.parametrize("k", [128, 512, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
def test_fused_moe_batched_experts(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
):
current_platform.seed_everything(7)
block_shape = [128, 128]
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
use_fp8_w8a8 = dtype == torch.torch.float8_e4m3fn
qtype = dtype if dtype == torch.torch.float8_e4m3fn else None
if use_fp8_w8a8:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (n + block_k - 1) // block_k
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
factor_for_scale = 1e-2
w1_s = torch.rand(
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
device="cuda") * factor_for_scale
w2_s = torch.rand(
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
device="cuda") * factor_for_scale
else:
w1_s = None
w2_s = None
with set_current_vllm_config(vllm_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, qtype, block_shape)
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s,
w2_s, use_fp8_w8a8, block_shape)
torch.testing.assert_close(baseline_output,
batched_output,
atol=2e-2,
rtol=0)

View File

@ -33,7 +33,10 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
get_default_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.platforms import current_platform
from vllm.utils import round_up
PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512),
(222, 2048, 1024)]
@ -74,6 +77,11 @@ class ProcessGroupInfo:
device: torch.device
@pytest.fixture(scope="function", autouse=True)
def use_pplx_backend(monkeypatch):
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "pplx")
def _worker_parallel_launch(
local_rank: int,
world_size: int,
@ -275,6 +283,70 @@ def batched_moe(
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
def native_w8a8_block_matmul(A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size,
output_dtype=torch.bfloat16):
"""This function performs matrix multiplication with block-wise
quantization using native torch.
It is agnostic to the input data type and can be used for both int8 and
fp8 data types.
It takes two input tensors `A` and `B` (int8) with scales `As` and
`Bs` (float32).
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32).contiguous()
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1], (
f"{(A.shape[-1] + block_k - 1) // block_k} == {As.shape[-1]}")
assert A.shape[:-1] == As.shape[:-1], f"{A.shape} == {As.shape}"
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N, )
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
]
B_tiles = [[
B[
j * block_n:min((j + 1) * block_n, N),
i * block_k:min((i + 1) * block_k, K),
] for i in range(k_tiles)
] for j in range(n_tiles)]
C_tiles = [
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
]
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
# Note: same as torch_moe but with fused_topk factored out.
def torch_moe2(
a: torch.Tensor,
@ -282,17 +354,44 @@ def torch_moe2(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
use_fp8_w8a8: bool = False,
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
M, K = a.shape
topk = topk_ids.shape[1]
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
if use_fp8_w8a8:
a, a_scale = per_token_group_quant_fp8(a, block_shape[1])
else:
a_scale = None
out = torch.zeros(M * topk,
w2.shape[1],
dtype=torch.bfloat16,
device=a.device)
num_experts = w1.shape[0]
for i in range(num_experts):
mask = (topk_ids == i).view(-1)
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
if not use_fp8_w8a8:
tmp1 = a[mask] @ w1[i].transpose(0, 1)
tmp2 = SiluAndMul()(tmp1)
out[mask] = tmp2 @ w2[i].transpose(0, 1)
else:
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
w1_scale[i], block_shape,
torch.bfloat16)
tmp2 = SiluAndMul()(tmp1)
tmp2, b_scale = per_token_group_quant_fp8(tmp2, block_shape[1])
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
w2_scale[i], block_shape,
torch.bfloat16)
return (out.view(M, -1, w2.shape[1]) *
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
@ -497,6 +596,10 @@ def pplx_moe(
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
use_compile: bool = True,
use_cudagraphs: bool = True,
) -> torch.Tensor:
@ -506,9 +609,17 @@ def pplx_moe(
device = torch.device("cuda", rank)
hidden_dim = a.shape[1]
num_experts = w1.shape[0]
block_size = 128
block_size = block_shape[1] if block_shape is not None else 128
topk = topk_ids.shape[1]
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
if qtype is not None:
a_dtype = qtype
# This is probably not right
scale_bytes = round_up(((hidden_dim + block_size - 1) // block_size) * torch.float32.itemsize, 16)
else:
a_dtype = a.dtype
scale_bytes = 0
ata = AllToAll.internode(
max_num_tokens=max_num_tokens,
@ -518,10 +629,8 @@ def pplx_moe(
world_size=world_size,
dp_size=dp_size,
hidden_dim=hidden_dim,
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
((hidden_dim + block_size - 1) // block_size *
torch.float32.itemsize)),
hidden_dim_bytes=hidden_dim * a_dtype.itemsize,
hidden_dim_scale_bytes=scale_bytes,
)
topk_ids = topk_ids.to(dtype=torch.uint32)
@ -532,11 +641,15 @@ def pplx_moe(
world_size,
rank,
dp_size,
quant_dtype=qtype,
block_shape=block_shape,
)
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
world_size=world_size,
dp_size=dp_size)
dp_size=dp_size,
use_fp8_w8a8=qtype==torch.float8_e4m3fn,
block_shape=block_shape)
fused_experts = FusedMoEModularKernel(
prepare_finalize,
@ -552,6 +665,13 @@ def pplx_moe(
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
if w1_scale is not None:
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
else:
w1_scale_chunk = None
w2_scale_chunk = None
if use_compile:
_fused_experts = torch.compile(fused_experts,
backend='inductor',
@ -564,6 +684,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
if use_cudagraphs:
@ -576,6 +698,8 @@ def pplx_moe(
w2_chunk,
chunk_topk_weight,
chunk_topk_ids,
w1_scale=w1_scale_chunk,
w2_scale=w2_scale_chunk,
global_num_experts=num_experts)
torch.cuda.synchronize()
@ -638,6 +762,10 @@ def _pplx_moe(
w2: torch.Tensor,
score: torch.Tensor,
topk: int,
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
qtype: Optional[torch.dtype] = None,
block_shape: Optional[list[int]] = None,
):
uid = nvshmem_get_unique_id(
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
@ -649,11 +777,20 @@ def _pplx_moe(
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
use_fp8_w8a8 = qtype == torch.float8_e4m3fn
device = torch.device("cuda", pgi.rank)
a = a.to(device)
w1 = w1.to(device)
w2 = w2.to(device)
w1_s = w1_s.to(device) if w1_s is not None else None
w2_s = w2_s.to(device) if w2_s is not None else None
with set_current_vllm_config(vllm_config), override_config(moe_config):
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids, w1_s, w2_s, use_fp8_w8a8, block_shape)
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
topk_weight, topk_ids)
topk_weight, topk_ids, w1_s, w2_s, qtype, block_shape)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
@ -670,7 +807,7 @@ def _pplx_moe(
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
@requires_pplx
def test_pplx_moe(
@ -683,9 +820,40 @@ def test_pplx_moe(
current_platform.seed_everything(7)
m, n, k = mnk
world_size, dp_size = world_dp_size
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16) / 10
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if use_fp8_w8a8:
block_shape = [128, 128]
quant_type = torch.float8_e4m3fn
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (n + block_k - 1) // block_k
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
w1 = w1.clamp(min=fp8_min, max=fp8_max).to(dtype)
w2 = w2.clamp(min=fp8_min, max=fp8_max).to(dtype)
factor_for_scale = 1e-2
w1_s = torch.rand(
(e, n_tiles_w1, k_tiles_w1), dtype=torch.float32,
device="cuda") * factor_for_scale
w2_s = torch.rand(
(e, n_tiles_w2, k_tiles_w2), dtype=torch.float32,
device="cuda") * factor_for_scale
else:
block_shape = None
quant_type = None
w1_s = None
w2_s = None
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk, w1_s, w2_s, quant_type, block_shape)

View File

@ -83,6 +83,9 @@ class PPLXAll2AllManager(All2AllManagerBase):
assert has_pplx, "pplx_kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install pplx_kernels." # noqa
super().__init__(cpu_group)
# Intranode doesn't work yet.
self.internode = True
if self.internode:
# inter-node communication needs nvshmem,
# intra-node communication uses p2p mapping directly

View File

@ -4,7 +4,8 @@ from contextlib import contextmanager
from typing import Any, Optional
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
MOE_DP_CHUNK_SIZE, FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
_config: Optional[dict[str, Any]] = None
@ -29,6 +30,7 @@ __all__ = [
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
"MOE_DP_CHUNK_SIZE",
]
if HAS_TRITON:

View File

@ -9,7 +9,9 @@ import triton.language as tl
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, try_get_optimal_moe_config)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache, moe_kernel_quantize_input)
from vllm.utils import round_up
@triton.jit
@ -315,8 +317,8 @@ def invoke_moe_batched_triton_kernel(
expert_num_tokens: torch.Tensor, # [E]
compute_type: tl.dtype,
# Quantization data
A_scale: torch.Tensor,
B_scale: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
B_zp: torch.Tensor,
# Quantization schemes
use_fp8_w8a8: bool,
@ -335,7 +337,7 @@ def invoke_moe_batched_triton_kernel(
BLOCK_K = config['BLOCK_SIZE_K']
assert (torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()
or max_num_tokens % BLOCK_M == 0)
or max_num_tokens % BLOCK_M == 0), f"{max_num_tokens} {BLOCK_M}"
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
triton.cdiv(B.size(1), BLOCK_N))
@ -388,13 +390,22 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
that the PPLX dispatch/combine kernels use.
"""
def __init__(self, max_num_tokens: Optional[int], world_size: int,
dp_size: int, rank: int):
def __init__(self,
max_num_tokens: Optional[int],
world_size: int,
dp_size: int,
rank: int,
qtype: Optional[torch.dtype] = None,
per_act_token: bool = False,
block_shape: Optional[list[int]] = None):
super().__init__()
self.world_size = world_size
self.dp_size = dp_size
self.rank = rank
self.max_num_tokens = max_num_tokens
self.per_act_token = per_act_token
self.block_shape = block_shape
self.qtype = qtype
def prepare(
self,
@ -436,20 +447,47 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
b_a1 = torch.zeros(
(num_local_experts, self.max_num_tokens, hidden_dim),
dtype=a1.dtype,
dtype=self.qtype if self.qtype is not None else a1.dtype,
device=a1.device)
if self.qtype is not None:
_, block_k = self.block_shape
k_tiles = (hidden_dim + block_k - 1) // block_k
b_a1_scale = torch.zeros(
(num_local_experts, self.max_num_tokens, k_tiles),
dtype=torch.float32,
device=a1.device)
else:
assert a1_scale is None
b_a1_scale = None
first_expert = num_local_experts * self.rank
last_expert = first_expert + num_local_experts
for expert_id in range(first_expert, last_expert):
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
rows = torch.count_nonzero(topks.flatten())
b_a1[expert_id -
first_expert, :rows, :] = a1[:topks.numel()][topks]
tokens_per_expert[expert_id - first_expert] = rows
rhs = a1[:topks.numel()][topks]
idx = expert_id - first_expert
if self.qtype is not None:
if a1_scale is not None:
rhs_a1_scale = a1_scale[:topks.numel()][topks]
else:
rhs_a1_scale = None
b_a1[idx, :rows, :], b_a1_scale[idx, :rows] = (
moe_kernel_quantize_input(
rhs,
rhs_a1_scale,
self.qtype,
self.per_act_token,
self.block_shape,
))
else:
b_a1[idx, :rows, :] = rhs
return b_a1, a1_scale, tokens_per_expert
tokens_per_expert[idx] = rows
return b_a1, b_a1_scale, tokens_per_expert
def finalize(
self,
@ -499,15 +537,15 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
block_m: Optional[int] = None,
):
super().__init__()
assert block_shape is None
assert block_m is None
assert not use_fp8_w8a8, "NYI"
assert not use_int8_w8a8, "NYI"
assert not use_int8_w8a16, "NYI"
assert not use_int4_w4a16, "NYI"
self.max_num_tokens = max_num_tokens
self.world_size = world_size
self.dp_size = dp_size
self.use_fp8_w8a8 = use_fp8_w8a8
self.block_shape = block_shape
def workspace_shapes(
self,
@ -522,7 +560,6 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
workspace13 = num_experts * max_num_tokens * num_dp * K
workspace2 = max_num_tokens * num_dp * N
return (workspace13, workspace2, a.dtype)
@ -579,6 +616,7 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
else:
num = int(expert_num_tokens[expert].item())
tmp = _resize_cache(workspace2, (num, N))
assert not self.use_fp8_w8a8
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
self.activation(activation, tmp, input)
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
@ -586,6 +624,61 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
return out
def batched_moe_kernel_quantize_input(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
num_tokens: int,
E: int,
N: int,
expert_num_tokens: torch.Tensor,
qtype: Optional[torch.dtype],
per_channel_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if (True or
torch.compiler.is_compiling()
or torch.cuda.is_current_stream_capturing()):
# Note: this does a bunch of extra work because expert_num_tokens is ignored
# but it does support torch.compile + cudagraphs.
hidden_dim = A.size(-1)
if block_shape is not None:
block_shape = [block_shape[1], block_shape[0]]
assert A_scale is None or A_scale.dim() == 2
A_q, A_q_scale = moe_kernel_quantize_input(
A.view(-1, hidden_dim),
A_scale,
qtype,
per_channel_quant,
block_shape)
A_q = A_q.view(E, -1, hidden_dim)
if A_q_scale is not None:
A_q_scale = A_q_scale.view(E, -1, A_q_scale.size(-1))
return A_q, A_q_scale
if qtype is not None:
assert block_shape is not None
A_q = torch.empty_like(A, dtype=qtype)
block_n, block_k = block_shape
n_tiles = ((N // 2) + block_n - 1) // block_n
scale_shape = (E, num_tokens, n_tiles)
A_q_scale = torch.empty(scale_shape,
dtype=torch.float32,
device=A.device)
for e in range(E):
num_tokens = expert_num_tokens[e]
if num_tokens > 0:
A_q[e, :num_tokens, :], tmp_scale = moe_kernel_quantize_input(
A[e, :num_tokens],
A_scale[e, :num_tokens] if A_scale else None, qtype,
per_channel_quant, [block_k, block_n])
A_q_scale[e, :tmp_scale.shape[0]] = tmp_scale
return A_q, A_q_scale
else:
return A, A_scale
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""
A Triton based MoE expert class that operates on expert batched format,
@ -595,12 +688,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_num_tokens: Optional[int] = None,
max_num_tokens: int,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
block_shape: Optional[list[int]] = None,
per_act_token: bool = False,
world_size: int = 1,
dp_size: int = 1,
):
@ -610,11 +704,13 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.use_int4_w4a16 = use_int4_w4a16
self.use_int8_w8a16 = use_int8_w8a16
self.block_shape = block_shape
self.max_num_tokens = max_num_tokens
assert not use_int8_w8a8, "NYI"
assert not use_int4_w4a16, "NYI"
self.world_size = world_size
self.dp_size = dp_size
self.per_act_token = per_act_token
self.qtype = torch.float8_e4m3fn if self.use_fp8_w8a8 else None
self.max_num_tokens = max_num_tokens
def workspace_shapes(
self,
@ -627,10 +723,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
) -> tuple[int, int, torch.dtype]:
assert a.dim() == 2
num_dp = self.world_size // self.dp_size
max_num_tokens = a.size(
0) if self.max_num_tokens is None else self.max_num_tokens
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
workspace13 = num_experts * self.max_num_tokens * num_dp * max(K, N)
workspace2 = num_experts * self.max_num_tokens * num_dp * (N // 2)
return (workspace13, workspace2, a.dtype)
def apply(
@ -702,7 +796,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
raise ValueError(
f"Unsupported compute_type: {hidden_states.dtype}")
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N))
@ -730,15 +823,11 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
self.activation(activation, intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N))
#qintermediate_cache2 = intermediate_cache2
a2q_scale = a2_scale
# TODO (varun) : support w8a8
assert not self.use_fp8_w8a8
#if self.use_fp8_w8a8:
# qintermediate_cache2, a2q_scale = _fp8_quantize(
# intermediate_cache2, a2_scale, self.block_shape)
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
intermediate_cache2, a2_scale, num_tokens, E, N, expert_num_tokens,
self.qtype, self.per_act_token, self.block_shape)
invoke_moe_batched_triton_kernel(A=intermediate_cache2,
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
B=w2,
C=intermediate_cache3,
expert_num_tokens=expert_num_tokens,

View File

@ -1520,11 +1520,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
use_fp8_w8a8: bool,
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
per_channel_quant: bool,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
per_channel_quant: bool = False,
block_shape: Optional[list[int]] = None,
block_m: Optional[int] = None,
):

View File

@ -8,6 +8,9 @@ from typing import Callable, Optional
import torch
import torch.nn.functional as F
from compressed_tensors.quantization import (QuantizationArgs,
QuantizationStrategy,
QuantizationType)
from torch.nn.parameter import UninitializedParameter
import vllm.envs as envs
@ -26,7 +29,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op
from vllm.utils import direct_register_custom_op, cdiv
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
@ -56,7 +59,7 @@ logger = init_logger(__name__)
# Note: this limit is somewhat arbitrary and might be changed later.
# The size of the activations will be E x MOE_DP_CHUNK_SIZE x hidden_dim.
MOE_DP_CHUNK_SIZE = 256
MOE_DP_CHUNK_SIZE = 128
@dataclass
@ -72,7 +75,7 @@ class FusedMoEParallelConfig:
@property
def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and \
return self.dp_size > 1 and self.use_ep and has_pplx and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
@staticmethod
@ -191,7 +194,8 @@ class MoEConfig:
num_local_experts: int
moe_parallel_config: FusedMoEParallelConfig
in_dtype: torch.dtype # The activation type.
in_dtype: torch.dtype # The post quantization activation type.
quant_dtype: Optional[torch.dtype] = None
# TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128
@ -238,6 +242,18 @@ class FusedMoeWeightScaleSupported(Enum):
BLOCK = "block"
def get_quant_config_input_activations(
quant_config: Optional[QuantizationConfig]
) -> Optional[QuantizationArgs]:
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
and "Linear" in quant_config.target_scheme_map and
"input_activations" in quant_config.target_scheme_map["Linear"]):
return quant_config.target_scheme_map["Linear"].get(
"input_activations")
else:
return None
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
@ -253,6 +269,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = None
if moe.use_pplx_kernels:
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
if moe.quant_dtype is not None and moe.quant_dtype.itemsize == 1:
hidden_dim_bytes = moe.hidden_dim * moe.quant_dtype.itemsize
hidden_scale_bytes = (cdiv(moe.hidden_dim, moe.block_size) *
torch.float32.itemsize)
else:
hidden_dim_bytes = moe.hidden_dim * moe.in_dtype.itemsize
hidden_scale_bytes = 0
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
@ -262,18 +289,17 @@ class FusedMoEMethodBase(QuantizeMethodBase):
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
group_name=all2all_manager.cpu_group.group_name,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
if not all2all_manager.internode:
all_to_all_args["group_name"] = \
all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
logger.debug("PplxPrepareAndFinalize")
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
@ -281,7 +307,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
quant_dtype=moe.quant_dtype,
)
if prepare_finalize is not None:
@ -346,33 +372,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
return BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
return TritonExperts()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
@ -785,14 +796,32 @@ class FusedMoE(torch.nn.Module):
from vllm_hpu_extension.ops import DynamicFusedMOE
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
logger.debug("MODEL DTYPE %s", vllm_config.model_config.dtype)
quant_dtype: Optional[torch.dtype] = None
if quant_config is not None:
input_activations = get_quant_config_input_activations(
quant_config)
if (input_activations is not None
and input_activations.num_bits == 8):
if input_activations.type == QuantizationType.FLOAT:
quant_dtype = torch.float8_e4m3fn
elif input_activations.type == QuantizationType.INT:
quant_dtype = torch.int8
# Total hack
if quant_config.__class__.__name__ == "Fp8Config":
quant_dtype = torch.float8_e4m3fn
logger.info("QUANT_DTYPE %s", quant_dtype)
moe = MoEConfig(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype,
in_dtype=vllm_config.model_config.dtype,
quant_dtype=quant_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
)
self.moe_config = moe
@ -832,15 +861,14 @@ class FusedMoE(torch.nn.Module):
self.batched_hidden_states: Optional[torch.Tensor] = None
self.batched_router_logits: Optional[torch.Tensor] = None
if self.moe_parallel_config.use_pplx_kernels:
act_dtype = vllm_config.model_config.dtype
self.batched_hidden_states = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.hidden_size),
dtype=act_dtype,
dtype=vllm_config.model_config.dtype,
device=torch.cuda.current_device())
self.batched_router_logits = torch.zeros(
(MOE_DP_CHUNK_SIZE, self.global_num_experts),
dtype=act_dtype,
dtype=vllm_config.model_config.dtype,
device=torch.cuda.current_device())
@property
@ -1251,7 +1279,7 @@ class FusedMoE(torch.nn.Module):
assert (self.batched_hidden_states.size(0) # type: ignore
>= chunk_size)
assert (self.batched_router_logits.size(0) # type: ignore
assert (self.batched_router_logits.size(0) # type: ignore
>= chunk_size)
staged_hidden_states = self.batched_hidden_states[:
chunk_size, :] # type: ignore

View File

@ -66,6 +66,10 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
per_act_token,
self.block_shape)
if a1q_scale is not None and a1q_scale.dim() == 1:
assert a1q_scale.numel() == 1
a1q_scale = a1q_scale.view(1, 1)
# rem_experts need to be 0 for pplx to work properly.
rem_experts = num_experts % self.world_size
assert rem_experts == 0
@ -90,7 +94,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
float32_size = torch.float32.itemsize
block_size = (self.block_shape[0] if self.block_shape is not None
else 1) * float32_size
expert_x_scale = torch.empty(
expert_x_scale = torch.zeros(
(
num_experts,
expert_x.size(1),

View File

@ -11,9 +11,10 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_ep_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.fused_moe import (MOE_DP_CHUNK_SIZE, FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
@ -461,9 +462,13 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.fused_experts = functools.partial( # type: ignore
fused_experts,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm)
self.use_pplx_kernels = False
self.rocm_aiter_moe_enabled = False
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
@ -764,19 +769,38 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w2_input_scale
def select_gemm_impl(self, prepare_finalize):
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts)
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
PplxPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts)
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
"Marlin and ROCm AITER are not supported with all2all yet.")
experts = TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
return experts
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts(fp8)")
self.use_pplx_kernels = True
return BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
per_act_token=False, #?
)
else:
logger.debug("TritonOrDeepGemmExperts(fp8)")
return TritonOrDeepGemmExperts(
use_fp8_w8a8=True,
block_shape=self.quant_config.weight_block_size,
allow_deep_gemm=self.allow_deep_gemm,
)
def apply(
self,
@ -807,7 +831,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
indices_type=torch.uint32 if self.use_pplx_kernels else None)
if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
@ -854,7 +878,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,