mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernels] MoE refactor (#19636)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: ElizaWszola <ewszola@redhat.com>
This commit is contained in:
@ -113,6 +113,7 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
@ -124,7 +125,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -148,7 +150,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -227,6 +230,7 @@ def bench_run(
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -287,12 +291,13 @@ def bench_run(
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
190
tests/kernels/moe/parallel_utils.py
Normal file
190
tests/kernels/moe/parallel_utils.py
Normal file
@ -0,0 +1,190 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
rank=pgi.rank,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
)
|
||||
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||
block_shape)
|
@ -2,18 +2,59 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton.language as tl
|
||||
|
||||
from tests.kernels.moe.utils import (batched_moe,
|
||||
make_quantized_test_activations,
|
||||
make_test_weights, triton_moe)
|
||||
from tests.kernels.quant_utils import native_batched_masked_quant_matmul
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
invoke_moe_batched_triton_kernel)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 128, 2048),
|
||||
(1, 512, 512),
|
||||
(1, 1024, 128),
|
||||
(1, 1024, 2048),
|
||||
(32, 128, 128),
|
||||
(32, 512, 512),
|
||||
(32, 1024, 2048),
|
||||
(45, 128, 128),
|
||||
(45, 128, 2048),
|
||||
(45, 512, 512),
|
||||
(45, 1024, 128),
|
||||
(45, 1024, 2048),
|
||||
(64, 128, 128),
|
||||
(64, 512, 512),
|
||||
(64, 1024, 2048),
|
||||
(222, 128, 128),
|
||||
(222, 128, 2048),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 128),
|
||||
(222, 1024, 2048),
|
||||
]
|
||||
NUM_EXPERTS = [8, 64]
|
||||
TOP_KS = [1, 2, 6]
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchedMMConfig:
|
||||
dtype: torch.dtype
|
||||
in_dtype: torch.dtype
|
||||
quant_dtype: Optional[torch.dtype]
|
||||
out_dtype: torch.dtype
|
||||
num_experts: int
|
||||
max_tokens_per_expert: int
|
||||
K: int
|
||||
@ -32,79 +73,127 @@ class BatchedMMTensors:
|
||||
A = torch.randn(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype) / 10
|
||||
dtype=config.in_dtype) / 10
|
||||
B = torch.randn((config.num_experts, config.N, config.K),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config.in_dtype)
|
||||
C = torch.zeros(
|
||||
(config.num_experts, config.max_tokens_per_expert, config.N),
|
||||
device="cuda",
|
||||
dtype=config.dtype)
|
||||
dtype=config.out_dtype)
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=config.max_tokens_per_expert,
|
||||
size=(config.num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
return BatchedMMTensors(A, B, C, num_expert_tokens)
|
||||
|
||||
|
||||
def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_experts", [16, 32])
|
||||
@pytest.mark.parametrize("num_experts", [8, 16, 32])
|
||||
@pytest.mark.parametrize("max_tokens_per_expert",
|
||||
[32, 64, 128, 192, 224, 256, 512])
|
||||
@pytest.mark.parametrize("K", [128, 256, 1024])
|
||||
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("block_shape", [None])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
N: int, dtype: torch.dtype):
|
||||
N: int, dtype: torch.dtype,
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token_quant: bool):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N)
|
||||
tensors = BatchedMMTensors.make_tensors(config)
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
test_output = tensors.C
|
||||
ref_output = test_output.clone()
|
||||
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
|
||||
pytest.skip("Don't test blocking for non-quantized types.")
|
||||
|
||||
if per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
num_expert_tokens = torch.randint(low=0,
|
||||
high=max_tokens_per_expert,
|
||||
size=(num_experts, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
|
||||
A, A_q, A_scale = make_quantized_test_activations(
|
||||
num_experts,
|
||||
max_tokens_per_expert,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant)
|
||||
|
||||
B, B_q, B_scale, _, _, _ = make_test_weights(
|
||||
num_experts,
|
||||
N // 2,
|
||||
K,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
out_shape = (num_experts, max_tokens_per_expert, N)
|
||||
test_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
q_ref_output = torch.zeros(out_shape, dtype=act_dtype, device="cuda")
|
||||
|
||||
compute_tl_dtype = {
|
||||
torch.float16: tl.float16,
|
||||
torch.bfloat16: tl.bfloat16,
|
||||
torch.float32: tl.float32
|
||||
}[test_output.dtype]
|
||||
|
||||
assert A_q.dtype == B_q.dtype
|
||||
|
||||
invoke_moe_batched_triton_kernel(
|
||||
tensors.A,
|
||||
tensors.B,
|
||||
A_q,
|
||||
B_q,
|
||||
test_output,
|
||||
tensors.num_expert_tokens,
|
||||
num_expert_tokens,
|
||||
compute_tl_dtype,
|
||||
# Quantization data
|
||||
None,
|
||||
None,
|
||||
A_scale,
|
||||
B_scale,
|
||||
None,
|
||||
# Quantization schemes
|
||||
False,
|
||||
use_fp8_w8a8,
|
||||
False,
|
||||
False,
|
||||
config={
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 16
|
||||
})
|
||||
"BLOCK_SIZE_K": 16 if dtype.itemsize > 1 else 32
|
||||
},
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
ref_output = ref_impl(tensors.A, tensors.B, ref_output,
|
||||
tensors.num_expert_tokens)
|
||||
ref_output = native_batched_masked_quant_matmul(
|
||||
A,
|
||||
B,
|
||||
ref_output,
|
||||
num_expert_tokens,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
q_ref_output = native_batched_masked_quant_matmul(A_q, B_q, q_ref_output,
|
||||
num_expert_tokens,
|
||||
A_scale, B_scale,
|
||||
block_shape)
|
||||
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
@ -112,4 +201,98 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[test_output.dtype]
|
||||
|
||||
torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(ref_output, test_output, atol=atol, rtol=rtol)
|
||||
torch.testing.assert_close(test_output, q_ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("m", "n", "k"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False])
|
||||
@pytest.mark.parametrize("block_shape", [None])
|
||||
def test_fused_moe_batched_experts(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
use_fp8_w8a8 = dtype == torch.float8_e4m3fn
|
||||
|
||||
if not use_fp8_w8a8 and (per_act_token_quant or block_shape is not None):
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
if per_act_token_quant and block_shape is not None or topk > e:
|
||||
pytest.skip("Skip illegal quantization test.")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
if dtype.itemsize == 1:
|
||||
act_dtype = torch.bfloat16
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
act_dtype = dtype
|
||||
quant_dtype = None
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
block_shape=block_shape,
|
||||
in_dtype=act_dtype,
|
||||
quant_dtype=quant_dtype)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
batched_output = batched_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
baseline_output = torch_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
triton_output = triton_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
baseline_output,
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
batched_output,
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
|
296
tests/kernels/moe/test_block_fp8.py
Normal file
296
tests/kernels/moe/test_block_fp8.py
Normal file
@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
dg_available = False
|
||||
try:
|
||||
import deep_gemm
|
||||
dg_available = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
# Test configurations
|
||||
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4608, 128),
|
||||
(1, 4608, 512),
|
||||
(1, 4608, 7168),
|
||||
(83, 128, 128),
|
||||
(83, 512, 512),
|
||||
(83, 1024, 7168),
|
||||
(83, 4608, 512),
|
||||
(83, 4608, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
(8192, 128, 128),
|
||||
(8192, 512, 512),
|
||||
(8192, 128, 7168),
|
||||
(8192, 1024, 7168),
|
||||
(8192, 4608, 512),
|
||||
(8192, 4608, 7168),
|
||||
]
|
||||
|
||||
MNK_FACTORS_DG = [
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 128, 7168),
|
||||
(128, 1024, 7168),
|
||||
(128, 4608, 128),
|
||||
(128, 4608, 512),
|
||||
(128, 4608, 7168),
|
||||
(192, 128, 128),
|
||||
(192, 512, 512),
|
||||
(192, 1024, 7168),
|
||||
(192, 4608, 512),
|
||||
(192, 4608, 7168),
|
||||
(1335, 128, 128),
|
||||
(1335, 1024, 7168),
|
||||
(1335, 4608, 512),
|
||||
(1335, 4608, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 512, 512),
|
||||
(2048, 128, 7168),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4608, 128),
|
||||
(2048, 4608, 512),
|
||||
(2048, 4608, 7168),
|
||||
]
|
||||
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weight, topk_ids,
|
||||
block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
topk = topk_ids.size(1)
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_cuda():
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "2048")
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
w1_s,
|
||||
w2_s,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
block_size,
|
||||
)
|
||||
|
||||
out = fused_experts(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
|
||||
m_out = m_fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
)
|
||||
|
||||
# 0.039 only needed for [40000-4608-7168-2-1-block_size852-dtype852-0]
|
||||
tol = 0.035 if M < 40000 else 0.039
|
||||
torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
|
||||
torch.testing.assert_close(m_out, ref_out, atol=tol, rtol=tol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS_DG)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids, block_size)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(out, ref_out, atol=0.035, rtol=0.035)
|
147
tests/kernels/moe/test_block_int8.py
Normal file
147
tests/kernels/moe/test_block_int8.py
Normal file
@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_int8,
|
||||
native_w8a8_block_matmul)
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (7, 0):
|
||||
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(1, 128, 128),
|
||||
(1, 512, 512),
|
||||
(1, 128, 7168),
|
||||
(1, 1024, 7168),
|
||||
(1, 4096, 128),
|
||||
(1, 4096, 512),
|
||||
(1, 4096, 7168),
|
||||
(33, 128, 128),
|
||||
(33, 512, 512),
|
||||
(33, 128, 7168),
|
||||
(33, 1024, 7168),
|
||||
(33, 4096, 128),
|
||||
(33, 4096, 512),
|
||||
(33, 4096, 7168),
|
||||
(128, 128, 128),
|
||||
(128, 512, 512),
|
||||
(128, 1024, 7168),
|
||||
(128, 4096, 512),
|
||||
(128, 4096, 7168),
|
||||
(222, 128, 128),
|
||||
(222, 512, 512),
|
||||
(222, 1024, 7168),
|
||||
(222, 4096, 512),
|
||||
(222, 4096, 7168),
|
||||
(2048, 128, 128),
|
||||
(2048, 1024, 7168),
|
||||
(2048, 4096, 512),
|
||||
(2048, 4096, 7168),
|
||||
]
|
||||
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("M", "N", "K"), MNK_FACTORS)
|
||||
@pytest.mark.parametrize("E", E)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
||||
N,
|
||||
K,
|
||||
dtype,
|
||||
torch.int8,
|
||||
per_act_token_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
torch.testing.assert_close(out, ref_out, atol=0.065, rtol=0.065)
|
@ -97,11 +97,9 @@ class MOETensors8Bit(MOETensors):
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
_, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
||||
a_scale,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
a_q, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, None, use_per_token_if_dynamic=per_act_token)
|
||||
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
@ -187,6 +185,7 @@ def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
@ -203,7 +202,8 @@ def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
'topk_ids': topk_ids,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'a1_scale': moe_tensors.a_scale
|
||||
'per_act_token': per_act_token,
|
||||
'a1_scale': None #moe_tensors.a_scale
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
@ -254,11 +254,13 @@ def test_cutlass_moe_8_bit_no_graph(
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids, per_act_token)
|
||||
|
||||
# Note 5.5 only needed for larger problem sizes, 5 works ok for
|
||||
# the rest.
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@ -303,7 +305,8 @@ def test_cutlass_moe_8_bit_cuda_graph(
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids,
|
||||
per_act_token)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
@ -359,6 +362,7 @@ def test_cutlass_moe_8_bit_EP(
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_local_experts=e // ep_size)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Test DeepEP + DeepGEMM integration
|
||||
Test DeepEP + DeepGEMM integration
|
||||
DeepGEMM are gemm kernels specialized for the
|
||||
fp8 block-quantized case.
|
||||
"""
|
||||
@ -17,12 +17,11 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
from .utils import make_test_weights
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@ -30,10 +29,9 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
if has_deep_gemm():
|
||||
import deep_gemm
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
@ -60,25 +58,6 @@ def next_power_of_2(x):
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
@ -86,43 +65,11 @@ def make_block_quant_fp8_weights(
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1q, w2q, w1_scale, w2_scale
|
||||
Return weights w1q, w2q, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
||||
return w1q, w2q, w1_scale, w2_scale
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -132,6 +79,7 @@ class TestConfig:
|
||||
k: int
|
||||
n: int
|
||||
num_experts: int
|
||||
per_act_token_quant: bool
|
||||
block_size: list[int]
|
||||
# configs for testing low-latency kernels
|
||||
low_latency: bool
|
||||
@ -150,8 +98,7 @@ class TestTensors:
|
||||
def make(config: TestConfig, rank) -> "TestTensors":
|
||||
|
||||
dtype = torch.bfloat16
|
||||
topk, m, k, block_size = (config.topk, config.m, config.k,
|
||||
config.block_size)
|
||||
topk, m, k = (config.topk, config.m, config.k)
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
@ -159,9 +106,7 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0
|
||||
rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
block_k = block_size[1]
|
||||
_, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k)
|
||||
rank_token_scales = None
|
||||
|
||||
topk_ids = torch.randint(
|
||||
low=0,
|
||||
@ -201,10 +146,12 @@ def make_ll_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype=q_dtype,
|
||||
block_shape=test_config.block_size)
|
||||
|
||||
fused_experts = BatchedDeepGemmExperts(max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size)
|
||||
fused_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_tokens_per_rank,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
block_shape=test_config.block_size,
|
||||
per_act_token_quant=test_config.per_act_token_quant)
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
@ -426,6 +373,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
"""
|
||||
Tests for High-Throughput DeepEP + DeepGemm integration.
|
||||
"""
|
||||
import deep_gemm
|
||||
|
||||
m, n, k = mnk
|
||||
current_platform.seed_everything(7)
|
||||
@ -442,6 +390,7 @@ def test_ht_deepep_deepgemm_moe(mnk: tuple[int, int, int], num_experts: int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=False,
|
||||
use_fp8_dispatch=None)
|
||||
@ -474,10 +423,14 @@ USE_FP8_DISPATCH = [False]
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@requires_deep_ep
|
||||
@requires_deep_gemm
|
||||
def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
int], num_experts: int, topk: int,
|
||||
use_fp8_dispatch: bool, block_size: list[int],
|
||||
world_dp_size: tuple[int, int]):
|
||||
def test_ll_deepep_deepgemm_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
use_fp8_dispatch: bool,
|
||||
block_size: list[int],
|
||||
world_dp_size: tuple[int, int],
|
||||
):
|
||||
"""
|
||||
Tests for Low-Latency DeepEP + DeepGemm integration.
|
||||
"""
|
||||
@ -495,6 +448,7 @@ def test_ll_deepep_deepgemm_moe(mnk: tuple[int, int,
|
||||
k=k,
|
||||
n=n,
|
||||
num_experts=num_experts,
|
||||
per_act_token_quant=False,
|
||||
block_size=block_size,
|
||||
low_latency=True,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
|
@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
@ -31,7 +31,7 @@ if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
from .utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
from .parallel_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a
|
||||
|
||||
requires_deep_ep = pytest.mark.skipif(
|
||||
not has_deep_ep(),
|
||||
@ -102,10 +102,6 @@ class TestTensors:
|
||||
rank_tokens = torch.randn(
|
||||
(config.m, config.k), device="cuda", dtype=token_dtype) / 10
|
||||
rank_token_scales = None
|
||||
if config.dtype == torch.float8_e4m3fn:
|
||||
# low_latency_mode kernels dont support per-token quant.
|
||||
_, rank_token_scales = ops.scaled_fp8_quant(
|
||||
rank_tokens, use_per_token_if_dynamic=not low_latency_mode)
|
||||
|
||||
topk = torch.randint(low=0,
|
||||
high=config.num_experts,
|
||||
@ -121,11 +117,18 @@ class TestTensors:
|
||||
config=config)
|
||||
|
||||
|
||||
def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, hidden_size: int, dp_size: int,
|
||||
num_experts: int, num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool) -> FusedMoEModularKernel:
|
||||
def make_modular_kernel(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
hidden_size: int,
|
||||
dp_size: int,
|
||||
num_experts: int,
|
||||
num_local_experts: int,
|
||||
q_dtype: Optional[torch.dtype],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> FusedMoEModularKernel:
|
||||
|
||||
is_quantized = q_dtype is not None
|
||||
|
||||
@ -152,6 +155,7 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
deepep_ll_args = ll_args)
|
||||
|
||||
if low_latency_mode:
|
||||
assert not per_act_token_quant, "not supported in ll mode"
|
||||
fused_experts = BatchedTritonExperts(
|
||||
max_num_tokens=MAX_TOKENS_PER_RANK,
|
||||
world_size=pgi.world_size,
|
||||
@ -159,25 +163,37 @@ def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False)
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=False,
|
||||
)
|
||||
else:
|
||||
fused_experts = TritonExperts(use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False)
|
||||
fused_experts = TritonExperts(
|
||||
use_fp8_w8a8=is_quantized,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
)
|
||||
|
||||
mk = FusedMoEModularKernel(prepare_finalize=a2a,
|
||||
fused_experts=fused_experts)
|
||||
return mk
|
||||
|
||||
|
||||
def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool, dp_size: int,
|
||||
test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], num_experts: int,
|
||||
use_fp8_dispatch: bool) -> torch.Tensor:
|
||||
def deep_ep_moe_impl(
|
||||
pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
low_latency_mode: bool,
|
||||
dp_size: int,
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
num_experts: int,
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_local_experts = w1.size(0)
|
||||
|
||||
@ -199,11 +215,9 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
# Make modular kernel
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode,
|
||||
hidden_size, dp_size,
|
||||
num_experts,
|
||||
num_local_experts, q_dtype,
|
||||
use_fp8_dispatch)
|
||||
mk: FusedMoEModularKernel = make_modular_kernel(
|
||||
pg, pgi, low_latency_mode, hidden_size, dp_size, num_experts,
|
||||
num_local_experts, q_dtype, use_fp8_dispatch, per_act_token_quant)
|
||||
|
||||
out_hidden_states = torch.empty_like(test_tensors.rank_tokens)
|
||||
total_num_tokens = test_tensors.rank_tokens.size(0)
|
||||
@ -257,9 +271,15 @@ def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo,
|
||||
return out_hidden_states
|
||||
|
||||
|
||||
def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
w2: torch.Tensor, w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool):
|
||||
def torch_moe_impl(
|
||||
test_tensors: TestTensors,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
using_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk,
|
||||
test_tensors.topk_weights)
|
||||
@ -267,6 +287,7 @@ def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor,
|
||||
# The DeepEP implementation is requested to dispatch using FP8.
|
||||
# For numerical stability for testing, emulate the fp8 dispatch by
|
||||
# blockwise quant and de-quant.
|
||||
assert not per_act_token_quant
|
||||
a = test_tensors.rank_tokens
|
||||
aq, aq_scale = per_token_group_quant_fp8(a, 128)
|
||||
a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view(
|
||||
@ -310,6 +331,7 @@ def _deep_ep_moe(
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
use_fp8_dispatch: bool,
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
|
||||
if not low_latency_mode:
|
||||
@ -331,7 +353,8 @@ def _deep_ep_moe(
|
||||
with set_current_vllm_config(VllmConfig()):
|
||||
# Reference
|
||||
torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale,
|
||||
w2_scale, use_fp8_dispatch)
|
||||
w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
# Splice experts for this rank.
|
||||
num_local_experts = config.num_experts // pgi.world_size
|
||||
@ -356,6 +379,7 @@ def _deep_ep_moe(
|
||||
w2_scale_ep,
|
||||
config.num_experts,
|
||||
use_fp8_dispatch,
|
||||
per_act_token_quant,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(
|
||||
@ -384,10 +408,16 @@ DTYPES = [torch.bfloat16, torch.float8_e4m3fn]
|
||||
@pytest.mark.parametrize("num_experts", [32])
|
||||
@pytest.mark.parametrize("topk", [6])
|
||||
@pytest.mark.parametrize("world_dp_size", [(2, 1)])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@requires_deep_ep
|
||||
def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
num_experts: int, topk: int, world_dp_size: tuple[int,
|
||||
int]):
|
||||
def test_deep_ep_moe(
|
||||
dtype: torch.dtype,
|
||||
mnk: tuple[int, int, int],
|
||||
num_experts: int,
|
||||
topk: int,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
):
|
||||
low_latency_mode = False
|
||||
use_fp8_dispatch = False
|
||||
m, n, k = mnk
|
||||
@ -404,7 +434,8 @@ def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
per_act_token_quant)
|
||||
|
||||
|
||||
MNKs = [
|
||||
@ -454,4 +485,5 @@ def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int],
|
||||
w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype)
|
||||
|
||||
parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size,
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch)
|
||||
config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch,
|
||||
False)
|
||||
|
@ -17,6 +17,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
@ -142,6 +143,10 @@ def test_fused_moe(
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
#
|
||||
# Setup test data
|
||||
#
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
@ -169,7 +174,7 @@ def test_fused_moe(
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None)
|
||||
|
||||
def m_fused_moe(
|
||||
@ -365,6 +370,13 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
if dtype == torch.float32:
|
||||
pytest.skip("AITER ROCm test skip for float32")
|
||||
|
||||
monkeypatch.setenv('RANK', "0")
|
||||
monkeypatch.setenv('LOCAL_RANK', "0")
|
||||
monkeypatch.setenv('WORLD_SIZE', "1")
|
||||
monkeypatch.setenv('MASTER_ADDR', 'localhost')
|
||||
monkeypatch.setenv('MASTER_PORT', '12345')
|
||||
init_distributed_environment()
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
vllm_config.compilation_config.static_forward_context = dict()
|
||||
with (set_current_vllm_config(vllm_config),
|
||||
|
@ -14,7 +14,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||
pytest.skip("Nvfp4 Requires compute capability of 10 or above.",
|
||||
allow_module_level=True)
|
||||
|
||||
MNK_FACTORS = [
|
||||
|
@ -15,7 +15,7 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
@ -93,7 +93,7 @@ def pplx_cutlass_moe(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
rank=rank,
|
||||
world_size=pgi.world_size,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1
|
||||
@ -118,8 +118,6 @@ def pplx_cutlass_moe(
|
||||
pgi.world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token=per_act_token,
|
||||
)
|
||||
|
||||
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||
|
@ -18,18 +18,20 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.kernels.moe.utils import make_test_weights, naive_batched_moe
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk, override_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk,
|
||||
get_default_config)
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import get_default_config
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
from .utils import ProcessGroupInfo, parallel_launch
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
@ -144,25 +146,6 @@ def torch_batched_moe(
|
||||
return torch_finalize(out, topk_weight, topk_ids)
|
||||
|
||||
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 512, 1024])
|
||||
@ -188,7 +171,7 @@ def test_fused_moe_batched_experts(
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
baseline_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
batched_output = naive_batched_moe(a, w1, w2, topk_weight, topk_ids)
|
||||
|
||||
torch.testing.assert_close(baseline_output,
|
||||
torch_output,
|
||||
@ -226,7 +209,6 @@ def pplx_prepare_finalize(
|
||||
|
||||
topk = topk_ids.shape[1]
|
||||
num_tokens, hidden_dim = a.shape
|
||||
block_size = 128
|
||||
device = pgi.device
|
||||
rank = pgi.rank
|
||||
world_size = pgi.world_size
|
||||
@ -241,9 +223,7 @@ def pplx_prepare_finalize(
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_scale_bytes=0,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@ -260,7 +240,6 @@ def pplx_prepare_finalize(
|
||||
world_size,
|
||||
rank,
|
||||
dp_size,
|
||||
a.dtype,
|
||||
)
|
||||
|
||||
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
|
||||
@ -276,6 +255,7 @@ def pplx_prepare_finalize(
|
||||
num_experts,
|
||||
None,
|
||||
False,
|
||||
FusedMoEQuantConfig(),
|
||||
)
|
||||
|
||||
b_a = b_a * 1.5
|
||||
@ -350,6 +330,7 @@ def _pplx_prepare_finalize(
|
||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||
# written, not due to limitations of the pplx kernels. The pplx_moe
|
||||
# test below is able to deal with odd M.
|
||||
# TODO (bnell) add fp8 tests
|
||||
@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@ -386,18 +367,31 @@ def pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_compile: bool = False,
|
||||
use_cudagraphs: bool = True,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
PplxPrepareAndFinalize, pplx_hidden_dim_scale_bytes)
|
||||
|
||||
device = torch.device("cuda", rank)
|
||||
hidden_dim = a.shape[1]
|
||||
num_experts = w1.shape[0]
|
||||
block_size = 128
|
||||
topk = topk_ids.shape[1]
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
max_num_tokens = round_up(rank_chunk(a.shape[0], 0, world_size), 64)
|
||||
|
||||
hidden_dim_bytes, scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens,
|
||||
hidden_dim,
|
||||
a.dtype,
|
||||
qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
@ -407,10 +401,8 @@ def pplx_moe(
|
||||
world_size=world_size,
|
||||
dp_size=dp_size,
|
||||
hidden_dim=hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
|
||||
hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else
|
||||
((hidden_dim + block_size - 1) // block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=scale_bytes,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
@ -429,9 +421,11 @@ def pplx_moe(
|
||||
dp_size,
|
||||
)
|
||||
|
||||
experts = BatchedTritonExperts(max_num_tokens=a.shape[0],
|
||||
experts = BatchedTritonExperts(max_num_tokens=max_num_tokens,
|
||||
world_size=world_size,
|
||||
dp_size=dp_size)
|
||||
dp_size=dp_size,
|
||||
use_fp8_w8a8=qtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@ -447,6 +441,13 @@ def pplx_moe(
|
||||
w1_chunk = chunk_by_rank(w1, rank, world_size).to(device)
|
||||
w2_chunk = chunk_by_rank(w2, rank, world_size).to(device)
|
||||
|
||||
if w1_scale is not None:
|
||||
w1_scale_chunk = chunk_by_rank(w1_scale, rank, world_size).to(device)
|
||||
w2_scale_chunk = chunk_by_rank(w2_scale, rank, world_size).to(device)
|
||||
else:
|
||||
w1_scale_chunk = None
|
||||
w2_scale_chunk = None
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
@ -465,6 +466,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
if use_cudagraphs:
|
||||
@ -477,6 +480,8 @@ def pplx_moe(
|
||||
w2_chunk,
|
||||
chunk_topk_weight,
|
||||
chunk_topk_ids,
|
||||
w1_scale=w1_scale_chunk,
|
||||
w2_scale=w2_scale_chunk,
|
||||
global_num_experts=num_experts)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@ -505,9 +510,9 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
experts = BatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
experts = NaiveBatchedExperts(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1)
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@ -539,7 +544,12 @@ def _pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
use_internode: bool,
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
qtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_internode: bool = False,
|
||||
):
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
@ -557,11 +567,28 @@ def _pplx_moe(
|
||||
|
||||
moe_config = get_default_config(m, e, n, k, topk, a.dtype, False)
|
||||
|
||||
device = torch.device("cuda", pgi.rank)
|
||||
a = a.to(device)
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_s = w1_s.to(device) if w1_s is not None else None
|
||||
w2_s = w2_s.to(device) if w2_s is not None else None
|
||||
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_experts(a, w1, w2, topk_weight, topk_ids)
|
||||
torch_output = torch_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape)
|
||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||
a, w1, w2, topk_weight, topk_ids)
|
||||
a, w1, w2, topk_weight, topk_ids, w1_s, w2_s,
|
||||
qtype, per_act_token_quant, block_shape)
|
||||
# TODO (bnell): fix + re-enable
|
||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||
# topk_ids)
|
||||
@ -581,6 +608,8 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("per_act_token_quant", [False, True])
|
||||
@pytest.mark.parametrize("block_shape", [None, [128, 128]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@requires_pplx
|
||||
def test_pplx_moe(
|
||||
@ -589,15 +618,33 @@ def test_pplx_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
world_size, dp_size = world_dp_size
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
use_fp8_w8a8 = True
|
||||
quant_dtype = dtype
|
||||
else:
|
||||
use_fp8_w8a8 = False
|
||||
quant_dtype = None
|
||||
|
||||
if not use_fp8_w8a8 and per_act_token_quant and block_shape is not None:
|
||||
pytest.skip("Skip quantization test for non-quantized type")
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(e,
|
||||
n,
|
||||
k,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
||||
w1_s, w2_s, quant_dtype, per_act_token_quant, block_shape,
|
||||
use_internode)
|
||||
|
@ -1,194 +1,249 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (per_block_cast_to_fp8,
|
||||
per_block_cast_to_int8)
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import round_up
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
def triton_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_channel_quant=per_act_token_quant,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
def batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
## DeepEP specific utils
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPHTArgs:
|
||||
num_local_experts: int
|
||||
def naive_batched_moe(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
max_num_tokens = round_up(a.shape[0], 64)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DeepEPLLArgs:
|
||||
max_tokens_per_rank: int
|
||||
hidden_size: int
|
||||
num_experts: int
|
||||
use_fp8_dispatch: bool
|
||||
|
||||
|
||||
def make_deepep_ht_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
ht_args: DeepEPHTArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# high throughput a2a
|
||||
num_nvl_bytes = 1024 * 1024 * 1024 # 1GB
|
||||
num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=low_latency_mode,
|
||||
num_qps_per_rank=num_qps_per_rank)
|
||||
return DeepEPHTPrepareAndFinalize(buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
rank=pgi.rank,
|
||||
dp_size=dp_size,
|
||||
rank_expert_offset=pgi.rank *
|
||||
ht_args.num_local_experts,
|
||||
quant_dtype=q_dtype,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def make_deepep_ll_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ll_args: DeepEPLLArgs,
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
|
||||
import deep_ep
|
||||
|
||||
# low-latency a2a
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size,
|
||||
pgi.world_size, deepep_ll_args.num_experts)
|
||||
|
||||
buffer = deep_ep.Buffer(group=pg,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=deepep_ll_args.num_experts //
|
||||
pgi.world_size)
|
||||
|
||||
return DeepEPLLPrepareAndFinalize(
|
||||
buffer=buffer,
|
||||
world_size=pgi.world_size,
|
||||
dp_size=dp_size,
|
||||
max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank,
|
||||
quant_dtype=q_dtype,
|
||||
block_shape=block_shape,
|
||||
use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch,
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(max_num_tokens,
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
NaiveBatchedExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
dp_size=1,
|
||||
world_size=1,
|
||||
use_fp8_w8a8=quant_dtype == torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
||||
return fused_experts(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale)
|
||||
|
||||
def make_deepep_a2a(pg: ProcessGroup,
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
deepep_ht_args: Optional[DeepEPHTArgs],
|
||||
deepep_ll_args: Optional[DeepEPLLArgs],
|
||||
q_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
if deepep_ht_args is not None:
|
||||
assert deepep_ll_args is None
|
||||
return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype,
|
||||
block_shape)
|
||||
|
||||
assert deepep_ll_args is not None
|
||||
return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype,
|
||||
block_shape)
|
||||
def chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
def make_quantized_test_activations(
|
||||
E: int,
|
||||
m: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
a = torch.randn((E, m, k), device="cuda", dtype=in_dtype) / 10
|
||||
a_q = a
|
||||
a_scale = None
|
||||
|
||||
if quant_dtype is not None:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
a_q = torch.zeros_like(a, dtype=quant_dtype)
|
||||
a_scale_l = [None] * E
|
||||
for e in range(E):
|
||||
a_q[e], a_scale_l[e] = moe_kernel_quantize_input(
|
||||
a[e], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
a_scale = torch.stack(a_scale_l)
|
||||
|
||||
if not per_act_token_quant and block_shape is None:
|
||||
a_scale = a_scale.view(E, 1, 1)
|
||||
|
||||
return a, a_q, a_scale
|
||||
|
||||
|
||||
def moe_quantize_weights(
|
||||
w: torch.Tensor,
|
||||
w_s: Optional[torch.Tensor],
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
per_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert (quant_dtype == torch.float8_e4m3fn
|
||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
||||
|
||||
if block_shape is not None:
|
||||
assert not per_token_quant
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||
else:
|
||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||
else:
|
||||
if quant_dtype == torch.int8:
|
||||
w, w_s = ops.scaled_int8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
else:
|
||||
w, w_s = ops.scaled_fp8_quant(
|
||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||
|
||||
return w, w_s
|
||||
|
||||
|
||||
def make_test_weight(
|
||||
e: int,
|
||||
rows: int,
|
||||
cols: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||
|
||||
if quant_dtype is not None:
|
||||
w_l = [None] * e
|
||||
w_s_l = [None] * e
|
||||
for idx in range(e):
|
||||
w_l[idx], w_s_l[idx] = moe_quantize_weights(
|
||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||
|
||||
w = torch.stack(w_l)
|
||||
w_s = torch.stack(w_s_l)
|
||||
if w_s.ndim == 2:
|
||||
assert w_s.shape[-1] == 1
|
||||
w_s = w_s.view(-1, 1, 1)
|
||||
|
||||
if block_shape is not None:
|
||||
block_n, block_k = block_shape
|
||||
n_tiles = (rows + block_n - 1) // block_n
|
||||
k_tiles = (cols + block_k - 1) // block_k
|
||||
assert w_s.shape == (e, n_tiles, k_tiles)
|
||||
else:
|
||||
w = w_16
|
||||
w_s = None
|
||||
|
||||
return w_16, w, w_s
|
||||
|
||||
|
||||
def make_test_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
|
||||
torch.Tensor, Optional[torch.Tensor]]:
|
||||
return (
|
||||
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||
per_act_token_quant),
|
||||
)
|
||||
|
@ -5,7 +5,10 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
group_broadcast)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import round_up
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
@ -94,9 +97,15 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
return ref_out, ref_scale.view((1, ))
|
||||
|
||||
|
||||
def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
As: torch.Tensor, Bs: torch.Tensor, block_size,
|
||||
output_dtype):
|
||||
def native_w8a8_block_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype,
|
||||
compute_type: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization using native torch.
|
||||
It is agnostic to the input data type and can be used for both int8 and
|
||||
@ -106,8 +115,8 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
`Bs` (float32).
|
||||
The output is returned in the specified `output_dtype`.
|
||||
"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
A = A.to(compute_type)
|
||||
B = B.to(compute_type)
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
assert len(block_size) == 2
|
||||
@ -122,11 +131,11 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
As = As.reshape(M, As.shape[-1])
|
||||
n_tiles = (N + block_n - 1) // block_n
|
||||
k_tiles = (K + block_k - 1) // block_k
|
||||
assert n_tiles == Bs.shape[0]
|
||||
assert k_tiles == Bs.shape[1]
|
||||
assert n_tiles == Bs.shape[0], f"{n_tiles} == {Bs.shape[0]}"
|
||||
assert k_tiles == Bs.shape[1], f"{k_tiles} == {Bs.shape[1]}"
|
||||
|
||||
C_shape = (M, N)
|
||||
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
|
||||
C = torch.zeros(C_shape, dtype=compute_type, device=A.device)
|
||||
|
||||
A_tiles = [
|
||||
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
|
||||
@ -152,3 +161,152 @@ def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor,
|
||||
|
||||
C = C.reshape(origin_C_shape).to(output_dtype)
|
||||
return C
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.float8_e4m3fn):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch."""
|
||||
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` must "
|
||||
"be divisible by `group_size`")
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def native_per_token_group_quant_int8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.int8):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch.
|
||||
|
||||
It converts the tensor values into int8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` must be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_min = iinfo.min
|
||||
int8_max = iinfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
# Use float32 for scale calculation for stability
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / int8_max
|
||||
x_q = (x_.to(torch.float32) / x_s).round().clamp(
|
||||
min=int8_min, max=int8_max).to(dtype) # Round before clamping
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
DEFAULT_BLOCK_SHAPE = [128, 128]
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
block_m, block_n = block_shape
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def per_block_cast_to_int8(
|
||||
x: torch.Tensor,
|
||||
block_shape: list[int] = DEFAULT_BLOCK_SHAPE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
block_m, block_n = block_shape
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((round_up(m, block_m), round_up(n, block_n)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (256.0 / x_amax)).to(torch.int8)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 256.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def dequant(
|
||||
t: torch.Tensor,
|
||||
scale: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]],
|
||||
per_act_token_quant: bool,
|
||||
out_dtype: Optional[torch.dtype] = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
f32 = torch.float32
|
||||
if per_act_token_quant or block_shape is None:
|
||||
return (t.to(f32) * scale).to(out_dtype)
|
||||
else:
|
||||
return (t.to(f32) * group_broadcast(scale, t.shape)).to(out_dtype)
|
||||
else:
|
||||
return t.to(out_dtype)
|
||||
|
||||
|
||||
def native_batched_masked_quant_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
C: torch.Tensor,
|
||||
num_expert_tokens: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor] = None,
|
||||
B_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_expert_tokens_cpu = num_expert_tokens.clone()
|
||||
num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu")
|
||||
num_experts = num_expert_tokens.size(0)
|
||||
|
||||
for e in range(num_experts):
|
||||
num_tokens = num_expert_tokens_cpu[e]
|
||||
if A.dtype.itemsize == 1 and block_shape is not None:
|
||||
assert A_scale is not None and B_scale is not None
|
||||
tmp = native_w8a8_block_matmul(A[e], B[e], A_scale[e], B_scale[e],
|
||||
block_shape, C.dtype)
|
||||
C[e, :num_tokens, :] = tmp[:num_tokens, :]
|
||||
elif A.dtype.itemsize == 1 and block_shape is None:
|
||||
assert A_scale is not None and B_scale is not None
|
||||
A_dq = dequant(A[e], A_scale[e], block_shape, per_act_token_quant)
|
||||
B_dq = dequant(B[e], B_scale[e], block_shape, per_act_token_quant)
|
||||
C[e, :num_tokens, :] = (
|
||||
A_dq[:num_tokens] @ B_dq.transpose(0, 1)).to(C.dtype)
|
||||
else:
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1)
|
||||
|
||||
return C
|
||||
|
@ -7,16 +7,10 @@ import itertools
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from tests.kernels.quant_utils import (native_per_token_group_quant_fp8,
|
||||
native_w8a8_block_matmul,
|
||||
per_block_cast_to_fp8)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
@ -46,78 +40,10 @@ N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
|
||||
M_moe_dg = [128, 192, 1335, 2048]
|
||||
N_moe = [128, 256, 1024, 4608] # [13824]
|
||||
K_moe = [256, 512, 7168] # [13824]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
E = [2, 8, 16, 24] # [128, 256]
|
||||
TOP_KS = [1, 2, 6]
|
||||
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
def native_per_token_group_quant_fp8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.float8_e4m3fn):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch."""
|
||||
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
|
||||
"be divisible by `group_size`")
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / fp8_max
|
||||
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""Fused moe with block-wise quantization using native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
|
||||
a_q = a_q.to(torch.float32)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_fp8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
# Skip all tests if CUDA is not available
|
||||
pytest.importorskip("torch.cuda")
|
||||
|
||||
@ -177,111 +103,6 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,E,topk,block_size,dtype,seed",
|
||||
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, BLOCK_SIZE, DTYPES,
|
||||
SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test; topk={topk} > E={E}")
|
||||
|
||||
torch.manual_seed(seed)
|
||||
factor_for_scale = 1e-2
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = (torch.rand(
|
||||
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
del w1_bf16
|
||||
|
||||
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
|
||||
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||
del w2_bf16
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = torch.rand(
|
||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
|
||||
w2_s = torch.rand(
|
||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
m_out = m_fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=E,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(deep_gemm.ceil_div(m, 128) * 128,
|
||||
deep_gemm.ceil_div(n, block_size_n) * block_size_n),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,block_size,out_dtype,seed",
|
||||
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS))
|
||||
@ -324,187 +145,3 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
def fp8_perm(m, idx):
|
||||
if torch.is_floating_point(m) and torch.finfo(m.dtype).bits == 8:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
def _moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
|
||||
M, K = a.shape
|
||||
|
||||
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
|
||||
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
|
||||
|
||||
num_tokens = topk * M
|
||||
|
||||
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
|
||||
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
|
||||
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
|
||||
|
||||
a = fp8_perm(a, sorted_token_ids // topk)
|
||||
if a_s is not None:
|
||||
a_s = a_s[sorted_token_ids // topk]
|
||||
|
||||
return a, a_s, m_indices, inv_perm
|
||||
|
||||
|
||||
def _moe_unpermute(out, inv_perm, topk, K, topk_weight):
|
||||
M = topk_weight.shape[0]
|
||||
out = out[inv_perm, ...]
|
||||
tmp_out = out.view(-1, topk, K)
|
||||
return (tmp_out * topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_shape):
|
||||
"""Fused moe with block-wise quantization using DeepGemm grouped gemm."""
|
||||
num_groups = w1.shape[0]
|
||||
M, K = a.shape
|
||||
N = w2.shape[-1]
|
||||
|
||||
topk_weight, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
|
||||
a_q, a_s = per_token_group_quant_fp8(a, block_m)
|
||||
|
||||
a_q, a_s, m_indices, inv_perm = _moe_permute(a_q, a_s, topk_ids,
|
||||
num_groups, topk, block_m)
|
||||
|
||||
inter_out = torch.zeros((a_q.shape[0], N * 2),
|
||||
dtype=torch.bfloat16,
|
||||
device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s),
|
||||
inter_out, m_indices)
|
||||
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
|
||||
|
||||
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
|
||||
|
||||
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
|
||||
|
||||
final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight)
|
||||
|
||||
return final_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M,N,K,E,topk,seed",
|
||||
itertools.product(M_moe_dg, N_moe, K_moe, E, TOP_KS, SEEDS))
|
||||
@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.")
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
||||
monkeypatch):
|
||||
if topk > E:
|
||||
pytest.skip(f"Skipping test: topk={topk} > E={E}")
|
||||
|
||||
if not _valid_deep_gemm_shape(M, N, K):
|
||||
pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}")
|
||||
|
||||
chunk_size = 1024
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", str(chunk_size))
|
||||
|
||||
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
|
||||
block_size = [block_m, block_m]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_bf16 = ((torch.rand((E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
w2_bf16 = ((torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 *
|
||||
fp8_max).clamp(min=fp8_min, max=fp8_max)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * N) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn)
|
||||
|
||||
w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||
w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||
|
||||
w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous()
|
||||
w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous()
|
||||
|
||||
assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(E):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i])
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i])
|
||||
|
||||
# Note: for now use_compile will error out if the problem size is
|
||||
# large enough to trigger chunking. I'm leaving the flag and
|
||||
# setup code in case we are able to revisit this later.
|
||||
use_compile = False
|
||||
|
||||
use_cudagraph = (chunk_size < M and N >= 1024 and K >= 1024
|
||||
and current_platform.is_cuda_alike())
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if M >= 128:
|
||||
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
|
||||
score, topk, block_size)
|
||||
else:
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
|
||||
topk, block_size)
|
||||
|
||||
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||
a, score.float(), topk, False)
|
||||
|
||||
if use_compile:
|
||||
deep_gemm_moe_fp8_fn = torch.compile(deep_gemm_moe_fp8,
|
||||
backend="inductor",
|
||||
fullgraph=True)
|
||||
torch._dynamo.mark_dynamic(a, 0)
|
||||
torch._dynamo.mark_dynamic(topk_weights, 0)
|
||||
torch._dynamo.mark_dynamic(topk_ids, 0)
|
||||
else:
|
||||
deep_gemm_moe_fp8_fn = deep_gemm_moe_fp8
|
||||
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
if use_cudagraph:
|
||||
out.fill_(0)
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
out = deep_gemm_moe_fp8_fn(a, w1, w2, w1_s, w2_s, topk_weights,
|
||||
topk_ids)
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
|
||||
assert rel_diff < 0.03
|
||||
|
@ -8,9 +8,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.quantization.utils.int8_utils import (
|
||||
w8a8_block_int8_matmul)
|
||||
from vllm.platforms import current_platform
|
||||
@ -23,82 +21,10 @@ vllm_config = VllmConfig()
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
# For test
|
||||
def native_per_token_group_quant_int8(x,
|
||||
group_size,
|
||||
eps=1e-10,
|
||||
dtype=torch.int8):
|
||||
"""Function to perform per-token-group quantization on an input tensor
|
||||
`x` using native torch.
|
||||
|
||||
It converts the tensor values into int8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
"""
|
||||
assert (x.shape[-1] % group_size == 0
|
||||
), "the last dimension of `x` cannot be divisible by `group_size`"
|
||||
assert x.is_contiguous(), "`x` is not contiguous"
|
||||
|
||||
iinfo = torch.iinfo(dtype)
|
||||
int8_min = iinfo.min
|
||||
int8_max = iinfo.max
|
||||
|
||||
x_ = x.reshape(x.numel() // group_size, group_size)
|
||||
# Use float32 for scale calculation for stability
|
||||
amax = x_.abs().max(dim=-1,
|
||||
keepdim=True)[0].clamp(min=eps).to(torch.float32)
|
||||
x_s = amax / int8_max
|
||||
x_q = (x_.to(torch.float32) / x_s).round().clamp(
|
||||
min=int8_min, max=int8_max).to(dtype) # Round before clamping
|
||||
x_q = x_q.reshape(x.shape)
|
||||
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
# For test
|
||||
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
"""This function performs fused moe with block-wise quantization using
|
||||
native torch."""
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
inter_out = native_w8a8_block_matmul(a_q[mask],
|
||||
w1[i],
|
||||
a_s[mask],
|
||||
w1_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
act_out_q, act_out_s = native_per_token_group_quant_int8(
|
||||
act_out, block_k)
|
||||
act_out = act_out.to(torch.float32)
|
||||
out[mask] = native_w8a8_block_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
block_shape,
|
||||
output_dtype=a.dtype)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8, 24]
|
||||
TOP_KS = [2, 6]
|
||||
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
|
||||
BLOCK_SIZE = [[128, 128]]
|
||||
SEEDS = [0]
|
||||
@ -140,63 +66,3 @@ def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.001
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"M, N, K, E, topk, block_size, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
|
||||
native torch reference."""
|
||||
torch.manual_seed(seed)
|
||||
# Use a smaller factor for scale initialization to prevent large
|
||||
# values/overflow especially when output dtype might be float16
|
||||
factor_for_scale = 1e-2
|
||||
int8_info = torch.iinfo(torch.int8)
|
||||
int8_max, int8_min = int8_info.max, int8_info.min
|
||||
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
w1_fp32 = (torch.rand(
|
||||
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
|
||||
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||
|
||||
w1_s = (torch.rand(
|
||||
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
|
||||
w2_s = (torch.rand(
|
||||
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=block_size,
|
||||
)
|
||||
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.06
|
||||
|
@ -13,8 +13,11 @@ import pytest
|
||||
import torch
|
||||
from torch._prims_common import TensorLikeType
|
||||
|
||||
from tests.kernels.quant_utils import native_w8a8_block_matmul
|
||||
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.platforms.interface import _Backend
|
||||
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
||||
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
||||
@ -1054,32 +1057,77 @@ def compute_max_diff(output, output_ref):
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def torch_experts(a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def torch_experts(
|
||||
a: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_act_token_quant=False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
assert (global_num_experts == -1
|
||||
or (global_num_experts == w1.shape[0] and expert_map is None)
|
||||
or (expert_map is not None
|
||||
and global_num_experts == expert_map.shape[0]))
|
||||
|
||||
M, K = a.shape
|
||||
topk = topk_ids.shape[1]
|
||||
B, D = a.shape
|
||||
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
|
||||
a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K)
|
||||
|
||||
out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
a, a_scale = moe_kernel_quantize_input(a, None, quant_dtype,
|
||||
per_act_token_quant, block_shape)
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
topk_ids = topk_ids.view(-1)
|
||||
if expert_map is not None:
|
||||
topk_ids = expert_map[topk_ids]
|
||||
for i in range(w1.shape[0]):
|
||||
|
||||
for i in range(num_experts):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
if quant_dtype is None:
|
||||
tmp1 = a[mask] @ w1[i].transpose(0, 1)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
out[mask] = tmp2 @ w2[i].transpose(0, 1)
|
||||
elif block_shape is not None:
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
tmp1 = native_w8a8_block_matmul(a[mask], w1[i], a_scale[mask],
|
||||
w1_scale[i], block_shape,
|
||||
out.dtype)
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
tmp2, b_scale = moe_kernel_quantize_input(
|
||||
tmp2, None, quant_dtype, per_act_token_quant, block_shape)
|
||||
|
||||
out[mask] = native_w8a8_block_matmul(tmp2, w2[i], b_scale,
|
||||
w2_scale[i], block_shape,
|
||||
out.dtype)
|
||||
else:
|
||||
assert (a_scale is not None and w1_scale is not None
|
||||
and w2_scale is not None)
|
||||
f32 = torch.float32
|
||||
scales = a_scale if a_scale.numel() == 1 else a_scale[mask]
|
||||
tmp1 = a[mask].to(f32) * scales
|
||||
w1_dq = (w1[i].to(f32) * w1_scale[i]).transpose(0, 1)
|
||||
tmp1 = tmp1 @ w1_dq
|
||||
tmp2 = SiluAndMul()(tmp1)
|
||||
w2_dq = (w2[i].to(f32) * w2_scale[i]).transpose(0, 1)
|
||||
out[mask] = (tmp2 @ w2_dq).to(out.dtype)
|
||||
|
||||
return (out.view(M, -1, w2.shape[1]) *
|
||||
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
def torch_moe(a: torch.Tensor,
|
||||
|
@ -1274,7 +1274,7 @@ def scaled_fp8_quant(
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
assert scale.numel() == 1
|
||||
assert scale.numel() == 1, f"{scale.shape}"
|
||||
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
|
||||
|
||||
return output, scale
|
||||
|
@ -4,8 +4,12 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
@ -26,8 +30,12 @@ def get_config() -> Optional[dict[str, Any]]:
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
@ -36,11 +44,21 @@ if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
import vllm.model_executor.layers.fused_moe.fused_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts, fused_experts, fused_moe, fused_topk,
|
||||
get_config_file_name, grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
@ -50,5 +68,11 @@ if HAS_TRITON:
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"CutlassExpertsFp8",
|
||||
"TritonExperts",
|
||||
"BatchedTritonExperts",
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
@ -179,28 +180,44 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# The Deep Gemm kernels only support block size of 128
|
||||
DEEPGEMM_BLOCK_SHAPE = 128
|
||||
DEEPGEMM_BLOCK_SHAPE: list[int] = [128, 128]
|
||||
|
||||
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
||||
block_shape: list[int]):
|
||||
def __init__(self,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
block_shape: list[int],
|
||||
per_act_token_quant=False):
|
||||
"""
|
||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||
world_size: Number of EP ranks
|
||||
dp_size: Number of data-parallel ranks
|
||||
block_shape: Block quantization block shape
|
||||
"""
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert self.block_shape == self.DEEPGEMM_BLOCK_SHAPE
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.block_shape = block_shape
|
||||
|
||||
assert (len(self.block_shape) == 2 and all(
|
||||
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -248,6 +265,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
):
|
||||
import deep_gemm as dg
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
|
||||
@ -20,43 +21,45 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
allow_deep_gemm: bool = False):
|
||||
super().__init__()
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
block_shape=block_shape,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
))
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.block_shape = block_shape
|
||||
self.allow_deep_gemm = allow_deep_gemm
|
||||
|
||||
# BatchedTritonKernel doesn't support block quantization
|
||||
# at the moment.
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape,
|
||||
world_size=self.world_size,
|
||||
dp_size=self.dp_size) if self.block_shape is None else None
|
||||
dp_size=self.dp_size,
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape,
|
||||
) if self.block_shape is None else None
|
||||
|
||||
is_fp8_128_block_quantized = (
|
||||
use_fp8_w8a8 and self.block_shape
|
||||
== BatchedDeepGemmExperts.DEEPGEMM_BLOCK_SHAPE)
|
||||
|
||||
is_fp8_128_block_quantized = (self.use_fp8_w8a8
|
||||
and self.block_shape is not None
|
||||
and len(self.block_shape) == 2 and all(
|
||||
[b == 128
|
||||
for b in self.block_shape]))
|
||||
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
world_size=self.world_size,
|
||||
@ -67,12 +70,31 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert (self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.batched_triton_experts is not None:
|
||||
assert (self.batched_deep_gemm_experts is None
|
||||
or self.batched_deep_gemm_experts.activation_formats
|
||||
== self.batched_triton_experts.activation_formats)
|
||||
return self.batched_triton_experts.activation_formats
|
||||
else:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return ((bdge is None or bdge.supports_chunking())
|
||||
and (bte is None or bte.supports_chunking()))
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return ((bdge is None or bdge.supports_expert_map())
|
||||
and (bte is None or bte.supports_expert_map()))
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -87,7 +109,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None:
|
||||
if self.allow_deep_gemm:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts)
|
||||
else:
|
||||
|
410
vllm/model_executor/layers/fused_moe/config.py
Normal file
410
vllm/model_executor/layers/fused_moe/config.py
Normal file
@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quant_config_quantization_args(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prop_name: str,
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(prop_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_quant_config_input_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config,
|
||||
"input_activations")
|
||||
|
||||
|
||||
def get_quant_config_weight_quant(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_quant_dtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
) -> Optional[torch.dtype]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
# The post quantization activation type.
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
per_act_token_quant: bool = False
|
||||
per_out_ch_quant: bool = False
|
||||
block_shape: Optional[list[int]] = None
|
||||
|
||||
# TODO: add col major flag?
|
||||
# add detailed quant info for input, intermediates, weights, etc?
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
assert sum([
|
||||
int(flag) for flag in [
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
use_int8_w8a16,
|
||||
use_int4_w4a16,
|
||||
]
|
||||
]) <= 1, "Quantization flags are mutually exclusive."
|
||||
|
||||
quant_dtype = get_config_quant_dtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
return FusedMoEQuantConfig(
|
||||
quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
world_size: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int, world_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||
world_size_ (int): the world size of the current All2All manager.
|
||||
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||
object.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||
we simply return the sizes unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||
is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
world_size=world_size_,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
world_size=world_size_,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug("Using FusedMoEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.quant_dtype
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.block_shape
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_act_token_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is not None:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
else:
|
||||
return False
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def world_size(self):
|
||||
return self.moe_parallel_config.world_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
num_experts: int,
|
||||
experts_per_token: int,
|
||||
hidden_dim: int,
|
||||
num_local_experts: int,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
in_dtype: torch.dtype,
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config: Optional[Union[FusedMoEQuantConfig,
|
||||
QuantizationConfig]] = None
|
||||
) -> "FusedMoEConfig":
|
||||
|
||||
_quant_config: Optional[FusedMoEQuantConfig] = None
|
||||
|
||||
if quant_config is not None and isinstance(quant_config,
|
||||
QuantizationConfig):
|
||||
if hasattr(quant_config, 'weight_block_size'):
|
||||
block_shape = quant_config.weight_block_size
|
||||
else:
|
||||
block_shape = None
|
||||
per_act_token_quant = False
|
||||
per_out_ch_quant = False
|
||||
quant_dtype: Optional[torch.dtype] = None
|
||||
|
||||
input_quant = get_quant_config_input_quant(quant_config)
|
||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||
|
||||
if input_quant is not None:
|
||||
per_act_token_quant = (input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_quant is not None else False)
|
||||
|
||||
if input_quant.num_bits == 8:
|
||||
if input_quant.type == QuantizationType.FLOAT:
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
elif input_quant.type == QuantizationType.INT:
|
||||
quant_dtype = torch.int8
|
||||
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if quant_dtype is None and isinstance(quant_config, Fp8Config):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
if weight_quant is not None:
|
||||
per_out_ch_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
|
||||
if quant_dtype is not None:
|
||||
_quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=quant_dtype,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
else:
|
||||
_quant_config = FusedMoEQuantConfig()
|
||||
logger.warning_once("MoE DP setup unable to determine "
|
||||
"quantization scheme or unsupported "
|
||||
"quantization type. This model will "
|
||||
"not run with DP enabled.")
|
||||
else:
|
||||
_quant_config = quant_config
|
||||
|
||||
return FusedMoEConfig(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=experts_per_token,
|
||||
hidden_dim=hidden_dim,
|
||||
num_local_experts=num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=in_dtype,
|
||||
quant_config=_quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
)
|
@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _fp8_perm, _resize_cache
|
||||
@ -202,26 +203,47 @@ def run_cutlass_moe_fp8(
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
# maybe remove need for passing aq to workspace_shapes
|
||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_experts_per_worker: int,
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
out_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert max_experts_per_worker > 0
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.per_act_token = per_act_token
|
||||
self.per_out_ch = per_out_ch
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.use_batched_format:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
else:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -245,7 +267,8 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M * topk, K)
|
||||
return (workspace1, workspace2, output, self.out_dtype)
|
||||
return (workspace1, workspace2, output,
|
||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -270,13 +293,14 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
|
||||
activation_callable, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2,
|
||||
expert_num_tokens, self.out_dtype,
|
||||
self.per_act_token, self.per_out_ch,
|
||||
self.use_batched_format)
|
||||
in_dtype = hidden_states.dtype
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
self.per_act_token_quant, self.per_out_ch_quant,
|
||||
self.use_batched_format)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@ -287,6 +311,7 @@ def cutlass_moe_fp8(
|
||||
topk_ids: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
activation: str = "silu",
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
@ -330,22 +355,18 @@ def cutlass_moe_fp8(
|
||||
Returns:
|
||||
- torch.Tensor: The fp16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
|
||||
a2_scale.numel() != 1 if a2_scale is not None else False)
|
||||
per_out_ch = w1_scale.numel() != w1_q.size(0)
|
||||
|
||||
out_dtype = a.dtype
|
||||
num_experts = global_num_experts if global_num_experts != -1 else w1_q.size(
|
||||
0)
|
||||
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_channel_quant=per_act_token,
|
||||
),
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
CutlassExpertsFp8(
|
||||
max_experts_per_worker=global_num_experts,
|
||||
out_dtype=out_dtype,
|
||||
per_act_token=per_act_token,
|
||||
per_out_ch=per_out_ch,
|
||||
max_experts_per_worker=num_experts,
|
||||
out_dtype=a.dtype,
|
||||
per_act_token_quant=per_act_token,
|
||||
per_out_ch_quant=per_out_ch,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
@ -358,7 +379,7 @@ def cutlass_moe_fp8(
|
||||
topk_ids,
|
||||
False,
|
||||
activation,
|
||||
global_num_experts if global_num_experts != -1 else w1_q.size(0),
|
||||
num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
|
@ -7,13 +7,13 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_permute)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
_resize_cache, per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -65,16 +65,31 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
|
||||
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block_shape = deep_gemm_block_shape()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig(
|
||||
quant_dtype=torch.float8_e4m3fn,
|
||||
per_act_token_quant=False,
|
||||
block_shape=deep_gemm_block_shape(),
|
||||
))
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
|
||||
topk: int, global_num_experts: int, local_num_experts: int
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert self.block_shape is not None
|
||||
# We use global_num_experts due to how moe_align_block_size handles
|
||||
# expert_maps.
|
||||
num_experts = global_num_experts
|
||||
@ -107,6 +122,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
):
|
||||
import deep_gemm as dg
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
@ -213,8 +229,7 @@ def deep_gemm_moe_fp8(
|
||||
- torch.Tensor: The bfloat16 output tensor after applying the MoE layer.
|
||||
"""
|
||||
fn = mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(quant_dtype=torch.float8_e4m3fn,
|
||||
block_shape=deep_gemm_block_shape()),
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
DeepGemmExperts(),
|
||||
)
|
||||
return fn(
|
||||
|
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
@ -15,22 +16,14 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
Prepare/Finalize using DeepEP High-Throughput kernels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
buffer: deep_ep.Buffer,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
rank_expert_offset: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None):
|
||||
def __init__(self, buffer: deep_ep.Buffer, world_size: int, rank: int,
|
||||
dp_size: int, rank_expert_offset: int):
|
||||
super().__init__()
|
||||
self.buffer = buffer
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.dp_size = dp_size
|
||||
self.rank_expert_offset = rank_expert_offset
|
||||
self.quant_dtype = quant_dtype
|
||||
self.block_shape = block_shape
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
@ -39,6 +32,10 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
|
||||
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
@ -55,13 +52,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return None
|
||||
return deep_ep.Buffer.get_combine_config(self.dp_size)
|
||||
|
||||
def _do_quant(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor], per_act_token: bool):
|
||||
tokens, token_scales = moe_kernel_quantize_input(
|
||||
tokens, token_scales, self.quant_dtype, per_act_token,
|
||||
self.block_shape)
|
||||
return tokens, token_scales
|
||||
|
||||
def _do_dispatch(self, tokens: torch.Tensor,
|
||||
token_scales: Optional[torch.Tensor],
|
||||
rank_topk_ids: torch.Tensor,
|
||||
@ -130,43 +120,51 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Check if there is a block_shape / or if we can infer the quantization
|
||||
# schemes from the scales.
|
||||
per_token_quant = None
|
||||
if all([x is None for x in [self.block_shape, a1_scale, a2_scale]
|
||||
]) and self.quant_dtype is not None:
|
||||
if all([
|
||||
x is None
|
||||
for x in [quant_config.block_shape, a1_scale, a2_scale]
|
||||
]) and quant_config.quant_dtype is not None:
|
||||
# Quantization required despite none of the inputs suggesting
|
||||
# quantization. Fallback to per_token_dynamic quant.
|
||||
per_token_quant = True
|
||||
else:
|
||||
per_token_quant = ((self.block_shape is not None) or
|
||||
(a1_scale is not None and a1_scale.numel() != 1)
|
||||
or (a2_scale is not None
|
||||
and a2_scale.numel() != 1))
|
||||
per_token_quant = False
|
||||
|
||||
if per_token_quant:
|
||||
a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=True,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
if a1q_scale is not None and a1q_scale.numel() == 1:
|
||||
a1q_scale = a1q_scale.view(1, 1)
|
||||
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1q,
|
||||
token_scales=a1q_scale,
|
||||
rank_topk_ids=rank_topk_ids,
|
||||
rank_topk_weights=rank_topk_weights,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
else:
|
||||
# DeepEP kernels only support dispatching per-token-quant
|
||||
@ -175,15 +173,18 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
expert_topk_weights) = self._do_dispatch(
|
||||
tokens=a1,
|
||||
token_scales=None,
|
||||
rank_topk_ids=rank_topk_ids,
|
||||
rank_topk_weights=rank_topk_weights,
|
||||
rank_topk_ids=topk_ids,
|
||||
rank_topk_weights=topk_weights,
|
||||
num_experts=num_experts)
|
||||
# quantize now
|
||||
expert_x_scale = None
|
||||
if expert_x.numel() != 0:
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x,
|
||||
a1_scale,
|
||||
per_act_token=False)
|
||||
expert_x, expert_x_scale = moe_kernel_quantize_input(
|
||||
expert_x,
|
||||
a1_scale,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=False,
|
||||
block_shape=quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
|
||||
expert_topk_weights)
|
||||
|
@ -5,11 +5,13 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
maybe_fix_scales, moe_kernel_quantize_input)
|
||||
|
||||
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
|
||||
DEEPEP_QUANT_BLOCK_SIZE = 128
|
||||
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]
|
||||
|
||||
|
||||
def dequant_fp8(expert_x_fp8: torch.Tensor,
|
||||
@ -35,30 +37,30 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
# DeepEP low-latency kernels are compiled only for certain
|
||||
# specific hidden sizes.
|
||||
SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168]
|
||||
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 7168]
|
||||
|
||||
def __init__(self,
|
||||
buffer: deep_ep.Buffer,
|
||||
max_tokens_per_rank: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
max_tokens_per_rank: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
use_fp8_dispatch: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.buffer = buffer
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.block_shape = block_shape
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.use_fp8_dispatch = use_fp8_dispatch
|
||||
# The dispatch function returns a handle that the combine function
|
||||
# requires. We store the handle here so it is available to the
|
||||
# combine function.
|
||||
self.handle = None
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_tokens_per_rank
|
||||
|
||||
@ -66,12 +68,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
return torch.int64
|
||||
|
||||
def _do_quant(
|
||||
self, x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
a1_scale: Optional[torch.Tensor], a2_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype
|
||||
self,
|
||||
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
a1_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
block_k = self.block_shape[1] if self.block_shape is not None else None
|
||||
block_k = block_shape[1] if block_shape is not None else None
|
||||
if self.use_fp8_dispatch:
|
||||
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
|
||||
# DeepEP kernels did the quantization for us.
|
||||
@ -84,32 +91,20 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
assert isinstance(x, torch.Tensor)
|
||||
|
||||
# Check if there is a block_shape / or if we can infer the quantization
|
||||
# schemes from the scales.
|
||||
per_token_quant = None
|
||||
if all([v is None for v in [self.block_shape, a1_scale, a2_scale]
|
||||
]) and self.quant_dtype is not None:
|
||||
# Quantization required despite none of the inputs suggesting
|
||||
# quantization. Fallback to per_token_dynamic quant.
|
||||
per_token_quant = True
|
||||
else:
|
||||
per_token_quant = ((self.block_shape is not None) or
|
||||
(a1_scale is not None and a1_scale.numel() != 1)
|
||||
or (a2_scale is not None
|
||||
and a2_scale.numel() != 1))
|
||||
assert not per_act_token_quant
|
||||
|
||||
num_experts, max_tokens, hidden_dim = x.size()
|
||||
|
||||
# TODO (varun): Optimization - Use a batched version of quant
|
||||
x = x.view((-1, hidden_dim))
|
||||
x, x_scales = moe_kernel_quantize_input(x, a1_scale, self.quant_dtype,
|
||||
per_token_quant,
|
||||
self.block_shape)
|
||||
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
|
||||
per_act_token_quant,
|
||||
block_shape)
|
||||
x = x.view((num_experts, -1, hidden_dim))
|
||||
|
||||
if per_token_quant:
|
||||
if quant_dtype is not None:
|
||||
assert x_scales is not None
|
||||
x_scales = x_scales.view(num_experts, max_tokens, -1)
|
||||
x_scales = maybe_fix_scales(x_scales, num_experts)
|
||||
|
||||
return x, x_scales
|
||||
|
||||
@ -118,11 +113,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
@ -142,24 +138,25 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"low_latency kernels doesn't support dispatching per-token scales")
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
expert_x, expert_num_tokens, self.handle, event, hook = \
|
||||
self.buffer.low_latency_dispatch(a1,
|
||||
rank_topk_ids,
|
||||
topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=False)
|
||||
|
||||
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a2_scale,
|
||||
a1.dtype)
|
||||
expert_x, expert_x_scale = self._do_quant(
|
||||
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
|
||||
|
||||
|
@ -8,6 +8,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
get_config_dtype_str, try_get_optimal_moe_config)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
@ -317,8 +318,8 @@ def invoke_moe_batched_triton_kernel(
|
||||
expert_num_tokens: torch.Tensor, # [E]
|
||||
compute_type: tl.dtype,
|
||||
# Quantization data
|
||||
A_scale: torch.Tensor,
|
||||
B_scale: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
B_zp: torch.Tensor,
|
||||
# Quantization schemes
|
||||
use_fp8_w8a8: bool,
|
||||
@ -387,14 +388,23 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
||||
rank: int):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
rank: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
self.rank = rank
|
||||
self.max_num_tokens = max_num_tokens
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_num_tokens
|
||||
|
||||
@ -411,6 +421,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
assert a1.dim() == 2
|
||||
@ -435,22 +446,35 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
num_local_experts = num_experts // self.world_size
|
||||
|
||||
if quant_config.quant_dtype is None:
|
||||
b_type = a1.dtype
|
||||
else:
|
||||
b_type = quant_config.quant_dtype
|
||||
|
||||
b_a1 = torch.zeros(
|
||||
(num_local_experts, self.max_num_tokens, hidden_dim),
|
||||
dtype=a1.dtype,
|
||||
dtype=b_type,
|
||||
device=a1.device)
|
||||
|
||||
b_a1_scale = None
|
||||
|
||||
assert quant_config.quant_dtype is None, "quantization NYI"
|
||||
|
||||
first_expert = num_local_experts * self.rank
|
||||
last_expert = first_expert + num_local_experts
|
||||
|
||||
for expert_id in range(first_expert, last_expert):
|
||||
topks = torch.any(topk_ids == expert_id, dim=1).flatten()
|
||||
rows = torch.count_nonzero(topks.flatten())
|
||||
b_a1[expert_id -
|
||||
first_expert, :rows, :] = a1[:topks.numel()][topks]
|
||||
tokens_per_expert[expert_id - first_expert] = rows
|
||||
if rows == 0:
|
||||
continue
|
||||
idx = expert_id - first_expert
|
||||
b_a1[idx, :rows, :] = a1[:topks.numel()][topks]
|
||||
tokens_per_expert[idx] = rows
|
||||
|
||||
return b_a1, a1_scale, tokens_per_expert, None, None
|
||||
assert b_a1_scale is None or b_a1_scale.ndim == 3
|
||||
|
||||
return b_a1, b_a1_scale, tokens_per_expert, None, None
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
@ -480,7 +504,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
output[topks] = output[topks] + rhs
|
||||
|
||||
|
||||
class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
"""
|
||||
A reference MoE expert class that operates on expert batched format,
|
||||
i.e. E x max_num_tokens x K. This is the format that the pplx
|
||||
@ -497,11 +521,17 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
assert block_shape is None
|
||||
assert block_m is None
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert not use_fp8_w8a8, "NYI"
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
@ -510,9 +540,19 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -554,20 +594,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_num_tokens is not None
|
||||
|
||||
max_num_tokens = self.max_num_tokens
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_local_experts = w1.size(0)
|
||||
assert num_local_experts == w1.size(0), (
|
||||
f"{num_local_experts} == {w1.size(0)}")
|
||||
|
||||
N = w1.size(1) // 2
|
||||
|
||||
# Not cudagraph friendly
|
||||
assert (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()
|
||||
or torch.all(expert_num_tokens <= max_num_tokens * num_dp)), (
|
||||
f"{expert_num_tokens} <= {max_num_tokens * num_dp}")
|
||||
|
||||
for expert in range(num_local_experts):
|
||||
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
||||
if (torch.compiler.is_compiling()
|
||||
@ -575,6 +607,10 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
num = hidden_states.shape[1]
|
||||
else:
|
||||
num = int(expert_num_tokens[expert].item())
|
||||
|
||||
if num == 0:
|
||||
continue
|
||||
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
@ -590,34 +626,53 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
world_size: int = 1,
|
||||
dp_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int8_w8a16, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
assert world_size > 0
|
||||
assert dp_size > 0
|
||||
assert dp_size <= world_size
|
||||
assert max_num_tokens > 0
|
||||
|
||||
assert not use_int8_w8a8, "NYI"
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
assert self.block_shape is None, "NYI"
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -630,10 +685,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
local_num_experts: int,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_dp = self.world_size
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
max_num_tokens = self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
@ -708,7 +762,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
raise ValueError(
|
||||
f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
#print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}")
|
||||
# We can reuse the memory between these because by the time we need
|
||||
# cache3, we're done with cache1
|
||||
intermediate_cache1 = _resize_cache(workspace13,
|
||||
@ -734,6 +787,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
config=config,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
intermediate_cache2.fill_(0)
|
||||
|
||||
# TODO: would be nice to use expert_num_tokens here to reduce
|
||||
# garbage compute
|
||||
self.activation(activation, intermediate_cache2.view(-1, N // 2),
|
||||
@ -745,8 +800,8 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
quant_dtype=self.quant_dtype,
|
||||
per_act_token_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
qintermediate_cache2 = qintermediate_cache2.view(
|
||||
|
@ -12,6 +12,10 @@ import vllm.envs as envs
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, get_config_quant_dtype)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
@ -980,20 +984,6 @@ def get_config_dtype_str(
|
||||
return None
|
||||
|
||||
|
||||
# TODO (bnell): use scalar_type instead of bools?
|
||||
def get_config_qtype(
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
) -> Optional[torch.dtype]:
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
return None
|
||||
|
||||
|
||||
def inplace_fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -1262,10 +1252,10 @@ def fused_experts_impl(
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
dtype=hidden_states.dtype)
|
||||
|
||||
qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
qtype = get_config_quant_dtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
@ -1332,8 +1322,8 @@ def fused_experts_impl(
|
||||
qcurr_hidden_states, a1q_scale = moe_kernel_quantize_input(
|
||||
A=curr_hidden_states,
|
||||
A_scale=a1_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
@ -1373,8 +1363,8 @@ def fused_experts_impl(
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
A=intermediate_cache2,
|
||||
A_scale=a2_scale,
|
||||
qtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
quant_dtype=qtype,
|
||||
per_act_token_quant=per_channel_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
@ -1521,30 +1511,41 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
self.use_int4_w4a16 = use_int4_w4a16
|
||||
self.use_int8_w8a8 = use_int8_w8a8
|
||||
self.use_int8_w8a16 = use_int8_w8a16
|
||||
self.block_shape = block_shape
|
||||
self.block_m = block_m
|
||||
self.qtype = get_config_qtype(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.Standard,
|
||||
mk.FusedMoEActivationFormat.Standard)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -1660,7 +1661,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
self.activation(activation, intermediate_cache2,
|
||||
@ -1669,8 +1670,8 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, a2_scale, self.qtype, self.per_channel_quant,
|
||||
self.block_shape)
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int8_w8a8=self.use_int8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
per_channel_quant=self.per_act_token_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
|
||||
@ -1699,27 +1700,17 @@ def modular_triton_fused_moe(
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
use_int4_w4a16: bool,
|
||||
per_channel_quant: bool,
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> mk.FusedMoEModularKernel:
|
||||
qtype = get_config_qtype(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
)
|
||||
return mk.FusedMoEModularKernel(
|
||||
MoEPrepareAndFinalizeNoEP(
|
||||
quant_dtype=qtype,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
MoEPrepareAndFinalizeNoEP(),
|
||||
TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
),
|
||||
)
|
||||
|
@ -3,27 +3,30 @@
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Callable, Literal, Optional, Union, overload
|
||||
from typing import Callable, Literal, Optional, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from compressed_tensors.quantization import (QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType)
|
||||
from torch.nn.parameter import UninitializedParameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_world_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -36,14 +39,12 @@ from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
||||
if current_platform.is_cuda_alike():
|
||||
from .fused_batched_moe import BatchedTritonExperts
|
||||
from .fused_moe import TritonExperts, fused_experts
|
||||
from .modular_kernel import (FusedMoEModularKernel,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
||||
from .pplx_prepare_finalize import (PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
@ -60,209 +61,10 @@ if current_platform.is_tpu():
|
||||
from .moe_pallas import fused_moe as fused_moe_pallas
|
||||
else:
|
||||
fused_moe_pallas = None # type: ignore
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input tp_size_,
|
||||
dp_size_, ep_size_ and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
||||
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
||||
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
||||
object.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
||||
we simply return the sizes unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
||||
is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices,
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class MoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
in_dtype: torch.dtype # The activation type.
|
||||
quant_dtype: torch.dtype = None
|
||||
|
||||
# TODO: add more quantization params, blocked, per-token, etc.
|
||||
block_size: int = 128
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug("Using MOEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
|
||||
class FusedMoeWeightScaleSupported(Enum):
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
@ -270,21 +72,9 @@ class FusedMoeWeightScaleSupported(Enum):
|
||||
BLOCK = "block"
|
||||
|
||||
|
||||
def get_quant_config_input_activations(
|
||||
quant_config: Optional[QuantizationConfig]
|
||||
) -> Optional[QuantizationArgs]:
|
||||
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
||||
and "Linear" in quant_config.target_scheme_map and
|
||||
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
||||
return quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
moe: MoEConfig
|
||||
moe: FusedMoEConfig
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
@ -292,23 +82,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
raise NotImplementedError
|
||||
|
||||
def init_prepare_finalize(self, moe: MoEConfig,
|
||||
def init_prepare_finalize(self, moe: FusedMoEConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
self.moe = moe
|
||||
quant_dtype = None
|
||||
act_quant_block_size = None
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
act_quant_block_size = quant_config.weight_block_size
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
|
||||
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
|
||||
DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize]] = None
|
||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
moe.quant_dtype,
|
||||
per_act_token_quant=moe.per_act_token_quant,
|
||||
block_shape=moe.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
@ -318,14 +110,8 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to sizeof(float32)
|
||||
hidden_dim_scale_bytes=(
|
||||
0 if moe.quant_dtype.itemsize != 1 else
|
||||
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
||||
torch.float32.itemsize)),
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
@ -335,9 +121,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
@ -345,10 +128,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
rank=all2all_manager.rank,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
quant_dtype=moe.quant_dtype,
|
||||
per_act_token=(input_activations.strategy
|
||||
== QuantizationStrategy.TOKEN
|
||||
if input_activations is not None else False),
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
@ -362,8 +141,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
rank_expert_offset=all2all_manager.rank *
|
||||
moe.num_local_experts,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
@ -380,25 +157,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
|
||||
# Note : We may want to use FP8 dispatch even otherwise just to
|
||||
# reduce datamovement
|
||||
assert act_quant_block_size is not None
|
||||
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
||||
and act_quant_block_size[1]
|
||||
== DEEPEP_QUANT_BLOCK_SIZE)
|
||||
use_fp8_dispatch = (moe.quant_config is not None
|
||||
and moe.quant_config.quant_dtype
|
||||
== current_platform.fp8_dtype()
|
||||
and moe.quant_config.block_shape
|
||||
== DEEPEP_QUANT_BLOCK_SHAPE)
|
||||
|
||||
# Note (varun): Whether to use FP8 dispatch or not needs some
|
||||
# profiling. Turning it off for now.
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
quant_dtype=quant_dtype,
|
||||
block_shape=act_quant_block_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
self.topk_indices_dtype = None
|
||||
if prepare_finalize is not None:
|
||||
logger.debug("%s", prepare_finalize.__class__.__name__)
|
||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||
experts = self.select_gemm_impl(prepare_finalize, moe)
|
||||
self.fused_experts = FusedMoEModularKernel(
|
||||
@ -407,13 +184,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
# based on the all2all implementation, select the appropriate
|
||||
# gemm implementation
|
||||
raise NotImplementedError(
|
||||
"Subclass must select appropriate gemm implementation"
|
||||
" based on the prepare_finalize")
|
||||
f"{self.__class__.__name__} must select appropriate gemm "
|
||||
"implementation based on the prepare_finalize")
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
@ -445,7 +224,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
"""MoE method without quantization."""
|
||||
|
||||
def __init__(self, moe: MoEConfig):
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__()
|
||||
self.fused_experts = fused_experts # type: ignore
|
||||
self.topk_indices_dtype = None
|
||||
@ -458,44 +237,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
else:
|
||||
self.rocm_aiter_fused_experts = None # type: ignore
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: Optional[MoEConfig]):
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert self.fused_experts == fused_experts
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
||||
|
||||
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
|
||||
) is not None
|
||||
if use_batched_experts:
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
logger.debug("BatchedTritonExperts %s", self.moe)
|
||||
assert self.moe.dp_size == all2all_manager.dp_world_size
|
||||
experts = BatchedTritonExperts(
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=self.moe.max_num_tokens,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
else:
|
||||
logger.debug("TritonExperts %s", self.moe)
|
||||
experts = TritonExperts(
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
block_shape=None,
|
||||
per_channel_quant=False,
|
||||
)
|
||||
return experts
|
||||
return TritonExperts()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
@ -883,13 +648,18 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
tp_size_ = (tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size())
|
||||
dp_size_ = (dp_size
|
||||
if dp_size is not None else get_dp_group().world_size)
|
||||
world_size_ = get_world_group().world_size
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.moe_parallel_config: FusedMoEParallelConfig = (
|
||||
FusedMoEParallelConfig.make(
|
||||
tp_size_=(tp_size if tp_size is not None else
|
||||
get_tensor_model_parallel_world_size()),
|
||||
dp_size_=(dp_size if dp_size is not None else
|
||||
get_dp_group().world_size),
|
||||
tp_size_=tp_size_,
|
||||
dp_size_=dp_size_,
|
||||
world_size_=world_size_,
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.global_num_experts = num_experts + num_redundant_experts
|
||||
@ -948,25 +718,22 @@ class FusedMoE(torch.nn.Module):
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
# Only support float8 for now.
|
||||
quant_dtype = params_dtype
|
||||
if quant_config is not None:
|
||||
input_activations = get_quant_config_input_activations(
|
||||
quant_config)
|
||||
if (input_activations is not None
|
||||
and input_activations.num_bits == 8
|
||||
and input_activations.type == QuantizationType.FLOAT):
|
||||
quant_dtype = torch.float8_e4m3fn
|
||||
if vllm_config.model_config is not None:
|
||||
model_dtype = vllm_config.model_config.dtype
|
||||
else:
|
||||
# TODO (bnell): This is a hack to get test_mixtral_moe to work
|
||||
# since model_config is not set in the pytest test.
|
||||
model_dtype = params_dtype
|
||||
|
||||
moe = MoEConfig(
|
||||
moe = FusedMoEConfig.make(
|
||||
num_experts=self.global_num_experts,
|
||||
experts_per_token=top_k,
|
||||
hidden_dim=hidden_size,
|
||||
num_local_experts=self.local_num_experts,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
in_dtype=params_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
in_dtype=model_dtype,
|
||||
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.moe_config = moe
|
||||
self.quant_config = quant_config
|
||||
@ -1017,16 +784,15 @@ class FusedMoE(torch.nn.Module):
|
||||
self.batched_router_logits: Optional[torch.Tensor] = None
|
||||
if (self.moe_parallel_config.use_pplx_kernels
|
||||
or self.moe_parallel_config.use_deepep_ll_kernels):
|
||||
act_dtype = vllm_config.model_config.dtype
|
||||
self.batched_hidden_states = torch.zeros(
|
||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
|
||||
dtype=act_dtype,
|
||||
(moe.max_num_tokens, self.hidden_size),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
# Note here we use `num_experts` which is logical expert count
|
||||
self.batched_router_logits = torch.zeros(
|
||||
(envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
|
||||
dtype=act_dtype,
|
||||
(moe.max_num_tokens, num_experts),
|
||||
dtype=moe.in_dtype,
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
@property
|
||||
@ -1588,7 +1354,7 @@ class FusedMoE(torch.nn.Module):
|
||||
|
||||
assert (self.batched_hidden_states.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
assert (self.batched_router_logits.size(0) # type: ignore
|
||||
>= chunk_size)
|
||||
staged_hidden_states = self.batched_hidden_states[:
|
||||
chunk_size, :] # type: ignore
|
||||
|
@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
from typing import Optional, final
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.utils import cdiv
|
||||
|
||||
@ -82,6 +84,18 @@ def _moe_problem_size(
|
||||
return E, M, N, K, topk
|
||||
|
||||
|
||||
class FusedMoEActivationFormat(Enum):
|
||||
"""
|
||||
The standard activation format (num_tokens, hidden dim).
|
||||
"""
|
||||
Standard = "standard",
|
||||
"""
|
||||
The batched experts format (num experts, max tokens per expert, hidden dim)
|
||||
"""
|
||||
BatchedExperts = "batched_experts",
|
||||
|
||||
|
||||
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
An abstract base class for the [Quantize-Prepare] and [Finalize] steps
|
||||
@ -99,6 +113,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
@ -148,6 +163,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_format(self) -> FusedMoEActivationFormat:
|
||||
"""
|
||||
A property indicating the output format of the activations for the
|
||||
'prepare' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def topk_indices_dtype(self) -> Optional[torch.dtype]:
|
||||
"""
|
||||
@ -176,6 +200,41 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
above.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: Optional[FusedMoEQuantConfig],
|
||||
):
|
||||
if quant_config is not None:
|
||||
self.quant_config = quant_config
|
||||
else:
|
||||
self.quant_config = FusedMoEQuantConfig()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def activation_formats(
|
||||
self) -> tuple[FusedMoEActivationFormat, FusedMoEActivationFormat]:
|
||||
"""
|
||||
A property which is a tuple of the input and output activation formats
|
||||
for the 'apply' method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
@ -185,6 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def supports_expert_map(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports expert maps
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -297,6 +363,7 @@ def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
return None
|
||||
|
||||
|
||||
@final
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
@ -318,6 +385,12 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
assert prepare_finalize.activation_format == \
|
||||
fused_experts.activation_formats[0], (
|
||||
f"{prepare_finalize.__class__.__name__}."
|
||||
f"{prepare_finalize.activation_format} == "
|
||||
f"{fused_experts.__class__.__name__}."
|
||||
f"{fused_experts.activation_formats[0]}")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -383,8 +456,16 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
|
||||
(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
a1,
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
apply_router_weight_on_input,
|
||||
self.fused_experts.quant_config,
|
||||
)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
|
@ -6,33 +6,76 @@ import pplx_kernels as pplx
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
from vllm.utils import cdiv, round_up
|
||||
|
||||
|
||||
def pplx_hidden_dim_scale_bytes(
|
||||
max_num_tokens: int,
|
||||
hidden_dim: int,
|
||||
in_dtype: torch.dtype,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
):
|
||||
# All pplx byte sizes must be 16-byte aligned.
|
||||
align = 16
|
||||
|
||||
# For blocked per token: set to
|
||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
||||
if quant_dtype is not None:
|
||||
assert quant_dtype.itemsize == 1
|
||||
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
||||
elem_size = torch.float32.itemsize
|
||||
|
||||
if per_act_token_quant:
|
||||
# per-token
|
||||
assert block_shape is None
|
||||
hidden_scale_bytes = elem_size
|
||||
elif block_shape is not None:
|
||||
# per-group
|
||||
block_size = block_shape[1]
|
||||
num_blocks = cdiv(hidden_dim, block_size)
|
||||
hidden_scale_bytes = num_blocks * elem_size
|
||||
else:
|
||||
# per-tensor
|
||||
hidden_scale_bytes = elem_size
|
||||
else:
|
||||
hidden_dim_bytes = hidden_dim * in_dtype.itemsize
|
||||
hidden_scale_bytes = 0
|
||||
|
||||
return (
|
||||
round_up(hidden_dim_bytes, align),
|
||||
round_up(hidden_scale_bytes, align),
|
||||
)
|
||||
|
||||
|
||||
# The max_num_tokens, world_size and dp_size must be the same
|
||||
# as the ones used to create the AllToAll.
|
||||
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(self,
|
||||
a2a: pplx.AllToAll,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
per_act_token: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
a2a: pplx.AllToAll,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
dp_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
assert max_num_tokens > 0
|
||||
self.a2a = a2a
|
||||
self.block_shape = block_shape
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.dp_size = dp_size
|
||||
self.quant_dtype = quant_dtype
|
||||
self.per_act_token = per_act_token
|
||||
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return self.max_num_tokens
|
||||
@ -45,36 +88,43 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
rank_topk_weights: torch.Tensor,
|
||||
rank_topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
num_tokens = a1.size(0) # M
|
||||
hidden_dim = a1.size(-1) # K
|
||||
|
||||
assert rank_topk_ids.size(0) == num_tokens
|
||||
assert topk_ids.size(0) == num_tokens
|
||||
# assert expert_map is None, "NYI"
|
||||
|
||||
# Is this always going to be a1.device?
|
||||
device = a1.device
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = rank_topk_ids.size(1)
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, (
|
||||
"apply_router_weight_on_input is only implemented for topk=1")
|
||||
a1 = a1 * rank_topk_weights.to(a1.dtype)
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
repeat_cols = 4
|
||||
repeat_rows = 1 if self.per_act_token else a1.size(0)
|
||||
repeat_rows = 1 if quant_config.per_act_token_quant else a1.size(0)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, (None if self.per_act_token else a1_scale), self.quant_dtype,
|
||||
self.per_act_token, self.block_shape)
|
||||
a1, (None if quant_config.per_act_token_quant else a1_scale),
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape)
|
||||
|
||||
if a1q_scale is not None:
|
||||
if a1q_scale.numel() == 1:
|
||||
orig_a_scale_block_shape = 1
|
||||
else:
|
||||
orig_a_scale_block_shape = a1q_scale.shape[-1]
|
||||
a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols)
|
||||
|
||||
# rem_experts need to be 0 for pplx to work properly.
|
||||
@ -98,15 +148,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
expert_x_scale: Optional[torch.Tensor] = None
|
||||
if a1q.dtype.itemsize == 1:
|
||||
float32_size = torch.float32.itemsize
|
||||
block_size = (self.block_shape[0] if self.block_shape is not None
|
||||
else 1) * float32_size
|
||||
block_size = (quant_config.block_shape[1]
|
||||
if quant_config.block_shape is not None else 1)
|
||||
expert_x_scale = torch.empty(
|
||||
(
|
||||
num_local_experts,
|
||||
expert_x.size(1),
|
||||
(expert_x.size(2) + block_size - 1) // block_size,
|
||||
),
|
||||
(num_local_experts, expert_x.size(1),
|
||||
round_up(
|
||||
(expert_x.size(2) + block_size - 1) // block_size, 4)),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
@ -121,11 +168,11 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
out_expert_x_scale=expert_x_scale,
|
||||
dp_x=a1q,
|
||||
dp_x_scale=a1q_scale,
|
||||
indices=rank_topk_ids,
|
||||
indices=topk_ids,
|
||||
bound_m=bound_m,
|
||||
)
|
||||
if expert_x_scale is not None:
|
||||
expert_x_scale = expert_x_scale[:, :, 0:1]
|
||||
expert_x_scale = expert_x_scale[:, :, :orig_a_scale_block_shape]
|
||||
|
||||
return expert_x, expert_x_scale, expert_num_tokens, None, None
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||
_moe_unpermute_and_reduce)
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
@ -13,16 +14,9 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
|
||||
class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_dtype: Optional[torch.dtype] = None,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.per_channel_quant = per_channel_quant
|
||||
self.block_shape = block_shape
|
||||
self.quant_dtype = quant_dtype
|
||||
@property
|
||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.Standard
|
||||
|
||||
def max_num_tokens_per_rank(self) -> Optional[int]:
|
||||
return None
|
||||
@ -39,7 +33,8 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool = False,
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
|
||||
@ -50,10 +45,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale,
|
||||
self.quant_dtype,
|
||||
self.per_channel_quant,
|
||||
self.block_shape)
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1, a1_scale, quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||
|
||||
return a1q, a1q_scale, None, None, None
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts, _valid_deep_gemm, _valid_deep_gemm_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
@ -12,34 +13,59 @@ from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
|
||||
class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_channel_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
block_m: Optional[int] = None,
|
||||
allow_deep_gemm: bool = False):
|
||||
super().__init__()
|
||||
self.triton_expert = TritonExperts(use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_channel_quant=per_channel_quant,
|
||||
block_shape=block_shape,
|
||||
block_m=block_m)
|
||||
self.allow_deep_gemm = allow_deep_gemm
|
||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
||||
def __init__(
|
||||
self,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
per_act_token_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
FusedMoEQuantConfig.make(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
))
|
||||
self.triton_expert = TritonExperts(
|
||||
use_fp8_w8a8=use_fp8_w8a8,
|
||||
use_int8_w8a8=use_int8_w8a8,
|
||||
use_int4_w4a16=use_int4_w4a16,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
self.allow_deep_gemm = (allow_deep_gemm and not per_act_token_quant
|
||||
and use_fp8_w8a8)
|
||||
self.deep_gemm_expert = DeepGemmExperts(
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
assert (self.deep_gemm_expert is None
|
||||
or self.triton_expert.activation_formats
|
||||
== self.deep_gemm_expert.activation_formats)
|
||||
return self.triton_expert.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
return ((dge is None or dge.supports_chunking())
|
||||
and (te is None or te.supports_chunking()))
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
return ((dge is None or dge.supports_expert_map())
|
||||
and (te is None or te.supports_expert_map()))
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -83,9 +109,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
):
|
||||
N = w1.size(1)
|
||||
|
||||
use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||
use_deep_gemm = (self.allow_deep_gemm
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
|
||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||
|
@ -37,6 +37,7 @@ def _fp8_quantize(
|
||||
A, A_scale = ops.scaled_fp8_quant(
|
||||
A, A_scale, use_per_token_if_dynamic=per_act_token)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_fp8(A, block_k)
|
||||
@ -64,6 +65,7 @@ def _int8_quantize(
|
||||
"int8 quantization only supports block or channel-wise"
|
||||
A, A_scale = per_token_quant_int8(A)
|
||||
else:
|
||||
assert not per_act_token
|
||||
assert len(block_shape) == 2
|
||||
_, block_k = block_shape[0], block_shape[1]
|
||||
A, A_scale = per_token_group_quant_int8(A, block_k)
|
||||
@ -75,16 +77,15 @@ def _int8_quantize(
|
||||
def moe_kernel_quantize_input(
|
||||
A: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
qtype: Optional[torch.dtype],
|
||||
per_channel_quant: bool,
|
||||
quant_dtype: Optional[torch.dtype],
|
||||
per_act_token_quant: bool,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if qtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
elif qtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_channel_quant, block_shape)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
else:
|
||||
assert A_scale is None
|
||||
return A, A_scale
|
||||
|
||||
|
||||
@ -96,3 +97,17 @@ def _fp8_perm(m: torch.Tensor, idx: torch.Tensor) -> torch.Tensor:
|
||||
return m.view(dtype=torch.uint8)[idx, ...].view(dtype=m.dtype)
|
||||
else:
|
||||
return m[idx, ...]
|
||||
|
||||
|
||||
# TODO(bnell): better name
|
||||
def maybe_fix_scales(scales: Optional[torch.Tensor],
|
||||
num_experts: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None and scales.ndim < 3:
|
||||
if scales.numel() == 1:
|
||||
scales = scales.view(1)
|
||||
scales = torch.repeat_interleave(scales, num_experts,
|
||||
dim=0).view(num_experts, 1, 1)
|
||||
else:
|
||||
scales = scales.view(num_experts, -1, scales.size(-1))
|
||||
|
||||
return scales
|
||||
|
@ -13,8 +13,10 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
CutlassExpertsFp8, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig,
|
||||
FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported, fused_experts)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
@ -32,14 +34,6 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils import has_pplx
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize)
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -569,15 +563,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False)
|
||||
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
if self.use_marlin:
|
||||
elif self.use_marlin:
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
self.fused_experts_func = None
|
||||
else:
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -653,6 +646,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert self.fused_experts_func is not None
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -826,28 +821,27 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize, moe):
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassExpertsFp8)
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
assert moe is not None
|
||||
use_batched_format = (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
|
||||
max_experts_per_worker = (
|
||||
(moe.num_experts + prepare_finalize.world_size - 1) //
|
||||
prepare_finalize.world_size)
|
||||
experts = CutlassExpertsFp8(
|
||||
max_experts_per_worker,
|
||||
num_experts,
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
use_batched_format=True,
|
||||
use_batched_format=use_batched_format,
|
||||
)
|
||||
|
||||
if has_pplx() and isinstance(
|
||||
prepare_finalize,
|
||||
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
|
||||
# no expert_map support in this case
|
||||
self.disable_expert_map = True
|
||||
self.disable_expert_map = not experts.supports_expert_map()
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
@ -888,7 +882,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=torch.uint32)
|
||||
indices_type=self.topk_indices_dtype,
|
||||
)
|
||||
|
||||
return self.fused_experts(
|
||||
x,
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -13,8 +13,11 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
BatchedTritonOrDeepGemmExperts, FusedMoE, FusedMoEActivationFormat,
|
||||
FusedMoEConfig, FusedMoEMethodBase, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize, FusedMoeWeightScaleSupported,
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -777,44 +780,46 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def select_gemm_impl(self, prepare_finalize, moe):
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||
moe: FusedMoEConfig,
|
||||
) -> FusedMoEPermuteExpertsUnpermute:
|
||||
assert not self.use_marlin and not self.rocm_aiter_moe_enabled, (
|
||||
"Marlin and ROCm AITER are not supported with all2all yet.")
|
||||
|
||||
experts: Optional[Union[BatchedTritonOrDeepGemmExperts,
|
||||
TritonOrDeepGemmExperts]] = None
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
use_batched_experts = max_num_tokens_per_rank is not None
|
||||
|
||||
if use_batched_experts:
|
||||
experts = BatchedTritonOrDeepGemmExperts(
|
||||
if (prepare_finalize.activation_format ==
|
||||
FusedMoEActivationFormat.BatchedExperts):
|
||||
max_num_tokens_per_rank = (
|
||||
prepare_finalize.max_num_tokens_per_rank())
|
||||
assert max_num_tokens_per_rank is not None
|
||||
logger.debug(
|
||||
"BatchedTritonOrDeepGemmExperts(%s): "
|
||||
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__, max_num_tokens_per_rank,
|
||||
self.quant_config.weight_block_size, False)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
world_size=prepare_finalize.world_size,
|
||||
dp_size=prepare_finalize.dp_size,
|
||||
world_size=prepare_finalize.
|
||||
world_size, # type: ignore [attr-defined]
|
||||
dp_size=prepare_finalize.
|
||||
dp_size, # type: ignore [attr-defined]
|
||||
use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
per_act_token_quant=False,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
else:
|
||||
experts = TritonOrDeepGemmExperts(
|
||||
logger.debug(
|
||||
"TritonOrDeepGemmExperts(%s): block_size=%s, per_act_token=%s",
|
||||
self.__class__.__name__, self.quant_config.weight_block_size,
|
||||
False)
|
||||
return TritonOrDeepGemmExperts(
|
||||
use_fp8_w8a8=True,
|
||||
block_shape=self.quant_config.weight_block_size,
|
||||
allow_deep_gemm=self.allow_deep_gemm,
|
||||
)
|
||||
|
||||
assert experts is not None
|
||||
return experts
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
|
Reference in New Issue
Block a user