[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298)

Signed-off-by: Varun <vsundarr@redhat.com>
Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-06-09 10:50:39 -04:00
committed by GitHub
parent b8089195b4
commit 5cf2daea9a
8 changed files with 98 additions and 36 deletions

View File

@ -274,7 +274,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
a_chunk,
None,
None,

View File

@ -233,16 +233,11 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = 1024 * 1024 * 1024
num_qps_per_rank = num_local_experts
num_rdma_bytes = None
if self.internode:
num_rdma_bytes = 1024 * 1024 * 1024
else:
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
hidden=token_hidden_size,
num_ranks=num_ep_ranks,
num_experts=num_global_experts)
assert num_rdma_bytes is not None
return dict(group=self.cpu_group,

View File

@ -110,6 +110,7 @@ if TYPE_CHECKING:
VLLM_DP_SIZE: int = 1
VLLM_DP_MASTER_IP: str = ""
VLLM_DP_MASTER_PORT: int = 0
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
VLLM_V0_USE_OUTLINES_CACHE: bool = False
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
@ -761,6 +762,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DP_MASTER_PORT":
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
# Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
# Whether to use S3 path for model loading in CI via RunAI Streamer
"VLLM_CI_USE_S3":
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",

View File

@ -80,11 +80,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk: int,
num_experts: int,
) -> tuple[int, int, torch.dtype]:
block_m = self.block_shape[0]
M_sum = (M * topk) + num_experts * (block_m - 1)
M_sum = round_up(M_sum, block_m)
workspace1 = M_sum * max(N * 2, K)
workspace2 = M_sum * N
workspace2 = M_sum * max(N, K)
return (workspace1, workspace2, a.dtype)
def apply(
@ -135,26 +137,31 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# Note: M_sum is different than the pre-permuted shape of a1q.
M_sum = a1q.size(0)
workspace1 = _resize_cache(workspace13, (M_sum, N))
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
workspace3 = _resize_cache(workspace13, (M_sum, K))
mm1_out = _resize_cache(workspace13, (M_sum, N))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
(M_sum, N // 2))
mm2_out = _resize_cache(workspace2, (M_sum, K))
out = _resize_cache(workspace13, (inv_perm.size(0), K))
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
self.activation(activation, workspace2, workspace1.view(-1, N))
self.activation(activation, act_out, mm1_out.view(-1, N))
a2q_scale: Optional[torch.Tensor] = None
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
self.block_shape[1],
column_major_scales=True)
column_major_scales=True,
out_q=quant_out)
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
workspace3 = workspace3[inv_perm, ...]
torch.index_select(mm2_out, 0, inv_perm, out=out)
return workspace3
return out
def deep_gemm_moe_fp8(

View File

@ -5,6 +5,7 @@ import deep_ep
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input)
@ -193,20 +194,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
output_dtype: torch.dtype):
hidden_dim = fused_expert_output.size(-1)
if fused_expert_output.ndim == 2:
hidden_dim = fused_expert_output.size(-1)
fused_expert_output = fused_expert_output.view(
num_tokens, -1, hidden_dim)
if not apply_router_weight_on_input:
# The DeepEP combine kernels don't do the topk weight
# multiplication. We multiply the weights locally.
fused_expert_output = fused_expert_output.to(torch.float32)
fused_expert_output = fused_expert_output * topk_weights.view(
fused_expert_output.size(0), -1, 1)
fused_expert_output = fused_expert_output.to(output_dtype)
m_x_topk = fused_expert_output.size(0)
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
return fused_expert_output.sum(dim=1).to(output_dtype)
out = torch.empty((num_tokens, hidden_dim),
device=fused_expert_output.device,
dtype=output_dtype)
ops.moe_sum(fused_expert_output, out)
return out
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

View File

@ -18,7 +18,7 @@ def _moe_permute(
expert_map: Optional[torch.Tensor],
block_m: int,
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
torch.Tensor]:
"""
Determine the sorted_token_ids, expert_ids for the given problem size.
Permute the hidden states and scales according to `sorted_token_ids`.

View File

@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
y_s_ptr += g_id
cols = tl.arange(0, BLOCK) # N <= BLOCK
@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
row = g_id // groups_per_row
row_g_id = g_id % groups_per_row
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
y_q_ptr += g_id * group_size
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
group_size)
y_ptr += y_ptr_offset
y_q_ptr_offset = g_id.to(tl.int64) * group_size
y_q_ptr += y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row = y_num_columns // group_size
scale_col = g_id % blocks_per_row
scale_row = g_id // blocks_per_row
y_s_ptr += scale_col * y_s_col_stride + scale_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
tl.int64)
y_s_ptr += y_s_ptr_offset
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
mask = cols < group_size
@ -311,6 +324,7 @@ def per_token_group_quant_fp8(
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
out_q: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
@ -321,6 +335,8 @@ def per_token_group_quant_fp8(
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
@ -335,7 +351,11 @@ def per_token_group_quant_fp8(
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
assert out_q is None or out_q.shape == x.shape
x_q = out_q
if x_q is None:
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
if column_major_scales:

View File

@ -5,6 +5,7 @@ import copy
import gc
import time
import weakref
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
@ -12,6 +13,7 @@ import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadataBuilder)
@ -1727,6 +1729,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return prompt_logprobs_dict
@contextmanager
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
"""
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
This is to help balance expert-selection
- during profile_run
- during DP rank dummy run
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
if not randomize_inputs:
yield
else:
import functools
@functools.cache
def rand_input_ids() -> torch.Tensor:
return torch.randint_like(
self.input_ids,
low=0,
high=self.model_config.get_vocab_size(),
dtype=input_ids.dtype)
logger.debug("Randomizing dummy data for DP Rank")
input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
non_blocking=True)
yield
input_ids.fill_(0)
@torch.inference_mode()
def _dummy_run(
self,
@ -1807,7 +1838,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
num_tokens, None, False)
with set_forward_context(
with self.maybe_randomize_inputs(input_ids), set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_tokens,