Compare commits

...

1 Commits

Author SHA1 Message Date
787384dd4a updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-08-28 01:25:55 +00:00

View File

@ -1401,6 +1401,66 @@ class FusedMoE(CustomOp):
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
self.logical_replica_count = logical_replica_count[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx]
@staticmethod
@torch.compile(dynamic=True,
backend=current_platform.simple_compile_backend)
def handle_eplb(
topk_ids: torch.Tensor,
logical_replica_count: torch.Tensor,
logical_to_physical_map: torch.Tensor,
expert_load_view: torch.Tensor,
indices_type: torch.dtype,
) -> torch.Tensor:
# 1. Convert the logical expert ids to physical expert ids
# Directly select a random replica for each logical expert
# TODO: maybe optimize this by using specified kernels,
# or compute pseudo-random indices by modulo
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
replica_indices = (
torch.rand_like(topk_ids, dtype=torch.float) *
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
physical_ids = logical_to_physical_map[topk_ids_long].gather(
-1, replica_indices).squeeze(-1)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
topk_ids_flatten = topk_ids.flatten()
# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
# `src` is the valid mask, which is 1 for valid and 0 for invalid
src = ~invalid_mask
expert_load_view.scatter_add_(dim=0,
index=index.long(),
src=src.to(expert_load_view))
return topk_ids.to(dtype=indices_type)
@staticmethod @staticmethod
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -1480,56 +1540,12 @@ class FusedMoE(CustomOp):
assert expert_load_view is not None assert expert_load_view is not None
assert logical_to_physical_map is not None assert logical_to_physical_map is not None
assert logical_replica_count is not None assert logical_replica_count is not None
topk_ids = FusedMoE.handle_eplb(
# 1. Convert the logical expert ids to physical expert ids topk_ids=topk_ids,
# Directly select a random replica for each logical expert logical_replica_count=logical_replica_count,
logical_to_physical_map=logical_to_physical_map,
# TODO: maybe optimize this by using specified kernels, expert_load_view=expert_load_view,
# or compute pseudo-random indices by modulo indices_type=indices_type)
# In case `indices_type` is not `torch.long` or `torch.int`,
# e.g. `torch.uint32` as required by dispatch/combine kernels
topk_ids_long = topk_ids.long()
replica_indices = (
torch.rand_like(topk_ids, dtype=torch.float) *
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
physical_ids = logical_to_physical_map[topk_ids_long].gather(
-1, replica_indices).squeeze(-1)
topk_ids = physical_ids
# 2. Record expert load metrics.
# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.
# `expert_load_view`: (num_physical_experts,)
topk_ids_flatten = topk_ids.flatten()
# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
# `src` is the valid mask, which is 1 for valid and 0 for invalid
src = ~invalid_mask
expert_load_view.scatter_add_(dim=0,
index=index.long(),
src=src.to(expert_load_view))
topk_ids = topk_ids.to(dtype=indices_type)
assert topk_ids.dtype == indices_type or indices_type is None assert topk_ids.dtype == indices_type or indices_type is None