mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
b8089195b4
commit
5cf2daea9a
@ -274,7 +274,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||
|
||||
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
|
||||
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||
a_chunk,
|
||||
None,
|
||||
None,
|
||||
|
@ -233,16 +233,11 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||
num_nvl_bytes = 1024 * 1024 * 1024
|
||||
num_qps_per_rank = num_local_experts
|
||||
num_rdma_bytes = None
|
||||
|
||||
if self.internode:
|
||||
num_rdma_bytes = 1024 * 1024 * 1024
|
||||
else:
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||
hidden=token_hidden_size,
|
||||
num_ranks=num_ep_ranks,
|
||||
num_experts=num_global_experts)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(group=self.cpu_group,
|
||||
|
@ -110,6 +110,7 @@ if TYPE_CHECKING:
|
||||
VLLM_DP_SIZE: int = 1
|
||||
VLLM_DP_MASTER_IP: str = ""
|
||||
VLLM_DP_MASTER_PORT: int = 0
|
||||
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||
@ -761,6 +762,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_DP_MASTER_PORT":
|
||||
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
|
||||
|
||||
# Randomize inputs during dummy runs when using Data Parallel
|
||||
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
|
||||
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
|
||||
|
||||
# Whether to use S3 path for model loading in CI via RunAI Streamer
|
||||
"VLLM_CI_USE_S3":
|
||||
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
|
||||
|
@ -80,11 +80,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
topk: int,
|
||||
num_experts: int,
|
||||
) -> tuple[int, int, torch.dtype]:
|
||||
|
||||
block_m = self.block_shape[0]
|
||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||
M_sum = round_up(M_sum, block_m)
|
||||
workspace1 = M_sum * max(N * 2, K)
|
||||
workspace2 = M_sum * N
|
||||
workspace2 = M_sum * max(N, K)
|
||||
|
||||
return (workspace1, workspace2, a.dtype)
|
||||
|
||||
def apply(
|
||||
@ -135,26 +137,31 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||
M_sum = a1q.size(0)
|
||||
workspace1 = _resize_cache(workspace13, (M_sum, N))
|
||||
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
|
||||
workspace3 = _resize_cache(workspace13, (M_sum, K))
|
||||
|
||||
mm1_out = _resize_cache(workspace13, (M_sum, N))
|
||||
act_out = _resize_cache(workspace2, (M_sum, N // 2))
|
||||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||
(M_sum, N // 2))
|
||||
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||
out = _resize_cache(workspace13, (inv_perm.size(0), K))
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
|
||||
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
|
||||
|
||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
||||
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
||||
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
|
||||
self.block_shape[1],
|
||||
column_major_scales=True)
|
||||
column_major_scales=True,
|
||||
out_q=quant_out)
|
||||
|
||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
||||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||
|
||||
workspace3 = workspace3[inv_perm, ...]
|
||||
torch.index_select(mm2_out, 0, inv_perm, out=out)
|
||||
|
||||
return workspace3
|
||||
return out
|
||||
|
||||
|
||||
def deep_gemm_moe_fp8(
|
||||
|
@ -5,6 +5,7 @@ import deep_ep
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input)
|
||||
|
||||
@ -193,20 +194,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
apply_router_weight_on_input: bool,
|
||||
output_dtype: torch.dtype):
|
||||
|
||||
hidden_dim = fused_expert_output.size(-1)
|
||||
if fused_expert_output.ndim == 2:
|
||||
hidden_dim = fused_expert_output.size(-1)
|
||||
fused_expert_output = fused_expert_output.view(
|
||||
num_tokens, -1, hidden_dim)
|
||||
|
||||
if not apply_router_weight_on_input:
|
||||
# The DeepEP combine kernels don't do the topk weight
|
||||
# multiplication. We multiply the weights locally.
|
||||
fused_expert_output = fused_expert_output.to(torch.float32)
|
||||
fused_expert_output = fused_expert_output * topk_weights.view(
|
||||
fused_expert_output.size(0), -1, 1)
|
||||
fused_expert_output = fused_expert_output.to(output_dtype)
|
||||
m_x_topk = fused_expert_output.size(0)
|
||||
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
|
||||
|
||||
return fused_expert_output.sum(dim=1).to(output_dtype)
|
||||
out = torch.empty((num_tokens, hidden_dim),
|
||||
device=fused_expert_output.device,
|
||||
dtype=output_dtype)
|
||||
ops.moe_sum(fused_expert_output, out)
|
||||
|
||||
return out
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
|
@ -18,7 +18,7 @@ def _moe_permute(
|
||||
expert_map: Optional[torch.Tensor],
|
||||
block_m: int,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor]]:
|
||||
torch.Tensor]:
|
||||
"""
|
||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||
|
@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
y_q_ptr += y_q_ptr_offset
|
||||
y_s_ptr += g_id
|
||||
|
||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||
@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
y_q_ptr += y_q_ptr_offset
|
||||
|
||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||
# into the output y_scales matrix
|
||||
blocks_per_row = y_num_columns // group_size
|
||||
scale_col = g_id % blocks_per_row
|
||||
scale_row = g_id // blocks_per_row
|
||||
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
||||
# Ensure offset calculation uses int64 for y_s_ptr
|
||||
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
|
||||
tl.int64)
|
||||
y_s_ptr += y_s_ptr_offset
|
||||
|
||||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||
mask = cols < group_size
|
||||
@ -311,6 +324,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
out_q: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
@ -321,6 +335,8 @@ def per_token_group_quant_fp8(
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
@ -335,7 +351,11 @@ def per_token_group_quant_fp8(
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
assert out_q is None or out_q.shape == x.shape
|
||||
x_q = out_q
|
||||
if x_q is None:
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
|
@ -5,6 +5,7 @@ import copy
|
||||
import gc
|
||||
import time
|
||||
import weakref
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -12,6 +13,7 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
@ -1727,6 +1729,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
@contextmanager
|
||||
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
|
||||
"""
|
||||
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
|
||||
This is to help balance expert-selection
|
||||
- during profile_run
|
||||
- during DP rank dummy run
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
|
||||
if not randomize_inputs:
|
||||
yield
|
||||
else:
|
||||
import functools
|
||||
|
||||
@functools.cache
|
||||
def rand_input_ids() -> torch.Tensor:
|
||||
return torch.randint_like(
|
||||
self.input_ids,
|
||||
low=0,
|
||||
high=self.model_config.get_vocab_size(),
|
||||
dtype=input_ids.dtype)
|
||||
|
||||
logger.debug("Randomizing dummy data for DP Rank")
|
||||
input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
|
||||
non_blocking=True)
|
||||
yield
|
||||
input_ids.fill_(0)
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
@ -1807,7 +1838,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||
num_tokens, None, False)
|
||||
|
||||
with set_forward_context(
|
||||
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens,
|
||||
|
Reference in New Issue
Block a user