mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
b8089195b4
commit
5cf2daea9a
@ -274,7 +274,7 @@ def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor,
|
|||||||
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
|
||||||
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
|
||||||
|
|
||||||
b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare(
|
b_a, b_a_scale, expert_num_tokens, _, _ = prepare_finalize.prepare(
|
||||||
a_chunk,
|
a_chunk,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
|||||||
@ -233,16 +233,11 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
|||||||
# Defaults for internode and intranode are taken from DeepEP tests.
|
# Defaults for internode and intranode are taken from DeepEP tests.
|
||||||
num_nvl_bytes = 1024 * 1024 * 1024
|
num_nvl_bytes = 1024 * 1024 * 1024
|
||||||
num_qps_per_rank = num_local_experts
|
num_qps_per_rank = num_local_experts
|
||||||
num_rdma_bytes = None
|
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
||||||
|
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
||||||
if self.internode:
|
hidden=token_hidden_size,
|
||||||
num_rdma_bytes = 1024 * 1024 * 1024
|
num_ranks=num_ep_ranks,
|
||||||
else:
|
num_experts=num_global_experts)
|
||||||
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
|
|
||||||
num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank,
|
|
||||||
hidden=token_hidden_size,
|
|
||||||
num_ranks=num_ep_ranks,
|
|
||||||
num_experts=num_global_experts)
|
|
||||||
|
|
||||||
assert num_rdma_bytes is not None
|
assert num_rdma_bytes is not None
|
||||||
return dict(group=self.cpu_group,
|
return dict(group=self.cpu_group,
|
||||||
|
|||||||
@ -110,6 +110,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DP_SIZE: int = 1
|
VLLM_DP_SIZE: int = 1
|
||||||
VLLM_DP_MASTER_IP: str = ""
|
VLLM_DP_MASTER_IP: str = ""
|
||||||
VLLM_DP_MASTER_PORT: int = 0
|
VLLM_DP_MASTER_PORT: int = 0
|
||||||
|
VLLM_RANDOMIZE_DP_DUMMY_INPUTS: bool = False
|
||||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||||
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
|
||||||
@ -761,6 +762,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_DP_MASTER_PORT":
|
"VLLM_DP_MASTER_PORT":
|
||||||
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
|
lambda: int(os.getenv("VLLM_DP_MASTER_PORT", "0")),
|
||||||
|
|
||||||
|
# Randomize inputs during dummy runs when using Data Parallel
|
||||||
|
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS":
|
||||||
|
lambda: os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1",
|
||||||
|
|
||||||
# Whether to use S3 path for model loading in CI via RunAI Streamer
|
# Whether to use S3 path for model loading in CI via RunAI Streamer
|
||||||
"VLLM_CI_USE_S3":
|
"VLLM_CI_USE_S3":
|
||||||
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
|
lambda: os.environ.get("VLLM_CI_USE_S3", "0") == "1",
|
||||||
|
|||||||
@ -80,11 +80,13 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
topk: int,
|
topk: int,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[int, int, torch.dtype]:
|
) -> tuple[int, int, torch.dtype]:
|
||||||
|
|
||||||
block_m = self.block_shape[0]
|
block_m = self.block_shape[0]
|
||||||
M_sum = (M * topk) + num_experts * (block_m - 1)
|
M_sum = (M * topk) + num_experts * (block_m - 1)
|
||||||
M_sum = round_up(M_sum, block_m)
|
M_sum = round_up(M_sum, block_m)
|
||||||
workspace1 = M_sum * max(N * 2, K)
|
workspace1 = M_sum * max(N * 2, K)
|
||||||
workspace2 = M_sum * N
|
workspace2 = M_sum * max(N, K)
|
||||||
|
|
||||||
return (workspace1, workspace2, a.dtype)
|
return (workspace1, workspace2, a.dtype)
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
@ -135,26 +137,31 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
# Note: M_sum is different than the pre-permuted shape of a1q.
|
# Note: M_sum is different than the pre-permuted shape of a1q.
|
||||||
M_sum = a1q.size(0)
|
M_sum = a1q.size(0)
|
||||||
workspace1 = _resize_cache(workspace13, (M_sum, N))
|
|
||||||
workspace2 = _resize_cache(workspace2, (M_sum, N // 2))
|
mm1_out = _resize_cache(workspace13, (M_sum, N))
|
||||||
workspace3 = _resize_cache(workspace13, (M_sum, K))
|
act_out = _resize_cache(workspace2, (M_sum, N // 2))
|
||||||
|
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn),
|
||||||
|
(M_sum, N // 2))
|
||||||
|
mm2_out = _resize_cache(workspace2, (M_sum, K))
|
||||||
|
out = _resize_cache(workspace13, (inv_perm.size(0), K))
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids)
|
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
|
||||||
|
|
||||||
self.activation(activation, workspace2, workspace1.view(-1, N))
|
self.activation(activation, act_out, mm1_out.view(-1, N))
|
||||||
|
|
||||||
a2q_scale: Optional[torch.Tensor] = None
|
a2q_scale: Optional[torch.Tensor] = None
|
||||||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2,
|
a2q, a2q_scale = per_token_group_quant_fp8(act_out,
|
||||||
self.block_shape[1],
|
self.block_shape[1],
|
||||||
column_major_scales=True)
|
column_major_scales=True,
|
||||||
|
out_q=quant_out)
|
||||||
|
|
||||||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
||||||
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids)
|
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
|
||||||
|
|
||||||
workspace3 = workspace3[inv_perm, ...]
|
torch.index_select(mm2_out, 0, inv_perm, out=out)
|
||||||
|
|
||||||
return workspace3
|
return out
|
||||||
|
|
||||||
|
|
||||||
def deep_gemm_moe_fp8(
|
def deep_gemm_moe_fp8(
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import deep_ep
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
|
|
||||||
@ -193,20 +194,23 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
output_dtype: torch.dtype):
|
output_dtype: torch.dtype):
|
||||||
|
|
||||||
|
hidden_dim = fused_expert_output.size(-1)
|
||||||
if fused_expert_output.ndim == 2:
|
if fused_expert_output.ndim == 2:
|
||||||
hidden_dim = fused_expert_output.size(-1)
|
|
||||||
fused_expert_output = fused_expert_output.view(
|
fused_expert_output = fused_expert_output.view(
|
||||||
num_tokens, -1, hidden_dim)
|
num_tokens, -1, hidden_dim)
|
||||||
|
|
||||||
if not apply_router_weight_on_input:
|
if not apply_router_weight_on_input:
|
||||||
# The DeepEP combine kernels don't do the topk weight
|
# The DeepEP combine kernels don't do the topk weight
|
||||||
# multiplication. We multiply the weights locally.
|
# multiplication. We multiply the weights locally.
|
||||||
fused_expert_output = fused_expert_output.to(torch.float32)
|
m_x_topk = fused_expert_output.size(0)
|
||||||
fused_expert_output = fused_expert_output * topk_weights.view(
|
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1))
|
||||||
fused_expert_output.size(0), -1, 1)
|
|
||||||
fused_expert_output = fused_expert_output.to(output_dtype)
|
|
||||||
|
|
||||||
return fused_expert_output.sum(dim=1).to(output_dtype)
|
out = torch.empty((num_tokens, hidden_dim),
|
||||||
|
device=fused_expert_output.device,
|
||||||
|
dtype=output_dtype)
|
||||||
|
ops.moe_sum(fused_expert_output, out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ def _moe_permute(
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
block_m: int,
|
block_m: int,
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||||
Optional[torch.Tensor]]:
|
torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Determine the sorted_token_ids, expert_ids for the given problem size.
|
Determine the sorted_token_ids, expert_ids for the given problem size.
|
||||||
Permute the hidden states and scales according to `sorted_token_ids`.
|
Permute the hidden states and scales according to `sorted_token_ids`.
|
||||||
|
|||||||
@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
|
|||||||
row = g_id // groups_per_row
|
row = g_id // groups_per_row
|
||||||
row_g_id = g_id % groups_per_row
|
row_g_id = g_id % groups_per_row
|
||||||
|
|
||||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
# Ensure offset calculations use int64 to prevent overflow
|
||||||
y_q_ptr += g_id * group_size
|
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||||
|
group_size)
|
||||||
|
y_ptr += y_ptr_offset
|
||||||
|
|
||||||
|
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||||
|
y_q_ptr += y_q_ptr_offset
|
||||||
y_s_ptr += g_id
|
y_s_ptr += g_id
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||||
@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
|
|||||||
row = g_id // groups_per_row
|
row = g_id // groups_per_row
|
||||||
row_g_id = g_id % groups_per_row
|
row_g_id = g_id % groups_per_row
|
||||||
|
|
||||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
# Ensure offset calculations use int64 to prevent overflow
|
||||||
y_q_ptr += g_id * group_size
|
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||||
|
group_size)
|
||||||
|
y_ptr += y_ptr_offset
|
||||||
|
|
||||||
|
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||||
|
y_q_ptr += y_q_ptr_offset
|
||||||
|
|
||||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||||
# into the output y_scales matrix
|
# into the output y_scales matrix
|
||||||
blocks_per_row = y_num_columns // group_size
|
blocks_per_row = y_num_columns // group_size
|
||||||
scale_col = g_id % blocks_per_row
|
scale_col = g_id % blocks_per_row
|
||||||
scale_row = g_id // blocks_per_row
|
scale_row = g_id // blocks_per_row
|
||||||
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
# Ensure offset calculation uses int64 for y_s_ptr
|
||||||
|
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
|
||||||
|
tl.int64)
|
||||||
|
y_s_ptr += y_s_ptr_offset
|
||||||
|
|
||||||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||||
mask = cols < group_size
|
mask = cols < group_size
|
||||||
@ -311,6 +324,7 @@ def per_token_group_quant_fp8(
|
|||||||
eps: float = 1e-10,
|
eps: float = 1e-10,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
column_major_scales: bool = False,
|
column_major_scales: bool = False,
|
||||||
|
out_q: Optional[torch.Tensor] = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||||
It converts the tensor values into signed float8 values and returns the
|
It converts the tensor values into signed float8 values and returns the
|
||||||
@ -321,6 +335,8 @@ def per_token_group_quant_fp8(
|
|||||||
eps: The minimum to avoid dividing zero.
|
eps: The minimum to avoid dividing zero.
|
||||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||||
is supported for now.
|
is supported for now.
|
||||||
|
column_major_scales: Outputs scales in column major.
|
||||||
|
out_q: Optional output tensor. If not provided, function will create.
|
||||||
Returns:
|
Returns:
|
||||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||||
scaling factor for quantization.
|
scaling factor for quantization.
|
||||||
@ -335,7 +351,11 @@ def per_token_group_quant_fp8(
|
|||||||
fp8_min = finfo.min
|
fp8_min = finfo.min
|
||||||
fp8_max = finfo.max
|
fp8_max = finfo.max
|
||||||
|
|
||||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
assert out_q is None or out_q.shape == x.shape
|
||||||
|
x_q = out_q
|
||||||
|
if x_q is None:
|
||||||
|
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||||
|
|
||||||
M = x.numel() // group_size
|
M = x.numel() // group_size
|
||||||
N = group_size
|
N = group_size
|
||||||
if column_major_scales:
|
if column_major_scales:
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import copy
|
|||||||
import gc
|
import gc
|
||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -12,6 +13,7 @@ import torch
|
|||||||
import torch.distributed
|
import torch.distributed
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.attention import AttentionType, get_attn_backend
|
from vllm.attention import AttentionType, get_attn_backend
|
||||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
AttentionMetadataBuilder)
|
AttentionMetadataBuilder)
|
||||||
@ -1727,6 +1729,35 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
return prompt_logprobs_dict
|
return prompt_logprobs_dict
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def maybe_randomize_inputs(self, input_ids: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set.
|
||||||
|
This is to help balance expert-selection
|
||||||
|
- during profile_run
|
||||||
|
- during DP rank dummy run
|
||||||
|
"""
|
||||||
|
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||||
|
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
|
||||||
|
if not randomize_inputs:
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
import functools
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def rand_input_ids() -> torch.Tensor:
|
||||||
|
return torch.randint_like(
|
||||||
|
self.input_ids,
|
||||||
|
low=0,
|
||||||
|
high=self.model_config.get_vocab_size(),
|
||||||
|
dtype=input_ids.dtype)
|
||||||
|
|
||||||
|
logger.debug("Randomizing dummy data for DP Rank")
|
||||||
|
input_ids.copy_(rand_input_ids()[:input_ids.size(0)],
|
||||||
|
non_blocking=True)
|
||||||
|
yield
|
||||||
|
input_ids.fill_(0)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
@ -1807,7 +1838,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
|
||||||
num_tokens, None, False)
|
num_tokens, None, False)
|
||||||
|
|
||||||
with set_forward_context(
|
with self.maybe_randomize_inputs(input_ids), set_forward_context(
|
||||||
attn_metadata,
|
attn_metadata,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
|
|||||||
Reference in New Issue
Block a user