mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
1 Commits
v0.11.0rc2
...
compile-ep
Author | SHA1 | Date | |
---|---|---|---|
787384dd4a |
@ -1401,6 +1401,66 @@ class FusedMoE(CustomOp):
|
||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||
|
||||
@staticmethod
|
||||
@torch.compile(dynamic=True,
|
||||
backend=current_platform.simple_compile_backend)
|
||||
def handle_eplb(
|
||||
topk_ids: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
expert_load_view: torch.Tensor,
|
||||
indices_type: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# TODO: maybe optimize this by using specified kernels,
|
||||
# or compute pseudo-random indices by modulo
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
replica_indices = (
|
||||
torch.rand_like(topk_ids, dtype=torch.float) *
|
||||
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
||||
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
||||
-1, replica_indices).squeeze(-1)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Performance optimization:
|
||||
# `masked_fill` is significantly faster than `masked_select`
|
||||
invalid_mask = topk_ids_flatten < 0
|
||||
# Replace invalid expert ids with 0 (just a dummy position)
|
||||
# to avoid out-of-bounds errors in scatter_add_
|
||||
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
||||
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
||||
src = ~invalid_mask
|
||||
|
||||
expert_load_view.scatter_add_(dim=0,
|
||||
index=index.long(),
|
||||
src=src.to(expert_load_view))
|
||||
|
||||
return topk_ids.to(dtype=indices_type)
|
||||
|
||||
@staticmethod
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
@ -1480,56 +1540,12 @@ class FusedMoE(CustomOp):
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
|
||||
# 1. Convert the logical expert ids to physical expert ids
|
||||
# Directly select a random replica for each logical expert
|
||||
|
||||
# TODO: maybe optimize this by using specified kernels,
|
||||
# or compute pseudo-random indices by modulo
|
||||
|
||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||
topk_ids_long = topk_ids.long()
|
||||
replica_indices = (
|
||||
torch.rand_like(topk_ids, dtype=torch.float) *
|
||||
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
||||
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
||||
-1, replica_indices).squeeze(-1)
|
||||
|
||||
topk_ids = physical_ids
|
||||
|
||||
# 2. Record expert load metrics.
|
||||
|
||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||
# can be done in a more unified way, since
|
||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||
# token count, in some cases directly from the kernel.
|
||||
# However, now there are many code paths not using
|
||||
# the modular kernel, e.g. calling `fused_experts`,
|
||||
# so we decide to keep the logic here.
|
||||
#
|
||||
# If later refactor moved all the MoE kernel calls
|
||||
# to the modular kernel, we can move this logic there
|
||||
# to achieve better efficiency.
|
||||
|
||||
# `expert_load_view`: (num_physical_experts,)
|
||||
|
||||
topk_ids_flatten = topk_ids.flatten()
|
||||
|
||||
# Performance optimization:
|
||||
# `masked_fill` is significantly faster than `masked_select`
|
||||
invalid_mask = topk_ids_flatten < 0
|
||||
# Replace invalid expert ids with 0 (just a dummy position)
|
||||
# to avoid out-of-bounds errors in scatter_add_
|
||||
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
||||
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
||||
src = ~invalid_mask
|
||||
|
||||
expert_load_view.scatter_add_(dim=0,
|
||||
index=index.long(),
|
||||
src=src.to(expert_load_view))
|
||||
|
||||
topk_ids = topk_ids.to(dtype=indices_type)
|
||||
topk_ids = FusedMoE.handle_eplb(
|
||||
topk_ids=topk_ids,
|
||||
logical_replica_count=logical_replica_count,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
expert_load_view=expert_load_view,
|
||||
indices_type=indices_type)
|
||||
|
||||
assert topk_ids.dtype == indices_type or indices_type is None
|
||||
|
||||
|
Reference in New Issue
Block a user