mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
deepep HT dispatch no abstraction
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
189
vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py
Normal file
189
vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py
Normal file
@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# torchrun --nproc_per_node=2 vllm/model_executor/layers/moe/ep_kernels_no_abstraction.py # noqa: E501
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from deep_ep import Buffer, EventOverlap
|
||||
|
||||
# Communication buffer (will allocate at runtime)
|
||||
_buffer: Buffer | None = None
|
||||
|
||||
# Set the number of SMs to use
|
||||
# NOTES: this is a static variable
|
||||
Buffer.set_num_sms(24)
|
||||
|
||||
|
||||
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
|
||||
global _buffer
|
||||
|
||||
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all
|
||||
# the tests
|
||||
num_nvl_bytes, num_rdma_bytes = 0, 0
|
||||
for config in (
|
||||
Buffer.get_dispatch_config(group.size()),
|
||||
Buffer.get_combine_config(group.size()),
|
||||
):
|
||||
num_nvl_bytes = max(
|
||||
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
||||
)
|
||||
num_rdma_bytes = max(
|
||||
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
||||
)
|
||||
|
||||
# Allocate a buffer if not existed or not enough buffer size
|
||||
if (
|
||||
_buffer is None
|
||||
or _buffer.group != group
|
||||
or _buffer.num_nvl_bytes < num_nvl_bytes
|
||||
or _buffer.num_rdma_bytes < num_rdma_bytes
|
||||
):
|
||||
_buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
||||
return _buffer
|
||||
|
||||
|
||||
def get_hidden_bytes(x: torch.Tensor) -> int:
|
||||
t = x[0] if isinstance(x, tuple) else x
|
||||
return t.size(1) * max(t.element_size(), 2)
|
||||
|
||||
|
||||
def dispatch_forward(
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
topk_idx: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
num_experts: int,
|
||||
previous_event: EventOverlap | None = None,
|
||||
) -> tuple[
|
||||
torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
list,
|
||||
tuple,
|
||||
EventOverlap,
|
||||
]:
|
||||
# NOTES: an optional `previous_event` means a CUDA event captured that you want to
|
||||
# make it as a dependency of the dispatch kernel, it may be useful with
|
||||
# communication-computation overlap. For more information, please
|
||||
# refer to the docs of `Buffer.dispatch`
|
||||
global _buffer
|
||||
assert _buffer is not None
|
||||
|
||||
# Calculate layout before actual dispatch
|
||||
(
|
||||
num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank,
|
||||
num_tokens_per_expert,
|
||||
is_token_in_rank,
|
||||
previous_event,
|
||||
) = _buffer.get_dispatch_layout(
|
||||
topk_idx,
|
||||
num_experts,
|
||||
previous_event=previous_event,
|
||||
async_finish=True,
|
||||
allocate_on_comm_stream=previous_event is not None,
|
||||
)
|
||||
# Do MoE dispatch
|
||||
# NOTES: the CPU will wait for GPU's signal to arrive, so this is not compatible
|
||||
# with CUDA graph. Unless you specify `num_worst_tokens`, but this flag is
|
||||
# for intranode only. For more advanced usages, please refer to the docs of
|
||||
# the `dispatch` function
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = _buffer.dispatch(
|
||||
x,
|
||||
topk_idx=topk_idx,
|
||||
topk_weights=topk_weights,
|
||||
num_tokens_per_rank=num_tokens_per_rank,
|
||||
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
||||
is_token_in_rank=is_token_in_rank,
|
||||
num_tokens_per_expert=num_tokens_per_expert,
|
||||
previous_event=previous_event,
|
||||
async_finish=True,
|
||||
allocate_on_comm_stream=True,
|
||||
)
|
||||
# For event management, please refer to the docs of the `EventOverlap` class
|
||||
return (
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.distributed.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
)
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
group = dist.group.WORLD
|
||||
num_experts = 8
|
||||
local_batch_size = 4
|
||||
hidden_size = 128
|
||||
local_num_experts = num_experts // group.size()
|
||||
x = torch.randn(local_batch_size, hidden_size, device="cuda", dtype=torch.bfloat16)
|
||||
hidden_bytes = get_hidden_bytes(x)
|
||||
get_buffer(group, hidden_bytes)
|
||||
topk = 4
|
||||
|
||||
expert_weights = torch.randn(
|
||||
local_batch_size,
|
||||
num_experts,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
topk_weights, topk_idx = torch.topk(expert_weights, topk, dim=1)
|
||||
|
||||
# Dispatch
|
||||
(
|
||||
recv_x,
|
||||
recv_topk_idx,
|
||||
recv_topk_weights,
|
||||
num_recv_tokens_per_expert_list,
|
||||
handle,
|
||||
event,
|
||||
) = dispatch_forward(
|
||||
x,
|
||||
topk_idx,
|
||||
topk_weights,
|
||||
num_experts,
|
||||
)
|
||||
# print(f"rank {rank} recv_x: {recv_x.shape=}")
|
||||
output_recv_topk_idx = recv_topk_idx + torch.where(
|
||||
recv_topk_idx == -1,
|
||||
0,
|
||||
rank * local_num_experts,
|
||||
)
|
||||
print(f"rank {rank} recv_topk_idx: {recv_topk_idx.shape=} {output_recv_topk_idx}")
|
||||
# print(
|
||||
# f"rank {rank} recv_topk_weights: {recv_topk_weights.shape=} "
|
||||
# f"{recv_topk_weights}"
|
||||
# )
|
||||
|
||||
# Dispatch naive
|
||||
all_x = [torch.empty_like(x) for _ in range(world_size)]
|
||||
all_topk_idx = [torch.empty_like(topk_idx) for _ in range(world_size)]
|
||||
all_topk_weights = [torch.empty_like(topk_weights) for _ in range(world_size)]
|
||||
|
||||
dist.all_gather(all_x, x)
|
||||
dist.all_gather(all_topk_idx, topk_idx)
|
||||
dist.all_gather(all_topk_weights, topk_weights)
|
||||
|
||||
all_x = torch.cat(all_x, dim=0)
|
||||
all_topk_idx = torch.cat(all_topk_idx, dim=0)
|
||||
all_topk_weights = torch.cat(all_topk_weights, dim=0)
|
||||
|
||||
assert isinstance(all_topk_idx, torch.Tensor)
|
||||
print(f"rank {rank} all_topk_idx: {all_topk_idx.shape=} {all_topk_idx}")
|
Reference in New Issue
Block a user