Support expert parallel in Transformers backend (#26162)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Harry Mellor
2025-10-04 05:35:04 +01:00
committed by GitHub
parent ea507c3a93
commit d3d649efec
2 changed files with 32 additions and 21 deletions

View File

@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
- Any combination of the following vLLM parallelisation schemes:
- Data parallel
- Pipeline parallel
- Tensor parallel
- Expert parallel
- Pipeline parallel
Checking if the modeling backend is Transformers is as simple as:

View File

@ -22,6 +22,7 @@ import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config.utils import getattr_iter
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._top_k_index: torch.Tensor = None
self._topk_ids: torch.Tensor = None
def custom_routing_function(hidden_states, gating_output, topk,
renormalize):
"""Return `top_k_weights` from `gating_output` and the
`top_k_index` we stored in the layer earlier."""
return gating_output, self._top_k_index
"""Return `topk_weights` from `gating_output` and the
`topk_ids` we stored in the layer earlier."""
topk_weights = gating_output
topk_ids = self._topk_ids
# Handle all gather in expert parallel
if topk_ids.size(0) != hidden_states.size(0):
dp_metadata = get_forward_context().dp_metadata
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
is_sp = self.is_sequence_parallel
dist_group = get_ep_group() if is_sp else get_dp_group()
assert sizes[dist_group.rank_in_group] == topk_ids.shape[0]
topk_ids, = dist_group.all_gatherv([topk_ids], 0, sizes)
return topk_weights, topk_ids
self.custom_routing_function = custom_routing_function
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor,
topk_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""In Transformers `experts.forward` will have this signature.
We discard any extra kwargs because we cannot use them here."""
return torch.ops.vllm.transformers_moe_forward(hidden_states,
top_k_index,
top_k_weights,
self.layer_name)
return torch.ops.vllm.transformers_moe_forward(
hidden_states,
topk_ids.to(torch.int32),
topk_weights.to(torch.float32),
self.layer_name,
)
def transformers_moe_forward(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
"""Store the `top_k_index` in the layer and call the actual forward."""
"""Store the `topk_ids` in the layer and call the actual forward."""
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._top_k_index = top_k_index
self._topk_ids = topk_ids
# Clone hidden_states because it will be mutated in-place in FusedMoE
return self.forward_impl(hidden_states.clone(), top_k_weights)
return self.forward_impl(hidden_states.clone(), topk_weights)
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
@ -96,9 +109,6 @@ class TransformersMoEBase(TransformersBase):
self.check_version("4.57.0.dev0", "MoE models support")
super().__init__(vllm_config=vllm_config, prefix=prefix)
if self.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"Transformers backend does not support expert parallel yet.")
if self.parallel_config.enable_eplb:
raise NotImplementedError(
"Transformers backend does not support expert parallel load "