mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Support expert parallel in Transformers backend (#26162)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -32,8 +32,9 @@ If the Transformers model implementation follows all the steps in [writing a cus
|
||||
- All the features listed in the [compatibility matrix](../features/README.md#feature-x-feature)
|
||||
- Any combination of the following vLLM parallelisation schemes:
|
||||
- Data parallel
|
||||
- Pipeline parallel
|
||||
- Tensor parallel
|
||||
- Expert parallel
|
||||
- Pipeline parallel
|
||||
|
||||
Checking if the modeling backend is Transformers is as simple as:
|
||||
|
||||
|
@ -22,6 +22,7 @@ import torch.nn as nn
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -40,42 +41,54 @@ class TransformersFusedMoE(FusedMoE):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._top_k_index: torch.Tensor = None
|
||||
self._topk_ids: torch.Tensor = None
|
||||
|
||||
def custom_routing_function(hidden_states, gating_output, topk,
|
||||
renormalize):
|
||||
"""Return `top_k_weights` from `gating_output` and the
|
||||
`top_k_index` we stored in the layer earlier."""
|
||||
return gating_output, self._top_k_index
|
||||
"""Return `topk_weights` from `gating_output` and the
|
||||
`topk_ids` we stored in the layer earlier."""
|
||||
topk_weights = gating_output
|
||||
topk_ids = self._topk_ids
|
||||
# Handle all gather in expert parallel
|
||||
if topk_ids.size(0) != hidden_states.size(0):
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
is_sp = self.is_sequence_parallel
|
||||
dist_group = get_ep_group() if is_sp else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == topk_ids.shape[0]
|
||||
topk_ids, = dist_group.all_gatherv([topk_ids], 0, sizes)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
self.custom_routing_function = custom_routing_function
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
def forward(self, hidden_states: torch.Tensor, topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""In Transformers `experts.forward` will have this signature.
|
||||
|
||||
We discard any extra kwargs because we cannot use them here."""
|
||||
return torch.ops.vllm.transformers_moe_forward(hidden_states,
|
||||
top_k_index,
|
||||
top_k_weights,
|
||||
self.layer_name)
|
||||
return torch.ops.vllm.transformers_moe_forward(
|
||||
hidden_states,
|
||||
topk_ids.to(torch.int32),
|
||||
topk_weights.to(torch.float32),
|
||||
self.layer_name,
|
||||
)
|
||||
|
||||
|
||||
def transformers_moe_forward(hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
"""Store the `top_k_index` in the layer and call the actual forward."""
|
||||
"""Store the `topk_ids` in the layer and call the actual forward."""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._top_k_index = top_k_index
|
||||
self._topk_ids = topk_ids
|
||||
# Clone hidden_states because it will be mutated in-place in FusedMoE
|
||||
return self.forward_impl(hidden_states.clone(), top_k_weights)
|
||||
return self.forward_impl(hidden_states.clone(), topk_weights)
|
||||
|
||||
|
||||
def transformers_moe_forward_fake(hidden_states: torch.Tensor,
|
||||
top_k_index: torch.Tensor,
|
||||
top_k_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
layer_name: str) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@ -96,9 +109,6 @@ class TransformersMoEBase(TransformersBase):
|
||||
self.check_version("4.57.0.dev0", "MoE models support")
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if self.parallel_config.enable_expert_parallel:
|
||||
raise NotImplementedError(
|
||||
"Transformers backend does not support expert parallel yet.")
|
||||
if self.parallel_config.enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"Transformers backend does not support expert parallel load "
|
||||
|
Reference in New Issue
Block a user