mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
1 Commits
d31f7844f8
...
compile-ep
Author | SHA1 | Date | |
---|---|---|---|
787384dd4a |
@ -1401,6 +1401,66 @@ class FusedMoE(CustomOp):
|
|||||||
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
||||||
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.compile(dynamic=True,
|
||||||
|
backend=current_platform.simple_compile_backend)
|
||||||
|
def handle_eplb(
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
logical_replica_count: torch.Tensor,
|
||||||
|
logical_to_physical_map: torch.Tensor,
|
||||||
|
expert_load_view: torch.Tensor,
|
||||||
|
indices_type: torch.dtype,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# 1. Convert the logical expert ids to physical expert ids
|
||||||
|
# Directly select a random replica for each logical expert
|
||||||
|
|
||||||
|
# TODO: maybe optimize this by using specified kernels,
|
||||||
|
# or compute pseudo-random indices by modulo
|
||||||
|
|
||||||
|
# In case `indices_type` is not `torch.long` or `torch.int`,
|
||||||
|
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
||||||
|
topk_ids_long = topk_ids.long()
|
||||||
|
replica_indices = (
|
||||||
|
torch.rand_like(topk_ids, dtype=torch.float) *
|
||||||
|
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
||||||
|
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
||||||
|
-1, replica_indices).squeeze(-1)
|
||||||
|
|
||||||
|
topk_ids = physical_ids
|
||||||
|
|
||||||
|
# 2. Record expert load metrics.
|
||||||
|
|
||||||
|
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
||||||
|
# can be done in a more unified way, since
|
||||||
|
# `FusedMoEPrepareAndFinalize` will return the expert
|
||||||
|
# token count, in some cases directly from the kernel.
|
||||||
|
# However, now there are many code paths not using
|
||||||
|
# the modular kernel, e.g. calling `fused_experts`,
|
||||||
|
# so we decide to keep the logic here.
|
||||||
|
#
|
||||||
|
# If later refactor moved all the MoE kernel calls
|
||||||
|
# to the modular kernel, we can move this logic there
|
||||||
|
# to achieve better efficiency.
|
||||||
|
|
||||||
|
# `expert_load_view`: (num_physical_experts,)
|
||||||
|
|
||||||
|
topk_ids_flatten = topk_ids.flatten()
|
||||||
|
|
||||||
|
# Performance optimization:
|
||||||
|
# `masked_fill` is significantly faster than `masked_select`
|
||||||
|
invalid_mask = topk_ids_flatten < 0
|
||||||
|
# Replace invalid expert ids with 0 (just a dummy position)
|
||||||
|
# to avoid out-of-bounds errors in scatter_add_
|
||||||
|
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
||||||
|
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
||||||
|
src = ~invalid_mask
|
||||||
|
|
||||||
|
expert_load_view.scatter_add_(dim=0,
|
||||||
|
index=index.long(),
|
||||||
|
src=src.to(expert_load_view))
|
||||||
|
|
||||||
|
return topk_ids.to(dtype=indices_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def select_experts(
|
def select_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -1480,56 +1540,12 @@ class FusedMoE(CustomOp):
|
|||||||
assert expert_load_view is not None
|
assert expert_load_view is not None
|
||||||
assert logical_to_physical_map is not None
|
assert logical_to_physical_map is not None
|
||||||
assert logical_replica_count is not None
|
assert logical_replica_count is not None
|
||||||
|
topk_ids = FusedMoE.handle_eplb(
|
||||||
# 1. Convert the logical expert ids to physical expert ids
|
topk_ids=topk_ids,
|
||||||
# Directly select a random replica for each logical expert
|
logical_replica_count=logical_replica_count,
|
||||||
|
logical_to_physical_map=logical_to_physical_map,
|
||||||
# TODO: maybe optimize this by using specified kernels,
|
expert_load_view=expert_load_view,
|
||||||
# or compute pseudo-random indices by modulo
|
indices_type=indices_type)
|
||||||
|
|
||||||
# In case `indices_type` is not `torch.long` or `torch.int`,
|
|
||||||
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
|
||||||
topk_ids_long = topk_ids.long()
|
|
||||||
replica_indices = (
|
|
||||||
torch.rand_like(topk_ids, dtype=torch.float) *
|
|
||||||
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
|
||||||
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
|
||||||
-1, replica_indices).squeeze(-1)
|
|
||||||
|
|
||||||
topk_ids = physical_ids
|
|
||||||
|
|
||||||
# 2. Record expert load metrics.
|
|
||||||
|
|
||||||
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
|
||||||
# can be done in a more unified way, since
|
|
||||||
# `FusedMoEPrepareAndFinalize` will return the expert
|
|
||||||
# token count, in some cases directly from the kernel.
|
|
||||||
# However, now there are many code paths not using
|
|
||||||
# the modular kernel, e.g. calling `fused_experts`,
|
|
||||||
# so we decide to keep the logic here.
|
|
||||||
#
|
|
||||||
# If later refactor moved all the MoE kernel calls
|
|
||||||
# to the modular kernel, we can move this logic there
|
|
||||||
# to achieve better efficiency.
|
|
||||||
|
|
||||||
# `expert_load_view`: (num_physical_experts,)
|
|
||||||
|
|
||||||
topk_ids_flatten = topk_ids.flatten()
|
|
||||||
|
|
||||||
# Performance optimization:
|
|
||||||
# `masked_fill` is significantly faster than `masked_select`
|
|
||||||
invalid_mask = topk_ids_flatten < 0
|
|
||||||
# Replace invalid expert ids with 0 (just a dummy position)
|
|
||||||
# to avoid out-of-bounds errors in scatter_add_
|
|
||||||
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
|
||||||
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
|
||||||
src = ~invalid_mask
|
|
||||||
|
|
||||||
expert_load_view.scatter_add_(dim=0,
|
|
||||||
index=index.long(),
|
|
||||||
src=src.to(expert_load_view))
|
|
||||||
|
|
||||||
topk_ids = topk_ids.to(dtype=indices_type)
|
|
||||||
|
|
||||||
assert topk_ids.dtype == indices_type or indices_type is None
|
assert topk_ids.dtype == indices_type or indices_type is None
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user