From 5cf2daea9ada71232f87180d38bde2c2044a1a61 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 9 Jun 2025 10:50:39 -0400 Subject: [PATCH] [Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298) Signed-off-by: Varun Co-authored-by: Varun --- tests/kernels/moe/test_pplx_moe.py | 2 +- .../device_communicators/all2all.py | 15 +++------ vllm/envs.py | 5 +++ .../layers/fused_moe/deep_gemm_moe.py | 29 +++++++++------- .../fused_moe/deepep_ht_prepare_finalize.py | 16 +++++---- .../layers/fused_moe/moe_permute_unpermute.py | 2 +- .../layers/quantization/utils/fp8_utils.py | 32 ++++++++++++++---- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++++++++++- 8 files changed, 98 insertions(+), 36 deletions(-) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index bbfe31d0e6..0b48bbef6c 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -274,7 +274,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) - b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare( + b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare( a_chunk, None, None, diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index cab2496bfb..35f2fd0ba9 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -233,16 +233,11 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): # Defaults for internode and intranode are taken from DeepEP tests. num_nvl_bytes = 1024 * 1024 * 1024 num_qps_per_rank = num_local_experts - num_rdma_bytes = None - - if self.internode: - num_rdma_bytes = 1024 * 1024 * 1024 - else: - num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( - num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, - hidden=token_hidden_size, - num_ranks=num_ep_ranks, - num_experts=num_global_experts) + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=num_ep_ranks, + num_experts=num_global_experts) assert num_rdma_bytes is not None return dict(group=self.cpu_group, diff --git a/vllm/envs.py b/vllm/envs.py index 9d18a13895..9511ed1cb4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -110,6 +110,7 @@ if TYPE_CHECKING: VLLM_DP_SIZE: int = 1 VLLM_DP_MASTER_IP: str = "" VLLM_DP_MASTER_PORT: int = 0 + VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False VLLM_MARLIN_USE_ATOMIC_ADD: bool = False VLLM_V0_USE_OUTLINES_CACHE: bool = False VLLM_TPU_BUCKET_PADDING_GAP: int = 0 @@ -761,6 +762,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_DP_MASTER_PORT": lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")), + # Randomize inputs during dummy runs when using Data Parallel + "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": + lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1", + # Whether to use S3 path for model loading in CI via RunAI Streamer "VLLM_CI_USE_S3": lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1", diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index c00e849b4e..436c632be9 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -80,11 +80,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): topk: int, num_experts: int, ) -> tuple[int, int, torch.dtype]: + block_m = self.block_shape[0] M_sum = (M * topk) + num_experts * (block_m - 1) M_sum = round_up(M_sum, block_m) workspace1 = M_sum * max(N * 2, K) - workspace2 = M_sum * N + workspace2 = M_sum * max(N, K) + return (workspace1, workspace2, a.dtype) def apply( @@ -135,26 +137,31 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # Note: M_sum is different than the pre-permuted shape of a1q. M_sum = a1q.size(0) - workspace1 = _resize_cache(workspace13, (M_sum, N)) - workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) - workspace3 = _resize_cache(workspace13, (M_sum, K)) + + mm1_out = _resize_cache(workspace13, (M_sum, N)) + act_out = _resize_cache(workspace2, (M_sum, N // 2)) + quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), + (M_sum, N // 2)) + mm2_out = _resize_cache(workspace2, (M_sum, K)) + out = _resize_cache(workspace13, (inv_perm.size(0), K)) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) + (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) - self.activation(activation, workspace2, workspace1.view(-1, N)) + self.activation(activation, act_out, mm1_out.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None - a2q, a2q_scale = per_token_group_quant_fp8(workspace2, + a2q, a2q_scale = per_token_group_quant_fp8(act_out, self.block_shape[1], - column_major_scales=True) + column_major_scales=True, + out_q=quant_out) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) + (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) - workspace3 = workspace3[inv_perm, ...] + torch.index_select(mm2_out, 0, inv_perm, out=out) - return workspace3 + return out def deep_gemm_moe_fp8( diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 48cf01638a..8c21d8aa53 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -5,6 +5,7 @@ import deep_ep import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input) @@ -193,20 +194,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): apply_router_weight_on_input: bool, output_dtype: torch.dtype): + hidden_dim = fused_expert_output.size(-1) if fused_expert_output.ndim == 2: - hidden_dim = fused_expert_output.size(-1) fused_expert_output = fused_expert_output.view( num_tokens, -1, hidden_dim) if not apply_router_weight_on_input: # The DeepEP combine kernels don't do the topk weight # multiplication. We multiply the weights locally. - fused_expert_output = fused_expert_output.to(torch.float32) - fused_expert_output = fused_expert_output * topk_weights.view( - fused_expert_output.size(0), -1, 1) - fused_expert_output = fused_expert_output.to(output_dtype) + m_x_topk = fused_expert_output.size(0) + fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1)) - return fused_expert_output.sum(dim=1).to(output_dtype) + out = torch.empty((num_tokens, hidden_dim), + device=fused_expert_output.device, + dtype=output_dtype) + ops.moe_sum(fused_expert_output, out) + + return out def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index 89481e5bd6..20ee0d9f78 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -18,7 +18,7 @@ def _moe_permute( expert_map: Optional[torch.Tensor], block_m: int, ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor, - Optional[torch.Tensor]]: + torch.Tensor]: """ Determine the sorted_token_ids, expert_ids for the given problem size. Permute the hidden states and scales according to `sorted_token_ids`. diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 270979c8e9..08dc99e075 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -234,8 +234,13 @@ def _per_token_group_quant_fp8( row = g_id // groups_per_row row_g_id = g_id % groups_per_row - y_ptr += (row * y_row_stride) + (row_g_id * group_size) - y_q_ptr += g_id * group_size + # Ensure offset calculations use int64 to prevent overflow + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * + group_size) + y_ptr += y_ptr_offset + + y_q_ptr_offset = g_id.to(tl.int64) * group_size + y_q_ptr += y_q_ptr_offset y_s_ptr += g_id cols = tl.arange(0, BLOCK) # N <= BLOCK @@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor( row = g_id // groups_per_row row_g_id = g_id % groups_per_row - y_ptr += (row * y_row_stride) + (row_g_id * group_size) - y_q_ptr += g_id * group_size + # Ensure offset calculations use int64 to prevent overflow + y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * + group_size) + y_ptr += y_ptr_offset + + y_q_ptr_offset = g_id.to(tl.int64) * group_size + y_q_ptr += y_q_ptr_offset # Convert g_id the flattened block coordinate to 2D so we can index # into the output y_scales matrix blocks_per_row = y_num_columns // group_size scale_col = g_id % blocks_per_row scale_row = g_id // blocks_per_row - y_s_ptr += scale_col * y_s_col_stride + scale_row + # Ensure offset calculation uses int64 for y_s_ptr + y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( + tl.int64) + y_s_ptr += y_s_ptr_offset cols = tl.arange(0, BLOCK) # group_size <= BLOCK mask = cols < group_size @@ -311,6 +324,7 @@ def per_token_group_quant_fp8( eps: float = 1e-10, dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, + out_q: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the @@ -321,6 +335,8 @@ def per_token_group_quant_fp8( eps: The minimum to avoid dividing zero. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + column_major_scales: Outputs scales in column major. + out_q: Optional output tensor. If not provided, function will create. Returns: tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. @@ -335,7 +351,11 @@ def per_token_group_quant_fp8( fp8_min = finfo.min fp8_max = finfo.max - x_q = torch.empty_like(x, device=x.device, dtype=dtype) + assert out_q is None or out_q.shape == x.shape + x_q = out_q + if x_q is None: + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size N = group_size if column_major_scales: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c39aea3d7e..175404efe0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5,6 +5,7 @@ import copy import gc import time import weakref +from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np @@ -12,6 +13,7 @@ import torch import torch.distributed import torch.nn as nn +import vllm.envs as envs from vllm.attention import AttentionType, get_attn_backend from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadataBuilder) @@ -1727,6 +1729,35 @@ class GPUModelRunner(LoRAModelRunnerMixin): return prompt_logprobs_dict + @contextmanager + def maybe_randomize_inputs(self, input_ids: torch.Tensor): + """ + Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. + This is to help balance expert-selection + - during profile_run + - during DP rank dummy run + """ + dp_size = self.vllm_config.parallel_config.data_parallel_size + randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 + if not randomize_inputs: + yield + else: + import functools + + @functools.cache + def rand_input_ids() -> torch.Tensor: + return torch.randint_like( + self.input_ids, + low=0, + high=self.model_config.get_vocab_size(), + dtype=input_ids.dtype) + + logger.debug("Randomizing dummy data for DP Rank") + input_ids.copy_(rand_input_ids()[:input_ids.size(0)], + non_blocking=True) + yield + input_ids.fill_(0) + @torch.inference_mode() def _dummy_run( self, @@ -1807,7 +1838,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): intermediate_tensors = self.sync_and_slice_intermediate_tensors( num_tokens, None, False) - with set_forward_context( + with self.maybe_randomize_inputs(input_ids), set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens,