mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Kernels] Add activation chunking logic to FusedMoEModularKernel (#19168)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@ -29,6 +29,7 @@ MNK_FACTORS = [
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
(1024 * 128, 1024, 1024),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
|
@ -15,7 +15,8 @@ import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import opcheck, stack_and_dev, torch_moe
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
@ -76,6 +77,13 @@ def test_fused_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=False,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=None)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
iterative_output = iterative_moe(a,
|
||||
@ -103,7 +111,20 @@ def test_fused_moe(
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
m_triton_output = m_fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map)
|
||||
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(m_triton_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
|
@ -1,10 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@ -14,6 +15,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .deepep_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
try:
|
||||
from pplx_kernels import AllToAll
|
||||
from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id,
|
||||
@ -64,6 +67,7 @@ def pplx_cutlass_moe(
|
||||
out_dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
group_name: Optional[str],
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
@ -84,7 +88,7 @@ def pplx_cutlass_moe(
|
||||
else:
|
||||
scale_elems = (hidden_dim + block_size - 1) // block_size
|
||||
|
||||
ata = AllToAll.internode(
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
@ -96,6 +100,12 @@ def pplx_cutlass_moe(
|
||||
hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize,
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
ata = AllToAll.internode(**args)
|
||||
else:
|
||||
args["group_name"] = group_name
|
||||
ata = AllToAll.intranode(**args)
|
||||
|
||||
w1 = w1.to(device)
|
||||
w2 = w2.to(device)
|
||||
w1_scale = w1_scale.to(device)
|
||||
@ -113,7 +123,10 @@ def pplx_cutlass_moe(
|
||||
)
|
||||
|
||||
experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size,
|
||||
out_dtype, per_act_token, per_out_ch)
|
||||
out_dtype,
|
||||
per_act_token,
|
||||
per_out_ch,
|
||||
use_batched_format=True)
|
||||
|
||||
fused_cutlass_experts = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
@ -184,11 +197,17 @@ def _pplx_moe(
|
||||
w2_full: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_internode: bool,
|
||||
):
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights,
|
||||
@ -196,7 +215,7 @@ def _pplx_moe(
|
||||
pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale,
|
||||
w2_scale, topk_weights, topk_ids,
|
||||
a1_scale, out_dtype, per_act_token,
|
||||
per_out_ch)
|
||||
per_out_ch, group_name)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
@ -207,7 +226,8 @@ def _pplx_moe(
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0)
|
||||
|
||||
nvshmem_finalize()
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [2, 224])
|
||||
@ -218,6 +238,7 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
@ -232,6 +253,7 @@ def test_cutlass_moe_pplx(
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
@ -284,4 +306,5 @@ def test_cutlass_moe_pplx(
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q,
|
||||
w1_scale, w2_scale, topk_weights, topk_ids, a_scale1,
|
||||
dtype, a, w1_d, w2_d, per_act_token, per_out_ch)
|
||||
dtype, a, w1_d, w2_d, per_act_token, per_out_ch,
|
||||
use_internode)
|
||||
|
@ -18,7 +18,6 @@ try:
|
||||
except ImportError:
|
||||
has_pplx = False
|
||||
|
||||
from tests.pplx_utils import ProcessGroupInfo, parallel_launch
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import override_config
|
||||
@ -30,6 +29,8 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .deepep_utils import ProcessGroupInfo, parallel_launch
|
||||
|
||||
requires_pplx = pytest.mark.skipif(
|
||||
not has_pplx,
|
||||
reason="Requires PPLX kernels",
|
||||
@ -153,7 +154,10 @@ def batched_moe(
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
fused_experts = FusedMoEModularKernel(
|
||||
BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0),
|
||||
BatchedPrepareAndFinalize(max_num_tokens=a.shape[0],
|
||||
world_size=1,
|
||||
dp_size=1,
|
||||
rank=0),
|
||||
BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1))
|
||||
|
||||
return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts)
|
||||
@ -229,9 +233,15 @@ def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor:
|
||||
return t[(r * chunk):(r + 1) * chunk]
|
||||
|
||||
|
||||
def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||
topk_weight: torch.Tensor, topk_ids: torch.Tensor,
|
||||
num_experts: int) -> torch.Tensor:
|
||||
def pplx_prepare_finalize(
|
||||
pgi: ProcessGroupInfo,
|
||||
dp_size: int,
|
||||
a: torch.Tensor,
|
||||
topk_weight: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
group_name: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
@ -245,7 +255,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||
world_size = pgi.world_size
|
||||
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
|
||||
|
||||
ata = AllToAll.internode(
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
@ -259,6 +269,12 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||
torch.float32.itemsize)),
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
ata = AllToAll.internode(**args)
|
||||
else:
|
||||
args["group_name"] = group_name
|
||||
ata = AllToAll.intranode(**args)
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
@ -318,11 +334,19 @@ def _pplx_prepare_finalize(
|
||||
score: torch.Tensor,
|
||||
topk: torch.Tensor,
|
||||
num_experts: int,
|
||||
use_internode: bool,
|
||||
):
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
device = pgi.device
|
||||
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
@ -335,14 +359,15 @@ def _pplx_prepare_finalize(
|
||||
a.dtype)
|
||||
|
||||
pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids,
|
||||
num_experts)
|
||||
num_experts, group_name)
|
||||
|
||||
torch_output = chunk_by_rank(torch_output, pgi.rank,
|
||||
pgi.world_size).to(pplx_output.device)
|
||||
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
nvshmem_finalize()
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
# TODO (bnell): this test point does not work for odd M due to how the test is
|
||||
@ -353,6 +378,7 @@ def _pplx_prepare_finalize(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@requires_pplx
|
||||
def test_pplx_prepare_finalize(
|
||||
mnk: tuple[int, int, int],
|
||||
@ -360,6 +386,7 @@ def test_pplx_prepare_finalize(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
@ -369,10 +396,11 @@ def test_pplx_prepare_finalize(
|
||||
score = torch.randn((m, e), device=device, dtype=dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score,
|
||||
topk, e)
|
||||
topk, e, use_internode)
|
||||
|
||||
|
||||
def pplx_moe(
|
||||
group_name: Optional[str],
|
||||
rank: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
@ -394,7 +422,7 @@ def pplx_moe(
|
||||
topk = topk_ids.shape[1]
|
||||
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
|
||||
|
||||
ata = AllToAll.internode(
|
||||
args = dict(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=topk,
|
||||
@ -408,6 +436,12 @@ def pplx_moe(
|
||||
torch.float32.itemsize)),
|
||||
)
|
||||
|
||||
if group_name is None:
|
||||
ata = AllToAll.internode(**args)
|
||||
else:
|
||||
args["group_name"] = group_name
|
||||
ata = AllToAll.intranode(**args)
|
||||
|
||||
topk_ids = topk_ids.to(dtype=torch.uint32)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
@ -522,11 +556,18 @@ def _pplx_moe(
|
||||
w2: torch.Tensor,
|
||||
score: torch.Tensor,
|
||||
topk: int,
|
||||
use_internode: bool,
|
||||
):
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
if use_internode:
|
||||
uid = nvshmem_get_unique_id(
|
||||
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
|
||||
torch.distributed.broadcast(uid, src=0)
|
||||
nvshmem_init(uid, pgi.rank, pgi.world_size)
|
||||
group_name = None
|
||||
else:
|
||||
group_ranks = list(range(pgi.world_size))
|
||||
cpu_group = torch.distributed.new_group(group_ranks, backend="gloo")
|
||||
group_name = cpu_group.group_name
|
||||
|
||||
m, k = a.shape
|
||||
e, _, n = w2.shape
|
||||
@ -536,8 +577,8 @@ def _pplx_moe(
|
||||
with set_current_vllm_config(vllm_config), override_config(moe_config):
|
||||
topk_weight, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
|
||||
pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2,
|
||||
topk_weight, topk_ids)
|
||||
pplx_output = pplx_moe(group_name, pgi.rank, pgi.world_size, dp_size,
|
||||
a, w1, w2, topk_weight, topk_ids)
|
||||
# TODO (bnell): fix + re-enable
|
||||
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
|
||||
# topk_ids)
|
||||
@ -548,7 +589,8 @@ def _pplx_moe(
|
||||
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
|
||||
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
nvshmem_finalize()
|
||||
if use_internode:
|
||||
nvshmem_finalize()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS)
|
||||
@ -556,6 +598,7 @@ def _pplx_moe(
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
|
||||
@pytest.mark.parametrize("use_internode", [False])
|
||||
@requires_pplx
|
||||
def test_pplx_moe(
|
||||
mnk: tuple[int, int, int],
|
||||
@ -563,6 +606,7 @@ def test_pplx_moe(
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
world_dp_size: tuple[int, int],
|
||||
use_internode: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
m, n, k = mnk
|
||||
@ -572,4 +616,5 @@ def test_pplx_moe(
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)
|
||||
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
|
||||
use_internode)
|
||||
|
@ -13,7 +13,8 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
_valid_deep_gemm_shape, deep_gemm_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk, modular_triton_fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
@ -45,7 +46,7 @@ N = [128, 512, 7168, 7748, 13824]
|
||||
K = [256, 3884, 4096, 13824, 16384]
|
||||
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
|
||||
# and its hidden size is 7168.
|
||||
M_moe = [1, 2, 7, 83, 128, 2048]
|
||||
M_moe = [1, 2, 7, 83, 128, 2048, 1024 * 128]
|
||||
M_moe_dg = [128, 192, 1335, 2048]
|
||||
N_moe = [128, 256, 1024, 4608] # [13824]
|
||||
K_moe = [256, 512, 7168] # [13824]
|
||||
@ -214,6 +215,13 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||
use_int8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
per_channel_quant=False,
|
||||
block_shape=block_size)
|
||||
|
||||
# Set the context to avoid lots of warning spam.
|
||||
with set_current_vllm_config(vllm_config):
|
||||
out = fused_moe(
|
||||
@ -231,6 +239,16 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
|
||||
block_size)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
m_out = m_fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=E,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s)
|
||||
|
||||
#print(f"{out.sum()=}")
|
||||
#print(f"{ref_out.sum()=}")
|
||||
|
||||
@ -239,6 +257,11 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(m_out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.03
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor,
|
||||
|
@ -1,123 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
"tcp://localhost:29500",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
|
||||
def parallel_launch_from_env(
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, P], None],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Launches a worker function in parallel across all processes in the current
|
||||
environment. The environment must have the following variables set:
|
||||
- WORLD_SIZE: The total number of processes.
|
||||
- WORLD_LOCAL_SIZE: The number of processes on the current node.
|
||||
- NODE_RANK: The rank of the current
|
||||
- MASTER_ADDR: The address of the master process.
|
||||
- MASTER_PORT: The port of the master process.
|
||||
"""
|
||||
assert not kwargs
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
world_local_size = int(os.environ["WORLD_LOCAL_SIZE"])
|
||||
node_rank = int(os.environ["NODE_RANK"])
|
||||
assert "MASTER_ADDR" in os.environ
|
||||
assert "MASTER_PORT" in os.environ
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_local_size,
|
||||
node_rank,
|
||||
"env://",
|
||||
worker,
|
||||
) + args,
|
||||
nprocs=world_local_size,
|
||||
join=True,
|
||||
)
|
@ -36,6 +36,9 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert (len(self.block_shape) == 2 and all(
|
||||
[v == self.DEEPGEMM_BLOCK_SHAPE for v in self.block_shape]))
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -45,17 +48,19 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -72,7 +77,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
import deep_gemm as dg
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
@ -89,7 +94,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
workspace2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))
|
||||
workspace3 = _resize_cache(workspace13, (E, max_num_tokens, K))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
@ -118,8 +122,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale),
|
||||
(w2, w2_scale),
|
||||
out=workspace3,
|
||||
out=output,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
|
||||
return workspace3
|
||||
|
@ -64,6 +64,15 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
block_shape=self.block_shape, # type: ignore[arg-type]
|
||||
) if (self.allow_deep_gemm and is_fp8_128_block_quantized) else None
|
||||
|
||||
assert (self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return ((bdge is None or bdge.supports_chunking())
|
||||
and (bte is None or bte.supports_chunking()))
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -73,7 +82,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
@ -87,6 +96,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -103,7 +113,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
use_batched_deep_gemm_experts = (self.allow_deep_gemm
|
||||
and self.batched_deep_gemm_experts
|
||||
is not None)
|
||||
@ -111,7 +121,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
if use_batched_deep_gemm_experts else
|
||||
self.batched_triton_experts)
|
||||
assert experts is not None
|
||||
return experts.apply(hidden_states, w1, w2, topk_ids, activation,
|
||||
global_num_experts, expert_map, w1_scale,
|
||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale,
|
||||
workspace13, workspace2, expert_num_tokens)
|
||||
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
|
||||
global_num_experts, expert_map, w1_scale, w2_scale,
|
||||
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||
workspace2, expert_num_tokens)
|
||||
|
@ -14,6 +14,7 @@ from vllm.scalar_type import scalar_types
|
||||
|
||||
|
||||
def run_cutlass_moe_fp8(
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -31,7 +32,8 @@ def run_cutlass_moe_fp8(
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
) -> torch.Tensor:
|
||||
use_batched_format: bool,
|
||||
):
|
||||
a1q = hidden_states
|
||||
|
||||
assert w1_scale is not None
|
||||
@ -61,23 +63,20 @@ def run_cutlass_moe_fp8(
|
||||
if expert_map is not None:
|
||||
assert expert_num_tokens is None
|
||||
|
||||
# We have two modes: PPLX and non-PPLX. We differentiate them by checking
|
||||
# if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX
|
||||
# uses to track the number of tokens per expert).
|
||||
# In the non-PPLX mode, the input tokens are not padded: thus, the shape
|
||||
# We have two modes: batched experts and non-batched experts.
|
||||
# In the non-batched mode, the input tokens are not padded: thus, the shape
|
||||
# of the input is [total_num_tokens, hidden_size]. The input and output
|
||||
# require shuffling by a_map and c_map such that the tokens assigned to
|
||||
# each expert are contiguous.
|
||||
# In the PPLX mode, the input tokens are padded per expert to ensure that
|
||||
# the PPLX dispatch and combine functions work correctly: thus, the shape
|
||||
# In the batched mode, the input tokens are padded per expert to ensure that
|
||||
# the batched dispatch and combine functions work correctly: thus, the shape
|
||||
# of the input is [num_experts, max_num_tokens_per_expert, hidden_size].
|
||||
# The PPLX input and output require no shuffling by a_map and c_map since
|
||||
# The batched input and output require no shuffling by a_map and c_map since
|
||||
# their tokens are already contiguous for each expert as a result of
|
||||
# the dispatch function.
|
||||
is_pplx = expert_num_tokens is not None
|
||||
|
||||
M = a1q.shape[0] # no pplx
|
||||
padded_M = a1q.shape[1] # pplx
|
||||
M = a1q.shape[0] # non batched expert M
|
||||
padded_M = a1q.shape[1] # batched expert M
|
||||
_, K, N = w2.shape
|
||||
device = a1q.device
|
||||
|
||||
@ -95,7 +94,9 @@ def run_cutlass_moe_fp8(
|
||||
topk = local_topk_ids.shape[1]
|
||||
local_E = w1.shape[0]
|
||||
|
||||
if is_pplx:
|
||||
if use_batched_format:
|
||||
assert expert_num_tokens is not None
|
||||
|
||||
expert_offsets = torch.empty((local_E),
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
@ -167,7 +168,7 @@ def run_cutlass_moe_fp8(
|
||||
device=device,
|
||||
dtype=torch.int64)
|
||||
|
||||
if is_pplx:
|
||||
if use_batched_format:
|
||||
c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2))
|
||||
c2 = _resize_cache(workspace2, (local_E * padded_M, N))
|
||||
c3 = _resize_cache(workspace13, (local_E * padded_M, K))
|
||||
@ -192,12 +193,15 @@ def run_cutlass_moe_fp8(
|
||||
problem_sizes2, ab_strides2, ab_strides2, c_strides2,
|
||||
per_act_token, per_out_ch)
|
||||
|
||||
if is_pplx:
|
||||
return c3.reshape(local_E, padded_M, K)
|
||||
if use_batched_format:
|
||||
output.copy_(c3.reshape(local_E, padded_M, K), non_blocking=True)
|
||||
else:
|
||||
return c3[c_map].view(M, topk, K)
|
||||
# We can't do this inplace because output may point to the same tensor
|
||||
# as c3.
|
||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||
|
||||
|
||||
# TODO (bnell): split class batched vs. non-batched?
|
||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
@ -206,12 +210,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
out_dtype: torch.dtype,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
use_batched_format: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_experts_per_worker = max_experts_per_worker
|
||||
self.out_dtype = out_dtype
|
||||
self.per_act_token = per_act_token
|
||||
self.per_out_ch = per_out_ch
|
||||
self.use_batched_format = use_batched_format
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return not self.use_batched_format
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -222,14 +231,24 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
padded_M = aq.shape[1]
|
||||
workspace1 = self.max_experts_per_worker * padded_M * max(N, K)
|
||||
workspace2 = self.max_experts_per_worker * padded_M * (N // 2)
|
||||
return (workspace1, workspace2, self.out_dtype)
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1: tuple[int, ...] = ()
|
||||
workspace2: tuple[int, ...] = ()
|
||||
output: tuple[int, ...] = ()
|
||||
if self.use_batched_format:
|
||||
padded_M = aq.shape[1]
|
||||
workspace1 = (self.max_experts_per_worker, padded_M, max(N, K))
|
||||
workspace2 = (self.max_experts_per_worker, padded_M, (N // 2))
|
||||
output = (self.max_experts_per_worker, padded_M, K)
|
||||
else:
|
||||
workspace1 = (M * topk, max(2 * N, K))
|
||||
workspace2 = (M * topk, N)
|
||||
output = (M * topk, K)
|
||||
return (workspace1, workspace2, output, self.out_dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -246,16 +265,17 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||
activation_callable = lambda i, o: self.activation(activation, i, o)
|
||||
return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids,
|
||||
activation_callable, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2,
|
||||
expert_num_tokens, self.out_dtype,
|
||||
self.per_act_token, self.per_out_ch)
|
||||
run_cutlass_moe_fp8(output, hidden_states, w1, w2, topk_ids,
|
||||
activation_callable, global_num_experts,
|
||||
expert_map, w1_scale, w2_scale, a1q_scale,
|
||||
a2_scale, workspace13, workspace2,
|
||||
expert_num_tokens, self.out_dtype,
|
||||
self.per_act_token, self.per_out_ch,
|
||||
self.use_batched_format)
|
||||
|
||||
|
||||
def cutlass_moe_fp8(
|
||||
@ -325,6 +345,7 @@ def cutlass_moe_fp8(
|
||||
out_dtype=out_dtype,
|
||||
per_act_token=per_act_token,
|
||||
per_out_ch=per_out_ch,
|
||||
use_batched_format=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -70,6 +70,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
super().__init__()
|
||||
self.block_shape = deep_gemm_block_shape()
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -79,18 +82,18 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = M_sum * max(N * 2, K)
|
||||
workspace2 = M_sum * max(N, K)
|
||||
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
workspace1 = (M_sum, max(N * 2, K))
|
||||
workspace2 = (M_sum, max(N, K))
|
||||
output = (M * topk, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -107,7 +110,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
import deep_gemm as dg
|
||||
|
||||
a1q = hidden_states
|
||||
@ -143,7 +146,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(M_sum, N // 2))
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
out = _resize_cache(workspace13, (inv_perm.size(0), K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
|
||||
@ -159,9 +161,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=out)
|
||||
|
||||
return out
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=output)
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
|
@ -335,9 +335,6 @@ def invoke_moe_batched_triton_kernel(
|
||||
BLOCK_M = config['BLOCK_SIZE_M']
|
||||
BLOCK_N = config['BLOCK_SIZE_N']
|
||||
BLOCK_K = config['BLOCK_SIZE_K']
|
||||
assert (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()
|
||||
or max_num_tokens % BLOCK_M == 0)
|
||||
|
||||
grid = (expert_num_tokens.size(0), triton.cdiv(max_num_tokens, BLOCK_M) *
|
||||
triton.cdiv(B.size(1), BLOCK_N))
|
||||
@ -390,8 +387,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
that the PPLX dispatch/combine kernels use.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_tokens: Optional[int], world_size: int,
|
||||
dp_size: int, rank: int):
|
||||
def __init__(self, max_num_tokens: int, world_size: int, dp_size: int,
|
||||
rank: int):
|
||||
super().__init__()
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
@ -430,14 +427,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
num_tokens, hidden_dim = a1.size()
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
if self.max_num_tokens is None:
|
||||
tokens_per_expert = torch.bincount(topk_ids.view(-1),
|
||||
minlength=num_experts)
|
||||
self.max_num_tokens = int(tokens_per_expert.max().item())
|
||||
else:
|
||||
tokens_per_expert = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=a1.device)
|
||||
tokens_per_expert = torch.zeros(num_experts,
|
||||
dtype=torch.int,
|
||||
device=a1.device)
|
||||
|
||||
assert num_experts % self.world_size == 0
|
||||
|
||||
@ -497,9 +489,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
world_size: int,
|
||||
dp_size: int,
|
||||
max_num_tokens: Optional[int] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@ -518,6 +510,9 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.world_size = world_size
|
||||
self.dp_size = dp_size
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -527,18 +522,16 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
#print(f"WORKSPACE {max_num_tokens} {num_dp}")
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * K
|
||||
workspace2 = max_num_tokens * num_dp * N
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
workspace13 = (num_experts, self.max_num_tokens * num_dp, K)
|
||||
workspace2 = (self.max_num_tokens * num_dp, N)
|
||||
return (workspace13, workspace2, workspace13, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -555,20 +548,12 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
assert hidden_states.dim() == 3
|
||||
assert expert_num_tokens is not None
|
||||
hidden_dim = hidden_states.size(-1)
|
||||
|
||||
if self.max_num_tokens is None:
|
||||
max_num_tokens = hidden_states.size(1)
|
||||
else:
|
||||
max_num_tokens = self.max_num_tokens
|
||||
|
||||
max_num_tokens = self.max_num_tokens
|
||||
num_dp = self.world_size // self.dp_size
|
||||
num_experts = global_num_experts
|
||||
out = _resize_cache(workspace13,
|
||||
(num_experts, max_num_tokens * num_dp, hidden_dim))
|
||||
num_local_experts = w1.size(0)
|
||||
assert num_local_experts == w1.size(0), (
|
||||
f"{num_local_experts} == {w1.size(0)}")
|
||||
@ -585,15 +570,13 @@ class BatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# Indexing expert_num_tokens doesn't work w/cudagraphs or inductor
|
||||
if (torch.compiler.is_compiling()
|
||||
or torch.cuda.is_current_stream_capturing()):
|
||||
num = max_num_tokens * num_dp
|
||||
num = hidden_states.shape[1]
|
||||
else:
|
||||
num = int(expert_num_tokens[expert].item())
|
||||
tmp = _resize_cache(workspace2, (num, N))
|
||||
input = hidden_states[expert, :num, :] @ w1[expert].transpose(0, 1)
|
||||
self.activation(activation, tmp, input)
|
||||
out[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
|
||||
return out
|
||||
output[expert, :num, :] = tmp @ w2[expert].transpose(0, 1)
|
||||
|
||||
|
||||
class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
@ -630,6 +613,9 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert not use_int4_w4a16, "NYI"
|
||||
assert self.block_shape is None, "NYI"
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -639,17 +625,19 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
num_dp = self.world_size // self.dp_size
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = num_experts * max_num_tokens * num_dp * max(K, N)
|
||||
workspace2 = num_experts * max_num_tokens * num_dp * (N // 2)
|
||||
return (workspace13, workspace2, a.dtype)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dp, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -666,7 +654,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
@ -723,8 +711,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(E, max_num_tokens, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(E, max_num_tokens, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(E, max_num_tokens, K))
|
||||
|
||||
# MM1
|
||||
invoke_moe_batched_triton_kernel(A=hidden_states,
|
||||
@ -761,7 +747,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
invoke_moe_batched_triton_kernel(A=qintermediate_cache2,
|
||||
B=w2,
|
||||
C=intermediate_cache3,
|
||||
C=output,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
compute_type=compute_type,
|
||||
A_scale=a2q_scale,
|
||||
@ -772,4 +758,3 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
config=config,
|
||||
block_shape=self.block_shape)
|
||||
return intermediate_cache3
|
||||
|
@ -1542,6 +1542,9 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
use_int4_w4a16=use_int4_w4a16)
|
||||
self.per_channel_quant = per_channel_quant
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return True
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -1551,14 +1554,15 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
factor = num_experts if a.dim() == 3 else 1
|
||||
workspace1 = M * topk * max(N * 2, K) * factor
|
||||
workspace2 = M * topk * N * factor
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
workspace1 = (M, topk, max(N * 2, K))
|
||||
workspace2 = (M, topk, N)
|
||||
output = (M, topk, K)
|
||||
return (workspace1, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -1575,7 +1579,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
# Check constraints.
|
||||
if self.use_int4_w4a16:
|
||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||
@ -1632,8 +1636,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
(num_tokens, top_k_num, N))
|
||||
intermediate_cache2 = _resize_cache(workspace2,
|
||||
(num_tokens * top_k_num, N // 2))
|
||||
intermediate_cache3 = _resize_cache(workspace13,
|
||||
(num_tokens, top_k_num, K))
|
||||
|
||||
sorted_token_ids, expert_ids, num_tokens_post_padded = (
|
||||
moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'],
|
||||
@ -1671,7 +1673,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
invoke_fused_moe_kernel(qintermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
output,
|
||||
a2q_scale,
|
||||
w2_scale,
|
||||
w2_zp,
|
||||
@ -1690,8 +1692,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
per_channel_quant=self.per_channel_quant,
|
||||
block_shape=self.block_shape)
|
||||
|
||||
return intermediate_cache3
|
||||
|
||||
|
||||
def modular_triton_fused_moe(
|
||||
use_fp8_w8a8: bool,
|
||||
|
@ -1,10 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.utils import cdiv
|
||||
|
||||
#
|
||||
# This file defines a set of base classes used to make MoE kernels more modular.
|
||||
# The goal is to be able to utilize different communication mechanisms with
|
||||
@ -115,9 +120,9 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
- quantized + dispatched a.
|
||||
- quantized + dispatched a1_scales.
|
||||
- Optional tensor as big as number of local experts that contains the
|
||||
number of tokens assigned to each local expert.
|
||||
number of tokens assigned to each local expert.
|
||||
- Optional dispatched expert topk IDs
|
||||
- Optional dispatched expert topk weight
|
||||
- Optional dispatched expert topk weight
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -159,7 +164,7 @@ class FusedMoEPrepareAndFinalize(ABC):
|
||||
Some PrepareFinalize All2All implementations are batched. Meaning,
|
||||
they can processes only as set of tokens at a time. This
|
||||
function returns the batch size i.e the maximum number of tokens
|
||||
the implementation can process at a time.
|
||||
the implementation can process at a time.
|
||||
Return None if there are no such restrictions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -171,6 +176,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
above.
|
||||
"""
|
||||
|
||||
# TODO (bnell): make this return a CHUNK_SIZE or None instead?
|
||||
@abstractmethod
|
||||
def supports_chunking(self) -> bool:
|
||||
"""
|
||||
A flag indicating whether or not this class supports activation
|
||||
chunking.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def workspace_shapes(
|
||||
self,
|
||||
@ -181,19 +195,22 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
"""
|
||||
Compute the number of elements for the temporary outputs of the two
|
||||
gemms and activation in the fused expert function. Since the
|
||||
gemms are independent, the workspace for the first gemm can be shared
|
||||
with the workspace for the last gemm.
|
||||
Compute the shapes for the temporary and final outputs of the two gemms
|
||||
and activation in the fused expert function. Since the gemms are
|
||||
independent, the workspace for the first gemm can be shared with the
|
||||
workspace for the last gemm.
|
||||
|
||||
Returns a tuple of:
|
||||
- Number of workspace13 elements: must be large enough to hold the
|
||||
- workspace13 shape tuple: must be large enough to hold the
|
||||
result of either expert gemm.
|
||||
- Number of workspace2 elements: must be large enough to hold the
|
||||
- workspace2 shape tuple: must be large enough to hold the
|
||||
result of the activation function.
|
||||
- output shape tuple: must be exact size of the final gemm output.
|
||||
- Workspace type: The dtype to use for the workspace tensors.
|
||||
- Note: in order for activation chunking to work, the first dimension
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -210,6 +227,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -226,12 +244,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
"""
|
||||
This function computes the intermediate result of a Mixture of Experts
|
||||
(MoE) layer using two sets of weights, w1 and w2.
|
||||
|
||||
Parameters:
|
||||
- output: (torch.Tensor): The unweighted, unreduced output tensor.
|
||||
- hidden_states: (torch.Tensor): The (quantized) input tensor to the MoE
|
||||
layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
@ -259,13 +278,20 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
function.
|
||||
- expert_num_tokens: An optional tensor containing the number of tokens
|
||||
assigned to each expert when using batched experts format input.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The unweighted, unreduced output tensor
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _chunk_scales(scales: Optional[torch.Tensor], start: int,
|
||||
end: int) -> Optional[torch.Tensor]:
|
||||
if scales is not None:
|
||||
if scales.numel() == 1:
|
||||
return scales
|
||||
else:
|
||||
return scales[start:end]
|
||||
return None
|
||||
|
||||
|
||||
class FusedMoEModularKernel(torch.nn.Module):
|
||||
"""
|
||||
This class combines a FusedMoEPrepareAndFinalize instance and
|
||||
@ -288,61 +314,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
self.prepare_finalize = prepare_finalize
|
||||
self.fused_experts = fused_experts
|
||||
|
||||
def _do_fused_experts(
|
||||
self,
|
||||
a1: torch.Tensor, # input to forward fn
|
||||
a1q: torch.Tensor, # output of prepare fn
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_num_tokens: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
w1_scale: Optional[torch.Tensor],
|
||||
w2_scale: Optional[torch.Tensor],
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
# Use a1 here to decipher the correct workspace datatype
|
||||
workspace13_shape, workspace2_shape, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k,
|
||||
global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the time
|
||||
# we need cache3, we're done with cache1
|
||||
workspace13 = torch.zeros(workspace13_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(workspace2_shape,
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
fused_out = self.fused_experts.apply(
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
return fused_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -408,12 +379,14 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
_expert_topk_weights) = self.prepare_finalize.prepare(
|
||||
a1, a1_scale, a2_scale, topk_weights, topk_ids,
|
||||
global_num_experts, expert_map, apply_router_weight_on_input)
|
||||
|
||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
|
||||
topk_weights = (topk_weights if _expert_topk_weights is None else
|
||||
_expert_topk_weights)
|
||||
|
||||
fused_out = None
|
||||
|
||||
if a1q.numel() == 0:
|
||||
# This happens when none of the tokens from the all2all reach this
|
||||
# EP rank. Also, note that this is only relevant for CUDAGraph
|
||||
@ -423,22 +396,107 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
# and can never run into the tensor.numel() == 0 case.
|
||||
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
|
||||
else:
|
||||
fused_out = self._do_fused_experts(
|
||||
a1=a1,
|
||||
a1q=a1q,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_ids=topk_ids,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale)
|
||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||
|
||||
if self.fused_experts.supports_chunking():
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||
else:
|
||||
CHUNK_SIZE = M
|
||||
num_chunks = 1
|
||||
|
||||
if num_chunks == 1:
|
||||
(workspace13_shape, workspace2_shape, fused_out_shape,
|
||||
workspace_dtype) = self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts)
|
||||
else:
|
||||
# Use the full M to get the final output shape.
|
||||
_, _, fused_out_shape, _ = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, M, N, K, top_k, global_num_experts))
|
||||
# Use the CHUNK_SIZE to get the workspace shapes.
|
||||
workspace13_shape, workspace2_shape, _, workspace_dtype = (
|
||||
self.fused_experts.workspace_shapes(
|
||||
a1, a1q, CHUNK_SIZE, N, K, top_k, global_num_experts))
|
||||
|
||||
# We can reuse the memory between cache1 and cache3 because by the
|
||||
# time we need cache3, we're done with cache1.
|
||||
workspace13 = torch.zeros(prod(workspace13_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
workspace2 = torch.zeros(prod(workspace2_shape),
|
||||
device=a1.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
if num_chunks == 1:
|
||||
fused_out = _resize_cache(workspace13, fused_out_shape)
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out,
|
||||
a1q,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
else:
|
||||
# The leading output dimension may not be equal to M, so
|
||||
# we compute output indices separately.
|
||||
M_out = fused_out_shape[0]
|
||||
assert M_out >= M
|
||||
factor = M_out // M
|
||||
assert factor > 0
|
||||
OUT_CHUNK_SIZE = CHUNK_SIZE * factor
|
||||
|
||||
fused_out = torch.empty(fused_out_shape,
|
||||
device=a1q.device,
|
||||
dtype=workspace_dtype)
|
||||
|
||||
assert cdiv(M_out, OUT_CHUNK_SIZE) == num_chunks, (
|
||||
f"{cdiv(M_out, OUT_CHUNK_SIZE)} == {num_chunks}")
|
||||
|
||||
for chunk in range(num_chunks):
|
||||
begin_chunk_idx = chunk * CHUNK_SIZE
|
||||
end_chunk_idx = min((chunk + 1) * CHUNK_SIZE, M)
|
||||
begin_out_idx = chunk * OUT_CHUNK_SIZE
|
||||
end_out_idx = min((chunk + 1) * OUT_CHUNK_SIZE, M_out)
|
||||
curr_a1q = a1q[begin_chunk_idx:end_chunk_idx]
|
||||
curr_a1q_scale = _chunk_scales(a1q_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_a2_scale = _chunk_scales(a2_scale, begin_chunk_idx,
|
||||
end_chunk_idx)
|
||||
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
|
||||
|
||||
self.fused_experts.apply(
|
||||
fused_out[begin_out_idx:end_out_idx],
|
||||
curr_a1q,
|
||||
w1,
|
||||
w2,
|
||||
curr_topk_ids,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_zp=w1_zp,
|
||||
w2_zp=w2_zp,
|
||||
a1q_scale=curr_a1q_scale,
|
||||
a2_scale=curr_a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_num_tokens=expert_num_tokens,
|
||||
)
|
||||
|
||||
self.prepare_finalize.finalize(output, fused_out, topk_weights,
|
||||
topk_ids, apply_router_weight_on_input)
|
||||
|
@ -34,6 +34,12 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
self.deep_gemm_expert = DeepGemmExperts(
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
dge = self.deep_gemm_expert
|
||||
te = self.triton_expert
|
||||
return ((dge is None or dge.supports_chunking())
|
||||
and (te is None or te.supports_chunking()))
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
@ -43,7 +49,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
K: int,
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
@ -57,6 +63,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
@ -73,45 +80,31 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_num_tokens: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
):
|
||||
N = w1.size(1)
|
||||
if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2)):
|
||||
assert self.deep_gemm_expert is not None
|
||||
return self.deep_gemm_expert.apply(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
)
|
||||
else:
|
||||
return self.triton_expert.apply(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
)
|
||||
|
||||
use_deep_gemm = (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512
|
||||
and _valid_deep_gemm(hidden_states, w1, w2))
|
||||
|
||||
experts = self.deep_gemm_expert if use_deep_gemm else self.triton_expert
|
||||
assert experts is not None
|
||||
|
||||
experts.apply(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
w1_zp,
|
||||
w2_zp,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_num_tokens,
|
||||
)
|
||||
|
@ -562,9 +562,12 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
|
||||
(moe.num_experts + prepare_finalize.world_size - 1) //
|
||||
prepare_finalize.world_size)
|
||||
experts = CutlassExpertsFp8(
|
||||
max_experts_per_worker, moe.in_dtype,
|
||||
max_experts_per_worker,
|
||||
moe.in_dtype,
|
||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL)
|
||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||
use_batched_format=True,
|
||||
)
|
||||
|
||||
if has_pplx and isinstance(
|
||||
prepare_finalize,
|
||||
|
Reference in New Issue
Block a user