mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
1768 lines
73 KiB
Python
1768 lines
73 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import abstractmethod
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Callable, Literal, Optional, Union, overload
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from compressed_tensors.quantization import (QuantizationArgs,
|
|
QuantizationStrategy,
|
|
QuantizationType)
|
|
from torch.nn.parameter import UninitializedParameter
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import ParallelConfig, get_current_vllm_config
|
|
from vllm.distributed import (get_dp_group, get_ep_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.distributed.eplb.eplb_state import EplbState
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.custom_op import CustomOp
|
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
|
is_rocm_aiter_moe_enabled)
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
QuantizationConfig, QuantizeMethodBase)
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.platforms import current_platform
|
|
from vllm.platforms.interface import CpuArchEnum
|
|
from vllm.utils import direct_register_custom_op, has_deep_ep, has_pplx
|
|
|
|
if current_platform.is_cuda_alike():
|
|
from .fused_batched_moe import BatchedTritonExperts
|
|
from .fused_moe import TritonExperts, fused_experts
|
|
from .modular_kernel import (FusedMoEModularKernel,
|
|
FusedMoEPermuteExpertsUnpermute,
|
|
FusedMoEPrepareAndFinalize)
|
|
if has_pplx():
|
|
from .pplx_prepare_finalize import PplxPrepareAndFinalize
|
|
if has_deep_ep():
|
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
|
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SIZE,
|
|
DeepEPLLPrepareAndFinalize)
|
|
else:
|
|
fused_experts = None # type: ignore
|
|
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
|
FusedMoEPrepareAndFinalize = None # type: ignore
|
|
if is_rocm_aiter_moe_enabled():
|
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
|
rocm_aiter_grouped_topk as grouped_topk)
|
|
elif current_platform.is_cpu():
|
|
pass
|
|
else:
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
|
if current_platform.is_tpu():
|
|
from .moe_pallas import fused_moe as fused_moe_pallas
|
|
else:
|
|
fused_moe_pallas = None # type: ignore
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class FusedMoEParallelConfig:
|
|
tp_size: int
|
|
dp_size: int
|
|
ep_size: int
|
|
tp_rank: int
|
|
dp_rank: int
|
|
ep_rank: int
|
|
|
|
use_ep: bool # whether to use EP or not
|
|
|
|
@property
|
|
def use_all2all_kernels(self):
|
|
return self.dp_size > 1 and self.use_ep
|
|
|
|
@property
|
|
def use_pplx_kernels(self):
|
|
return (self.use_all2all_kernels
|
|
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
|
|
|
@property
|
|
def use_deepep_ht_kernels(self):
|
|
return (self.use_all2all_kernels
|
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
|
|
|
@property
|
|
def use_deepep_ll_kernels(self):
|
|
return (self.use_all2all_kernels
|
|
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
|
|
|
@staticmethod
|
|
def make(tp_size_: int, dp_size_: int,
|
|
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
|
"""
|
|
Determine MoE parallel configuration. Based on the input tp_size_,
|
|
dp_size_, ep_size_ and vllm's parallel config, determine what
|
|
level's of parallelism to use in the fused moe layer.
|
|
|
|
Args:
|
|
tp_size_ (int): tp_size passed into the FusedMoE constructor.
|
|
dp_size_ (int): dp_size passed into the FusedMoE constructor.
|
|
ep_size_ (int): ep_size passed into the FusedMoE constructor.
|
|
vllm_parallel_config (ParallelConfig): vllm's parallel config
|
|
object.
|
|
|
|
Examples:
|
|
When there is no parallelism requested, i.e. tp_size_ = dp_size_ = 1,
|
|
we simply return the sizes unaltered and the ranks set to 0.
|
|
|
|
Expert Parallelism is considered only when either dp_size_ or tp_size_
|
|
is non trivial.
|
|
|
|
When TP = 2, DP = 1 and EP = False, the configuration on different
|
|
devices,
|
|
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
|
legend : {size, rank}
|
|
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
|
- Comment : Tensors are sharded across 2 devices.
|
|
|
|
When TP = 1, DP = 2 and EP = False, the configuration on different
|
|
devices,
|
|
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
|
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
|
- Comment: There are 2 engine instances and the tensors are sharded
|
|
across 2 decvices.
|
|
|
|
When TP = 2, DP = 2 and EP = False, the configuration on different
|
|
devices,
|
|
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
|
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
|
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
|
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
|
- Comment: There are 2 engine instances and the tensors are sharded
|
|
across 4 devices.
|
|
|
|
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
|
devices,
|
|
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
|
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
|
- Comment: The experts are split between the 2 devices.
|
|
|
|
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
|
devices,
|
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
|
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
|
- Comment: There are 2 engine instances and the experts are split
|
|
between the 2 devices.
|
|
|
|
When TP = 2, DP = 2 and EP = True, the configuration on different
|
|
devices,
|
|
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
|
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
|
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
|
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
|
- Comment: There are 2 engine instances and the experts are split
|
|
between the 4 devices.
|
|
"""
|
|
|
|
def flatten_tp_across_dp(dp_rank: int):
|
|
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
|
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
|
# and tp_rank so we shard across all devices.
|
|
tp_size = dp_size_ * tp_size_
|
|
tp_rank = dp_rank * tp_size_ + tp_rank
|
|
return tp_size, tp_rank
|
|
|
|
use_ep = (dp_size_ * tp_size_ > 1
|
|
and vllm_parallel_config.enable_expert_parallel)
|
|
|
|
dp_size = dp_size_
|
|
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
|
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
|
|
|
if not use_ep:
|
|
return FusedMoEParallelConfig(tp_size=tp_size,
|
|
tp_rank=tp_rank,
|
|
dp_size=dp_size,
|
|
dp_rank=dp_rank,
|
|
ep_size=1,
|
|
ep_rank=0,
|
|
use_ep=False)
|
|
# DP + EP / TP + EP / DP + TP + EP
|
|
assert use_ep
|
|
# In EP, each device owns a set of experts fully. There is no tensor
|
|
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
|
ep_size = tp_size
|
|
ep_rank = tp_rank
|
|
return FusedMoEParallelConfig(tp_size=1,
|
|
tp_rank=0,
|
|
dp_size=dp_size,
|
|
dp_rank=dp_rank,
|
|
ep_size=ep_size,
|
|
ep_rank=ep_rank,
|
|
use_ep=True)
|
|
|
|
|
|
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
|
@dataclass
|
|
class MoEConfig:
|
|
num_experts: int
|
|
experts_per_token: int
|
|
hidden_dim: int
|
|
|
|
num_local_experts: int
|
|
moe_parallel_config: FusedMoEParallelConfig
|
|
|
|
in_dtype: torch.dtype # The activation type.
|
|
quant_dtype: torch.dtype = None
|
|
|
|
# TODO: add more quantization params, blocked, per-token, etc.
|
|
block_size: int = 128
|
|
|
|
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
|
|
|
def __post_init__(self):
|
|
if self.dp_size > 1:
|
|
logger.debug("Using MOEConfig::max_num_tokens=%d",
|
|
self.max_num_tokens)
|
|
|
|
@property
|
|
def tp_size(self):
|
|
return self.moe_parallel_config.tp_size
|
|
|
|
@property
|
|
def dp_size(self):
|
|
return self.moe_parallel_config.dp_size
|
|
|
|
@property
|
|
def ep_size(self):
|
|
return self.moe_parallel_config.ep_size
|
|
|
|
@property
|
|
def tp_rank(self):
|
|
return self.moe_parallel_config.tp_rank
|
|
|
|
@property
|
|
def dp_rank(self):
|
|
return self.moe_parallel_config.dp_rank
|
|
|
|
@property
|
|
def ep_rank(self):
|
|
return self.moe_parallel_config.ep_rank
|
|
|
|
@property
|
|
def use_ep(self):
|
|
return self.moe_parallel_config.use_ep
|
|
|
|
@property
|
|
def use_pplx_kernels(self):
|
|
return self.moe_parallel_config.use_pplx_kernels
|
|
|
|
@property
|
|
def use_deepep_ht_kernels(self):
|
|
return self.moe_parallel_config.use_deepep_ht_kernels
|
|
|
|
@property
|
|
def use_deepep_ll_kernels(self):
|
|
return self.moe_parallel_config.use_deepep_ll_kernels
|
|
|
|
|
|
class FusedMoeWeightScaleSupported(Enum):
|
|
TENSOR = "tensor"
|
|
CHANNEL = "channel"
|
|
GROUP = "group"
|
|
BLOCK = "block"
|
|
|
|
|
|
def get_quant_config_input_activations(
|
|
quant_config: Optional[QuantizationConfig]
|
|
) -> Optional[QuantizationArgs]:
|
|
if (quant_config is not None and hasattr(quant_config, 'target_scheme_map')
|
|
and "Linear" in quant_config.target_scheme_map and
|
|
"input_activations" in quant_config.target_scheme_map["Linear"]):
|
|
return quant_config.target_scheme_map["Linear"].get(
|
|
"input_activations")
|
|
else:
|
|
return None
|
|
|
|
|
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
|
|
|
moe: MoEConfig
|
|
|
|
@abstractmethod
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
raise NotImplementedError
|
|
|
|
def init_prepare_finalize(self, moe: MoEConfig,
|
|
quant_config: Optional[QuantizationConfig]):
|
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
|
assert all2all_manager is not None
|
|
|
|
self.moe = moe
|
|
quant_dtype = None
|
|
act_quant_block_size = None
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
if isinstance(quant_config, Fp8Config):
|
|
act_quant_block_size = quant_config.weight_block_size
|
|
quant_dtype = torch.float8_e4m3fn
|
|
|
|
prepare_finalize: Optional[Union[PplxPrepareAndFinalize,
|
|
DeepEPHTPrepareAndFinalize,
|
|
DeepEPLLPrepareAndFinalize]] = None
|
|
if moe.use_pplx_kernels:
|
|
all_to_all_args = dict(
|
|
max_num_tokens=moe.max_num_tokens,
|
|
num_experts=moe.num_experts,
|
|
experts_per_token=moe.experts_per_token, # topk
|
|
rank=all2all_manager.rank,
|
|
world_size=all2all_manager.world_size,
|
|
# dp_size actually means tp_size, bug in pplx kernels
|
|
dp_size=all2all_manager.tp_group.world_size,
|
|
hidden_dim=moe.hidden_dim,
|
|
hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize,
|
|
# For blocked per token: set to
|
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
|
# For per-token: set to sizeof(float32)
|
|
hidden_dim_scale_bytes=(
|
|
0 if moe.quant_dtype.itemsize != 1 else
|
|
((moe.hidden_dim + moe.block_size - 1) // moe.block_size *
|
|
torch.float32.itemsize)),
|
|
)
|
|
|
|
# Intranode pplx a2a takes a group name while internode does not.
|
|
if not all2all_manager.internode:
|
|
all_to_all_args[
|
|
"group_name"] = all2all_manager.cpu_group.group_name
|
|
|
|
handle = all2all_manager.get_handle(all_to_all_args)
|
|
|
|
input_activations = get_quant_config_input_activations(
|
|
quant_config)
|
|
|
|
prepare_finalize = PplxPrepareAndFinalize(
|
|
handle,
|
|
max_num_tokens=moe.max_num_tokens,
|
|
world_size=all2all_manager.world_size,
|
|
rank=all2all_manager.rank,
|
|
# dp_size actually means tp_size, bug in pplx kernels
|
|
dp_size=all2all_manager.tp_group.world_size,
|
|
quant_dtype=moe.quant_dtype,
|
|
per_act_token=(input_activations.strategy
|
|
== QuantizationStrategy.TOKEN
|
|
if input_activations is not None else False),
|
|
)
|
|
elif moe.use_deepep_ht_kernels:
|
|
assert moe.dp_size == all2all_manager.dp_world_size
|
|
|
|
all_to_all_args = dict()
|
|
handle = all2all_manager.get_handle(all_to_all_args)
|
|
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
|
handle,
|
|
world_size=all2all_manager.world_size,
|
|
rank=all2all_manager.rank,
|
|
dp_size=all2all_manager.dp_world_size,
|
|
rank_expert_offset=all2all_manager.rank *
|
|
moe.num_local_experts,
|
|
quant_dtype=quant_dtype,
|
|
block_shape=act_quant_block_size,
|
|
)
|
|
|
|
elif moe.use_deepep_ll_kernels:
|
|
assert moe.dp_size == all2all_manager.dp_world_size
|
|
|
|
all_to_all_args = dict(
|
|
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
|
token_hidden_size=moe.hidden_dim,
|
|
num_ep_ranks=all2all_manager.world_size,
|
|
num_global_experts=moe.num_experts,
|
|
num_local_experts=moe.num_experts //
|
|
all2all_manager.world_size)
|
|
handle = all2all_manager.get_handle(all_to_all_args)
|
|
|
|
# Note : We may want to use FP8 dispatch even otherwise just to
|
|
# reduce datamovement
|
|
assert act_quant_block_size is not None
|
|
use_fp8_dispatch = (quant_dtype == current_platform.fp8_dtype()
|
|
and act_quant_block_size[1]
|
|
== DEEPEP_QUANT_BLOCK_SIZE)
|
|
|
|
# Note (varun): Whether to use FP8 dispatch or not needs some
|
|
# profiling. Turning it off for now.
|
|
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
|
handle,
|
|
world_size=all2all_manager.world_size,
|
|
dp_size=all2all_manager.dp_world_size,
|
|
max_tokens_per_rank=moe.max_num_tokens,
|
|
quant_dtype=quant_dtype,
|
|
block_shape=act_quant_block_size,
|
|
use_fp8_dispatch=use_fp8_dispatch,
|
|
)
|
|
|
|
self.topk_indices_dtype = None
|
|
if prepare_finalize is not None:
|
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
|
experts = self.select_gemm_impl(prepare_finalize, moe)
|
|
self.fused_experts = FusedMoEModularKernel(
|
|
prepare_finalize,
|
|
experts,
|
|
)
|
|
|
|
def select_gemm_impl(
|
|
self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute:
|
|
# based on the all2all implementation, select the appropriate
|
|
# gemm implementation
|
|
raise NotImplementedError(
|
|
"Subclass must select appropriate gemm implementation"
|
|
" based on the prepare_finalize")
|
|
|
|
@abstractmethod
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
raise NotImplementedError
|
|
|
|
|
|
@CustomOp.register("unquantized_fused_moe")
|
|
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
"""MoE method without quantization."""
|
|
|
|
def __init__(self, moe: MoEConfig):
|
|
super().__init__()
|
|
self.fused_experts = fused_experts # type: ignore
|
|
self.topk_indices_dtype = None
|
|
self.moe = moe
|
|
|
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
|
if self.rocm_aiter_moe_enabled:
|
|
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
|
|
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
|
|
else:
|
|
self.rocm_aiter_fused_experts = None # type: ignore
|
|
|
|
def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize,
|
|
moe: Optional[MoEConfig]):
|
|
|
|
assert self.fused_experts == fused_experts
|
|
|
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
|
assert all2all_manager is not None
|
|
|
|
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
|
|
|
|
use_batched_experts = prepare_finalize.max_num_tokens_per_rank(
|
|
) is not None
|
|
if use_batched_experts:
|
|
logger.debug("BatchedTritonExperts %s", self.moe)
|
|
assert self.moe.dp_size == all2all_manager.dp_world_size
|
|
experts = BatchedTritonExperts(
|
|
max_num_tokens=self.moe.max_num_tokens,
|
|
world_size=all2all_manager.world_size,
|
|
# dp_size actually means tp_size, bug in pplx kernels
|
|
dp_size=all2all_manager.tp_group.world_size,
|
|
use_fp8_w8a8=False,
|
|
use_int8_w8a8=False,
|
|
use_int8_w8a16=False,
|
|
use_int4_w4a16=False,
|
|
block_shape=None,
|
|
per_channel_quant=False,
|
|
)
|
|
else:
|
|
logger.debug("TritonExperts %s", self.moe)
|
|
experts = TritonExperts(
|
|
use_fp8_w8a8=False,
|
|
use_int8_w8a8=False,
|
|
use_int8_w8a16=False,
|
|
use_int4_w4a16=False,
|
|
block_shape=None,
|
|
per_channel_quant=False,
|
|
)
|
|
return experts
|
|
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
# Fused gate_up_proj (column parallel)
|
|
w13_weight = torch.nn.Parameter(torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
hidden_size,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w13_weight", w13_weight)
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
|
|
|
# down_proj (row parallel)
|
|
w2_weight = torch.nn.Parameter(torch.empty(
|
|
num_experts,
|
|
hidden_size,
|
|
intermediate_size_per_partition,
|
|
dtype=params_dtype),
|
|
requires_grad=False)
|
|
layer.register_parameter("w2_weight", w2_weight)
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
|
|
|
def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
|
|
# Pad the weight tensor. This is an optimization on ROCm platform, which
|
|
# can benefit from tensors located far enough from one another in memory
|
|
if (envs.VLLM_ROCM_MOE_PADDING and current_platform.is_rocm()
|
|
and weight.stride(-1) == 1
|
|
and (weight.stride(-2) * weight.element_size()) % 512 == 0):
|
|
num_pad = 256 // weight.element_size()
|
|
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
|
|
torch.cuda.empty_cache()
|
|
return weight
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
super().process_weights_after_loading(layer)
|
|
|
|
# Padding the weight for better performance on ROCm
|
|
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
|
|
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
|
|
# Lazy import to avoid importing triton.
|
|
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
|
shuffle_weights)
|
|
|
|
if self.rocm_aiter_moe_enabled:
|
|
shuffled_w13, shuffled_w2 = shuffle_weights(
|
|
layer.w13_weight.data, layer.w2_weight.data)
|
|
|
|
layer.w13_weight.data = shuffled_w13
|
|
layer.w2_weight.data = shuffled_w2
|
|
|
|
if current_platform.is_cpu():
|
|
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
|
|
from vllm.model_executor.layers.fused_moe import cpu_fused_moe
|
|
dtype = layer.w13_weight.dtype
|
|
if (envs.VLLM_CPU_SGL_KERNEL
|
|
and torch._C._cpu._is_amx_tile_supported()
|
|
and dtype == torch.bfloat16):
|
|
packed_w13_weight = torch.ops._C.convert_weight_packed(
|
|
layer.w13_weight)
|
|
assert packed_w13_weight.size() == layer.w13_weight.size()
|
|
layer.w13_weight.copy_(packed_w13_weight)
|
|
del packed_w13_weight
|
|
packed_w2_weight = torch.ops._C.convert_weight_packed(
|
|
layer.w2_weight)
|
|
assert packed_w2_weight.size() == layer.w2_weight.size()
|
|
layer.w2_weight.copy_(packed_w2_weight)
|
|
layer.cpu_fused_moe = cpu_fused_moe.SGLFusedMOE(layer)
|
|
else:
|
|
layer.cpu_fused_moe = cpu_fused_moe.IPEXFusedMOE(layer)
|
|
else:
|
|
raise NotImplementedError("CPU MOE only supports x86 arch.")
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if enable_eplb:
|
|
raise NotImplementedError(
|
|
"EPLB not supported for `UnquantizedFusedMoEMethod` yet.")
|
|
|
|
return self.forward(
|
|
x=x,
|
|
layer=layer,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
use_grouped_topk=use_grouped_topk,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
) -> torch.Tensor:
|
|
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
indices_type=self.topk_indices_dtype)
|
|
|
|
if self.rocm_aiter_moe_enabled:
|
|
assert expert_map is None
|
|
return self.rocm_aiter_fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
else:
|
|
return self.fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
def forward_cpu(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
**kwargs,
|
|
):
|
|
return layer.cpu_fused_moe(
|
|
layer,
|
|
x,
|
|
use_grouped_topk,
|
|
top_k,
|
|
router_logits,
|
|
renormalize,
|
|
topk_group,
|
|
num_expert_group,
|
|
global_num_experts,
|
|
expert_map,
|
|
custom_routing_function,
|
|
scoring_func,
|
|
e_score_correction_bias,
|
|
apply_router_weight_on_input,
|
|
activation,
|
|
)
|
|
|
|
def forward_hpu(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
) -> torch.Tensor:
|
|
assert not use_grouped_topk
|
|
assert num_expert_group is None
|
|
assert topk_group is None
|
|
assert custom_routing_function is None
|
|
assert layer is not None
|
|
assert apply_router_weight_on_input is False
|
|
if scoring_func != "softmax":
|
|
raise NotImplementedError(
|
|
"Only softmax scoring function is supported for HPU.")
|
|
if e_score_correction_bias is not None:
|
|
raise NotImplementedError(
|
|
"Expert score correction bias is not supported for HPU.")
|
|
return layer.hpu_fused_moe(x, layer.w13_weight, layer.w2_weight,
|
|
router_logits, top_k)
|
|
|
|
def forward_tpu(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
) -> torch.Tensor:
|
|
assert not use_grouped_topk
|
|
assert num_expert_group is None
|
|
assert topk_group is None
|
|
assert custom_routing_function is None
|
|
assert apply_router_weight_on_input is False
|
|
if scoring_func != "softmax":
|
|
raise NotImplementedError(
|
|
"Only softmax scoring function is supported for TPU.")
|
|
if e_score_correction_bias is not None:
|
|
raise NotImplementedError(
|
|
"Expert score correction bias is not supported for TPU.")
|
|
assert activation == "silu", f"{activation} is not supported for TPU."
|
|
return fused_moe_pallas(hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk=top_k,
|
|
gating_output=router_logits,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
renormalize=renormalize)
|
|
|
|
if current_platform.is_tpu():
|
|
forward_native = forward_tpu
|
|
elif current_platform.is_cpu():
|
|
forward_native = forward_cpu
|
|
else:
|
|
forward_native = forward_cuda
|
|
|
|
|
|
def determine_expert_map(
|
|
ep_size: int, ep_rank: int,
|
|
global_num_experts: int) -> tuple[int, Optional[torch.Tensor]]:
|
|
"""
|
|
Calculates how many experts should be assigned to each rank for EP and
|
|
creates a mapping from global to local expert index. Experts are
|
|
distributed evenly across ranks. Any remaining are assigned to the
|
|
last rank.
|
|
|
|
Args:
|
|
ep_size (int): The size of the expert parallel group
|
|
global_num_experts (int): The total number of experts in the model.
|
|
|
|
Returns:
|
|
tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
|
- local_num_experts (int): The number of experts assigned
|
|
to the current rank.
|
|
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
|
(global_num_experts,) mapping from global to local index.
|
|
Contains -1 for experts not assigned to the current rank.
|
|
Returns None if ep_size is 1.
|
|
"""
|
|
assert ep_size > 0
|
|
if ep_size == 1:
|
|
return (global_num_experts, None)
|
|
|
|
local_num_experts = global_num_experts // ep_size
|
|
|
|
# Create a tensor of size num_experts filled with -1
|
|
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
|
|
# Create a expert map for the local experts
|
|
if ep_rank < (ep_size - 1):
|
|
# Each non-last rank gets local_num_experts experts.
|
|
expert_map[ep_rank * local_num_experts:
|
|
(ep_rank + 1) * local_num_experts] = \
|
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
|
else:
|
|
# All remaining experts are assigned to the last rank.
|
|
local_num_experts = (global_num_experts - ep_rank * local_num_experts)
|
|
|
|
expert_map[-local_num_experts:] = \
|
|
torch.arange(0, local_num_experts, dtype=torch.int32)
|
|
return (local_num_experts, expert_map)
|
|
|
|
|
|
class FusedMoE(torch.nn.Module):
|
|
"""FusedMoE layer for MoE models.
|
|
|
|
This layer contains both MergedColumnParallel weights (gate_up_proj /
|
|
w13) and RowParallelLinear weights (down_proj/ w2).
|
|
|
|
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
|
|
copy that naming convention here and handle any remapping in the
|
|
load_weights function in each model implementation.
|
|
|
|
Args:
|
|
num_experts: Number of experts in the model
|
|
top_k: Number of experts selected for each token
|
|
hidden_size: Input hidden state size of the transformer
|
|
intermediate_size: Intermediate size of the experts
|
|
params_dtype: Data type for the parameters.
|
|
reduce_results: Whether to all all_reduce on the output of the layer
|
|
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
|
quant_config: Quantization configure.
|
|
enable_eplb: Whether to enable expert parallelism load balancer.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_experts: int, # Global number of experts
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
reduce_results: bool = False,
|
|
renormalize: bool = True,
|
|
use_grouped_topk: bool = False,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
tp_size: Optional[int] = None,
|
|
ep_size: Optional[int] = None,
|
|
dp_size: Optional[int] = None,
|
|
prefix: str = "",
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
num_redundant_experts: int = 0,
|
|
):
|
|
super().__init__()
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.moe_parallel_config: FusedMoEParallelConfig = (
|
|
FusedMoEParallelConfig.make(
|
|
tp_size_=(tp_size if tp_size is not None else
|
|
get_tensor_model_parallel_world_size()),
|
|
dp_size_=(dp_size if dp_size is not None else
|
|
get_dp_group().world_size),
|
|
vllm_parallel_config=vllm_config.parallel_config))
|
|
|
|
self.global_num_experts = num_experts + num_redundant_experts
|
|
|
|
# For smuggling this layer into the fused moe custom op
|
|
compilation_config = vllm_config.compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError("Duplicate layer name: {}".format(prefix))
|
|
compilation_config.static_forward_context[prefix] = self
|
|
self.layer_name = prefix
|
|
|
|
self.enable_eplb = enable_eplb
|
|
self.expert_load_view: Optional[torch.Tensor] = None
|
|
self.logical_to_physical_map: Optional[torch.Tensor] = None
|
|
self.logical_replica_count: Optional[torch.Tensor] = None
|
|
|
|
# Determine expert maps
|
|
if self.use_ep:
|
|
if self.enable_eplb:
|
|
assert self.global_num_experts % self.ep_size == 0, \
|
|
"EPLB currently only supports even distribution of " \
|
|
"experts across ranks."
|
|
else:
|
|
assert num_redundant_experts == 0, \
|
|
"Redundant experts are only supported with EPLB."
|
|
self.local_num_experts, self.expert_map = determine_expert_map(
|
|
ep_size=self.ep_size,
|
|
ep_rank=self.ep_rank,
|
|
global_num_experts=self.global_num_experts)
|
|
else:
|
|
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
|
None)
|
|
|
|
self.top_k = top_k
|
|
|
|
assert intermediate_size % self.tp_size == 0
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
|
self.reduce_results = reduce_results
|
|
self.renormalize = renormalize
|
|
self.use_grouped_topk = use_grouped_topk
|
|
if self.use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.num_expert_group = num_expert_group
|
|
self.topk_group = topk_group
|
|
self.custom_routing_function = custom_routing_function
|
|
self.scoring_func = scoring_func
|
|
self.e_score_correction_bias = e_score_correction_bias
|
|
self.apply_router_weight_on_input = apply_router_weight_on_input
|
|
self.activation = activation
|
|
|
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
|
raise ValueError("Only softmax scoring function is supported for "
|
|
"non-grouped topk.")
|
|
if current_platform.is_hpu():
|
|
from vllm_hpu_extension.ops import DynamicFusedMOE
|
|
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
|
|
|
# Only support float8 for now.
|
|
quant_dtype = params_dtype
|
|
if quant_config is not None:
|
|
input_activations = get_quant_config_input_activations(
|
|
quant_config)
|
|
if (input_activations is not None
|
|
and input_activations.num_bits == 8
|
|
and input_activations.type == QuantizationType.FLOAT):
|
|
quant_dtype = torch.float8_e4m3fn
|
|
|
|
moe = MoEConfig(
|
|
num_experts=self.global_num_experts,
|
|
experts_per_token=top_k,
|
|
hidden_dim=hidden_size,
|
|
num_local_experts=self.local_num_experts,
|
|
moe_parallel_config=self.moe_parallel_config,
|
|
in_dtype=params_dtype,
|
|
quant_dtype=quant_dtype,
|
|
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
|
|
)
|
|
self.moe_config = moe
|
|
self.quant_config = quant_config
|
|
|
|
# Note: get_quant_method will look at the layer's local_num_experts
|
|
# for heuristic purposes, so it must be initialized first.
|
|
quant_method: Optional[QuantizeMethodBase] = None
|
|
quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None
|
|
else quant_config.get_quant_method(self, prefix))
|
|
|
|
assert quant_method is not None
|
|
assert isinstance(quant_method, FusedMoEMethodBase)
|
|
self.quant_method = quant_method
|
|
|
|
if self.enable_eplb:
|
|
from vllm.model_executor.layers.quantization.fp8 import (
|
|
Fp8MoEMethod)
|
|
if not isinstance(quant_method, Fp8MoEMethod):
|
|
# TODO: Add support for additional quantization methods.
|
|
# The implementation for other quantization methods does not
|
|
# contain essential differences, but the current quant API
|
|
# design causes duplicated work when extending to new
|
|
# quantization methods, so I'm leaving it for now.
|
|
# If you plan to add support for more quantization methods,
|
|
# please refer to the implementation in `Fp8MoEMethod`.
|
|
raise NotImplementedError("EPLB is only supported for FP8 "
|
|
"quantization for now.")
|
|
|
|
moe_quant_params = {
|
|
"num_experts": self.local_num_experts,
|
|
"hidden_size": hidden_size,
|
|
"intermediate_size_per_partition":
|
|
self.intermediate_size_per_partition,
|
|
"params_dtype": params_dtype,
|
|
"weight_loader": self.weight_loader,
|
|
}
|
|
# need full intermediate size pre-sharding for WNA16 act order
|
|
if (self.quant_method.__class__.__name__
|
|
in ("GPTQMarlinMoEMethod",
|
|
"CompressedTensorsWNA16MarlinMoEMethod",
|
|
"CompressedTensorsWNA16MoEMethod")):
|
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
|
|
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
|
|
|
# Chunked all2all staging tensor
|
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
|
self.batched_router_logits: Optional[torch.Tensor] = None
|
|
if (self.moe_parallel_config.use_pplx_kernels
|
|
or self.moe_parallel_config.use_deepep_ll_kernels):
|
|
act_dtype = vllm_config.model_config.dtype
|
|
self.batched_hidden_states = torch.zeros(
|
|
(envs.VLLM_MOE_DP_CHUNK_SIZE, self.hidden_size),
|
|
dtype=act_dtype,
|
|
device=torch.cuda.current_device())
|
|
|
|
# Note here we use `num_experts` which is logical expert count
|
|
self.batched_router_logits = torch.zeros(
|
|
(envs.VLLM_MOE_DP_CHUNK_SIZE, num_experts),
|
|
dtype=act_dtype,
|
|
device=torch.cuda.current_device())
|
|
|
|
@property
|
|
def tp_size(self):
|
|
return self.moe_parallel_config.tp_size
|
|
|
|
@property
|
|
def dp_size(self):
|
|
return self.moe_parallel_config.dp_size
|
|
|
|
@property
|
|
def ep_size(self):
|
|
return self.moe_parallel_config.ep_size
|
|
|
|
@property
|
|
def tp_rank(self):
|
|
return self.moe_parallel_config.tp_rank
|
|
|
|
@property
|
|
def dp_rank(self):
|
|
return self.moe_parallel_config.dp_rank
|
|
|
|
@property
|
|
def ep_rank(self):
|
|
return self.moe_parallel_config.ep_rank
|
|
|
|
@property
|
|
def use_ep(self):
|
|
return self.moe_parallel_config.use_ep
|
|
|
|
@property
|
|
def use_pplx_kernels(self):
|
|
return self.moe_parallel_config.use_pplx_kernels
|
|
|
|
@property
|
|
def use_deepep_ht_kernels(self):
|
|
return self.moe_parallel_config.use_deepep_ht_kernels
|
|
|
|
@property
|
|
def use_deepep_ll_kernels(self):
|
|
return self.moe_parallel_config.use_deepep_ll_kernels
|
|
|
|
def _load_per_tensor_weight_scale(self, shard_id: str,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
expert_id: int):
|
|
param_data = param.data
|
|
# for per tensor weight quantization
|
|
if shard_id in ("w1", "w3"):
|
|
# We have to keep the weight scales of w1 and w3 because
|
|
# we need to re-quantize w1/w3 weights after weight loading.
|
|
idx = 0 if shard_id == "w1" else 1
|
|
param_data[expert_id][idx] = loaded_weight
|
|
# If we are in the row parallel case (down_proj)
|
|
elif shard_id == "w2":
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_model_weight_or_group_weight_scale(self,
|
|
shard_dim: int,
|
|
expert_data: torch.Tensor,
|
|
shard_id: str,
|
|
loaded_weight: torch.Tensor,
|
|
tp_rank: int,
|
|
load_full_w2: bool = False):
|
|
"""
|
|
Load grouped weight scales for group quantization or model weights
|
|
:param shard_dim: dimension to shard
|
|
:param expert_data: parameter for a particular expert
|
|
:param shard_id: either w1, w2, or w3
|
|
:param loaded_weight: checkpoint weight to load into the param
|
|
:param tp_rank: tensor parallel rank
|
|
:param load_full_w2: whether or not the w2 loaded should be sharded.
|
|
"""
|
|
if shard_id == "w2":
|
|
# In the case where we have actorder/g_idx, we do not partition the
|
|
# w2 scales, as indicated by `load_full` argument, for all tp cases
|
|
self._load_w2(shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank,
|
|
load_full=load_full_w2)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
|
|
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
|
|
shard_dim: int, shard_id: str,
|
|
loaded_weight: torch.Tensor,
|
|
tp_rank: int):
|
|
# for per channel weight quantization
|
|
if shard_id == "w2":
|
|
expert_data.copy_(loaded_weight)
|
|
elif shard_id in ("w1", "w3"):
|
|
self._load_w13(shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
|
|
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
|
|
shard_id: str, loaded_weight: torch.Tensor, tp_rank: int):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
|
shard_size = expert_data.shape[shard_dim] // 2
|
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
|
shard_size)
|
|
# Narrow parameter and load.
|
|
# w1, gate_proj: Load into first logical weight of w13.
|
|
if shard_id == "w1":
|
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
|
# w3, up_proj: Load into second logical weight of w13.
|
|
else:
|
|
assert shard_id == "w3"
|
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_w2(self,
|
|
expert_data: torch.Tensor,
|
|
shard_dim: int,
|
|
loaded_weight: torch.Tensor,
|
|
tp_rank: int,
|
|
load_full: bool = False):
|
|
|
|
# Index the loaded weight for tp sharding.
|
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
|
# Narrow parameter and load.
|
|
shard_size = expert_data.shape[shard_dim]
|
|
if not load_full:
|
|
loaded_weight = loaded_weight.narrow(shard_dim,
|
|
shard_size * tp_rank,
|
|
shard_size)
|
|
# w2, down_proj: Load into only logical weight of w2.
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _load_single_value(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, expert_id: int):
|
|
param_data = param.data
|
|
|
|
# Input scales can be loaded directly and should be equal.
|
|
param_data[expert_id] = loaded_weight
|
|
|
|
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
|
|
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
|
|
|
|
if shard_id == "w2":
|
|
self._load_w2(shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=tp_rank)
|
|
else:
|
|
assert shard_id in ("w1", "w3")
|
|
expert_data.copy_(loaded_weight)
|
|
|
|
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
|
|
if self.expert_map is None:
|
|
return expert_id
|
|
return self.expert_map[expert_id].item()
|
|
|
|
@overload
|
|
def weight_loader(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, weight_name: str,
|
|
shard_id: str, expert_id: int,
|
|
return_success: Literal[False]) -> None:
|
|
...
|
|
|
|
@overload
|
|
def weight_loader(self, param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor, weight_name: str,
|
|
shard_id: str, expert_id: int,
|
|
return_success: Literal[True]) -> bool:
|
|
...
|
|
|
|
def weight_loader(self,
|
|
param: torch.nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
return_success: bool = False) -> Optional[bool]:
|
|
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
|
|
if expert_id == -1:
|
|
# Failed to load this param since it's not local to this rank
|
|
return False if return_success else None
|
|
# Hereafter, `expert_id` is local physical id
|
|
|
|
quant_method_name = self.quant_method.__class__.__name__
|
|
# compressed-tensors checkpoints with packed weights are stored flipped
|
|
# TODO (mgoin): check self.quant_method.quant_config.quant_format
|
|
# against known CompressionFormat enum values that have this quality
|
|
if self.quant_method.__class__.__name__ in (
|
|
"CompressedTensorsWNA16MarlinMoEMethod",
|
|
"CompressedTensorsWNA16MoEMethod"):
|
|
loaded_weight = loaded_weight.t().contiguous()
|
|
|
|
if shard_id not in ("w1", "w2", "w3"):
|
|
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
|
f"got {shard_id}.")
|
|
|
|
WEIGHT_SCALE_SUPPORTED = [
|
|
e.value for e in FusedMoeWeightScaleSupported
|
|
]
|
|
# Fetch the dim to shard the parameter/loaded weight
|
|
# based on the shard id. This will be whatever
|
|
# dimension intermediate_size_per_partition is used.
|
|
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
|
|
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
if is_gguf_weight_type:
|
|
param.weight_type = loaded_weight.item()
|
|
param.data.copy_(loaded_weight)
|
|
return True if return_success else None
|
|
|
|
# is_transposed: if the dim to shard the weight
|
|
# should be flipped. Required by GPTQ, compressed-tensors
|
|
# should be whatever dimension intermediate_size_per_partition is
|
|
is_transposed = getattr(param, "is_transposed", False)
|
|
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
|
|
if is_transposed:
|
|
shard_dim = int(not shard_dim)
|
|
|
|
full_load = len(loaded_weight.shape) == 3
|
|
if full_load:
|
|
shard_dim += 1
|
|
|
|
# Materialize GGUF UninitializedParameter
|
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
|
final_shape = list(loaded_weight.shape)
|
|
if shard_id in ["w1", "w3"]:
|
|
final_shape[1] *= 2
|
|
final_shape[shard_dim] = final_shape[shard_dim] // self.tp_size
|
|
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
|
|
|
expert_data = param.data if full_load else param.data[expert_id]
|
|
|
|
# Case input scale: input_scale loading is only supported for fp8
|
|
if "input_scale" in weight_name:
|
|
# this is needed for compressed-tensors only
|
|
loaded_weight = loaded_weight.to(param.data.device)
|
|
|
|
if ("compressed" in quant_method_name.lower()
|
|
and param.data[expert_id] != 1
|
|
and (param.data[expert_id] - loaded_weight).abs() > 1e-5):
|
|
raise ValueError(
|
|
"input_scales of w1 and w3 of a layer "
|
|
f"must be equal. But got {param.data[expert_id]} "
|
|
f"vs. {loaded_weight}")
|
|
|
|
self._load_single_value(param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
return True if return_success else None
|
|
|
|
# Case g_idx
|
|
if "g_idx" in weight_name:
|
|
self._load_g_idx(shard_dim=0,
|
|
shard_id=shard_id,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=self.tp_rank)
|
|
return True if return_success else None
|
|
|
|
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
|
if "ModelOpt" in quant_method_name:
|
|
if ('weight_scale_2' in weight_name
|
|
or 'input_scale' in weight_name):
|
|
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
|
param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
elif "weight" in weight_name:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=self.tp_rank)
|
|
return True if return_success else None
|
|
|
|
# Case weight scales, zero_points and offset, weight/input global scales
|
|
if ("scale" in weight_name or "zero" in weight_name
|
|
or "offset" in weight_name):
|
|
# load the weight scales and zp based on the quantization scheme
|
|
# supported weight scales/zp can be found in
|
|
# FusedMoeWeightScaleSupported
|
|
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
|
|
# specific to each case
|
|
quant_method = getattr(param, "quant_method", None)
|
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
|
self._load_per_channel_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=self.tp_rank)
|
|
elif quant_method in [
|
|
FusedMoeWeightScaleSupported.GROUP.value,
|
|
FusedMoeWeightScaleSupported.BLOCK.value,
|
|
]:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=self.tp_rank,
|
|
load_full_w2=getattr(param, "load_full_w2", False))
|
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
|
self._load_per_tensor_weight_scale(shard_id=shard_id,
|
|
param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
else:
|
|
raise ValueError(
|
|
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
|
|
return True if return_success else None
|
|
|
|
# Case weight_shape
|
|
if "weight_shape" in weight_name:
|
|
# only required by compressed-tensors
|
|
self._load_single_value(param=param,
|
|
loaded_weight=loaded_weight,
|
|
expert_id=expert_id)
|
|
return True if return_success else None
|
|
|
|
# Case model weights
|
|
if "weight" in weight_name:
|
|
self._load_model_weight_or_group_weight_scale(
|
|
shard_id=shard_id,
|
|
shard_dim=shard_dim,
|
|
loaded_weight=loaded_weight,
|
|
expert_data=expert_data,
|
|
tp_rank=self.tp_rank)
|
|
return True if return_success else None
|
|
|
|
return False if return_success else None
|
|
|
|
def get_expert_weights(self) -> Iterable[torch.Tensor]:
|
|
weights = list(self.named_parameters())
|
|
assert all(weight.is_contiguous() for _, weight in weights)
|
|
|
|
# Filter out the non-expert weights.
|
|
# `e_score_correction_bias` is a bias for each logical expert,
|
|
# with shape (num_logical_experts,), not an expert weight.
|
|
NON_EXPERT_WEIGHTS = {
|
|
"e_score_correction_bias",
|
|
}
|
|
|
|
return [
|
|
weight.view(self.local_num_experts, -1) for name, weight in weights
|
|
if name not in NON_EXPERT_WEIGHTS
|
|
]
|
|
|
|
def set_eplb_state(
|
|
self,
|
|
moe_layer_idx: int,
|
|
expert_load_view: torch.Tensor,
|
|
logical_to_physical_map: torch.Tensor,
|
|
logical_replica_count: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Register the EPLB state in this layer.
|
|
|
|
This is used later in forward pass, where we get the expert mapping
|
|
and record the load metrics in `expert_load_view`.
|
|
"""
|
|
self.expert_load_view = expert_load_view[moe_layer_idx]
|
|
self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx]
|
|
self.logical_replica_count = logical_replica_count[moe_layer_idx]
|
|
|
|
@staticmethod
|
|
def select_experts(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
indices_type: Optional[torch.dtype] = None,
|
|
enable_eplb: bool = False,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Route the input hidden states to the top-k experts based on the
|
|
router logits.
|
|
|
|
Returns:
|
|
(topk_weights, topk_ids) (tuple[torch.Tensor, torch.Tensor]):
|
|
The weights and *global physical* expert ids of the top-k experts.
|
|
|
|
**Compatibility**: When EPLB is not enabled, the returned ids are
|
|
equivalent to global logical ids, so should be compatible with
|
|
plain MoE implementations without redundant experts.
|
|
"""
|
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
|
|
|
# DeepSeekv2 uses grouped_top_k
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
topk_weights, topk_ids = grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias)
|
|
if indices_type is not None:
|
|
topk_ids = topk_ids.to(dtype=indices_type)
|
|
elif custom_routing_function is None:
|
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
indices_type=indices_type,
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize)
|
|
if indices_type is not None:
|
|
topk_ids = topk_ids.to(dtype=indices_type)
|
|
|
|
if enable_eplb:
|
|
assert expert_load_view is not None
|
|
assert logical_to_physical_map is not None
|
|
assert logical_replica_count is not None
|
|
|
|
# 1. Convert the logical expert ids to physical expert ids
|
|
# Directly select a random replica for each logical expert
|
|
|
|
# TODO: maybe optimize this by using specified kernels,
|
|
# or compute pseudo-random indices by modulo
|
|
|
|
# In case `indices_type` is not `torch.long` or `torch.int`,
|
|
# e.g. `torch.uint32` as required by dispatch/combine kernels
|
|
topk_ids_long = topk_ids.long()
|
|
replica_indices = (
|
|
torch.rand_like(topk_ids, dtype=torch.float) *
|
|
logical_replica_count[topk_ids_long]).long().unsqueeze(-1)
|
|
physical_ids = logical_to_physical_map[topk_ids_long].gather(
|
|
-1, replica_indices).squeeze(-1)
|
|
|
|
topk_ids = physical_ids
|
|
|
|
# 2. Record expert load metrics.
|
|
|
|
# TODO(bowen): When using `FusedMoEModularKernel`, this
|
|
# can be done in a more unified way, since
|
|
# `FusedMoEPrepareAndFinalize` will return the expert
|
|
# token count, in some cases directly from the kernel.
|
|
# However, now there are many code paths not using
|
|
# the modular kernel, e.g. calling `fused_experts`,
|
|
# so we decide to keep the logic here.
|
|
#
|
|
# If later refactor moved all the MoE kernel calls
|
|
# to the modular kernel, we can move this logic there
|
|
# to achieve better efficiency.
|
|
|
|
# `expert_load_view`: (num_logical_experts,)
|
|
|
|
# Mask out non-local experts
|
|
if expert_map is not None:
|
|
topk_ids_local = expert_map[topk_ids]
|
|
topk_ids_flatten = topk_ids_local.flatten()
|
|
else:
|
|
topk_ids_flatten = topk_ids.flatten()
|
|
|
|
# Should be equivalent to:
|
|
# ```
|
|
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
|
|
# expert_load_view += topk_ids_masked.bincount(
|
|
# minlength=expert_load_view.shape[0])
|
|
# ```
|
|
# We use `scatter_add_` since `bincount` cannot be compiled
|
|
|
|
# Performance optimization:
|
|
# `masked_fill` is significantly faster than `masked_select`
|
|
invalid_mask = topk_ids_flatten < 0
|
|
# Replace invalid expert ids with 0 (just a dummy position)
|
|
# to avoid out-of-bounds errors in scatter_add_
|
|
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
|
|
# `src` is the valid mask, which is 1 for valid and 0 for invalid
|
|
src = ~invalid_mask
|
|
|
|
expert_load_view.scatter_add_(dim=0,
|
|
index=index.long(),
|
|
src=src.to(expert_load_view))
|
|
|
|
topk_ids = topk_ids.to(dtype=indices_type)
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
def must_reduce_shared_expert_outputs(self) -> bool:
|
|
"""
|
|
The shared_experts are typically computed using the RowParallelLinear
|
|
layer. The result of this function is typically used as
|
|
the reduce_results argument to the module.
|
|
When just tensor-parallel is used, it is not required to reduce
|
|
the shared_experts results immediately. Instead we reduce at the
|
|
once at the end of the MoE op. (Refer to DeepSeekV2MoE module)
|
|
With EP and all2all kernels - this is no longer viable as all
|
|
GPU ranks in DP, produce the complete set of hidden_states.
|
|
Therefore it is required that we reduce the shared_experts output
|
|
early.
|
|
"""
|
|
return (self.use_pplx_kernels or self.use_deepep_ht_kernels
|
|
or self.use_deepep_ll_kernels)
|
|
|
|
def maybe_all_reduce_tensor_model_parallel(
|
|
self, final_hidden_states: torch.Tensor):
|
|
"""
|
|
The pplx combine kernel reduces across GPU ranks by default.
|
|
"""
|
|
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
|
|
or self.use_deepep_ll_kernels):
|
|
return final_hidden_states
|
|
else:
|
|
return tensor_model_parallel_all_reduce(final_hidden_states)
|
|
|
|
def forward(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
|
|
self.layer_name)
|
|
|
|
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
|
|
full_router_logits: torch.Tensor):
|
|
assert self.batched_hidden_states is not None
|
|
assert self.batched_router_logits is not None
|
|
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
|
|
assert self.batched_router_logits.dtype == full_router_logits.dtype
|
|
# Check size compatibility.
|
|
assert (
|
|
self.batched_hidden_states.size(-1) == full_hidden_states.size(-1))
|
|
assert (
|
|
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
|
|
|
|
full_final_hidden_states = torch.empty_like(full_hidden_states)
|
|
|
|
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
|
|
chunk_size = chunk_end - chunk_start
|
|
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
|
|
router_logits = full_router_logits[chunk_start:chunk_end, :]
|
|
|
|
assert (self.batched_hidden_states.size(0) # type: ignore
|
|
>= chunk_size)
|
|
assert (self.batched_router_logits.size(0) # type: ignore
|
|
>= chunk_size)
|
|
staged_hidden_states = self.batched_hidden_states[:
|
|
chunk_size, :] # type: ignore
|
|
staged_router_logits = self.batched_router_logits[:
|
|
chunk_size, :] # type: ignore
|
|
staged_hidden_states.copy_(hidden_states, non_blocking=True)
|
|
staged_router_logits.copy_(router_logits, non_blocking=True)
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=staged_hidden_states,
|
|
router_logits=staged_router_logits,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
global_num_experts=self.global_num_experts,
|
|
expert_map=self.expert_map,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
scoring_func=self.scoring_func,
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
activation=self.activation,
|
|
enable_eplb=self.enable_eplb,
|
|
expert_load_view=self.expert_load_view,
|
|
logical_to_physical_map=self.logical_to_physical_map,
|
|
logical_replica_count=self.logical_replica_count,
|
|
)
|
|
|
|
if not skip_result_store:
|
|
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
|
|
final_hidden_states, non_blocking=True)
|
|
|
|
ctx = get_forward_context()
|
|
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
|
|
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
|
|
|
|
num_tokens = full_hidden_states.size(0)
|
|
for chunk_start_ in range(0, max_tokens_across_dp,
|
|
moe_dp_chunk_size_per_rank):
|
|
chunk_start = chunk_start_
|
|
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
|
|
max_tokens_across_dp)
|
|
# clamp start and end
|
|
chunk_start = min(chunk_start, num_tokens - 1)
|
|
chunk_end = min(chunk_end, num_tokens)
|
|
|
|
process_chunk(chunk_start,
|
|
chunk_end,
|
|
skip_result_store=chunk_start_ >= num_tokens)
|
|
|
|
return full_final_hidden_states
|
|
|
|
def forward_impl(self, hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor):
|
|
assert self.quant_method is not None
|
|
if (self.moe_parallel_config.use_pplx_kernels
|
|
or self.moe_parallel_config.use_deepep_ll_kernels):
|
|
return self.forward_impl_chunked(hidden_states, router_logits)
|
|
|
|
do_naive_dispatch_combine: bool = (
|
|
self.dp_size > 1
|
|
and not self.moe_parallel_config.use_deepep_ht_kernels)
|
|
if do_naive_dispatch_combine:
|
|
hidden_states, router_logits = get_ep_group().dispatch(
|
|
hidden_states, router_logits)
|
|
|
|
# Matrix multiply.
|
|
final_hidden_states = self.quant_method.apply(
|
|
layer=self,
|
|
x=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
renormalize=self.renormalize,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
global_num_experts=self.global_num_experts,
|
|
expert_map=self.expert_map,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
custom_routing_function=self.custom_routing_function,
|
|
scoring_func=self.scoring_func,
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
|
activation=self.activation,
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
|
enable_eplb=self.enable_eplb,
|
|
expert_load_view=self.expert_load_view,
|
|
logical_to_physical_map=self.logical_to_physical_map,
|
|
logical_replica_count=self.logical_replica_count,
|
|
)
|
|
|
|
if do_naive_dispatch_combine:
|
|
final_hidden_states = get_ep_group().combine(final_hidden_states)
|
|
|
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
|
# Default set to False. (May have to add shared expert outputs.
|
|
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
|
|
final_hidden_states)
|
|
|
|
return final_hidden_states
|
|
|
|
@classmethod
|
|
def make_expert_params_mapping(
|
|
cls,
|
|
ckpt_gate_proj_name: str,
|
|
ckpt_down_proj_name: str,
|
|
ckpt_up_proj_name: str,
|
|
num_experts: int,
|
|
num_redundant_experts: int = 0) -> list[tuple[str, str, int, str]]:
|
|
|
|
num_physical_experts = num_experts + num_redundant_experts
|
|
|
|
# In the returned mapping:
|
|
# - `expert_id` is the physical expert id
|
|
# - `weight_name` contains the weight name of the logical expert
|
|
# So that we should map the expert id to logical in `weight_name`
|
|
physical_to_logical_map = \
|
|
EplbState.build_initial_global_physical_to_logical_map(
|
|
num_experts, num_redundant_experts)
|
|
|
|
return [
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
("experts.w13_" if weight_name
|
|
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
|
|
f"experts.{physical_to_logical_map[expert_id]}.{weight_name}.",
|
|
expert_id, shard_id) for expert_id in range(num_physical_experts)
|
|
for shard_id, weight_name in [
|
|
("w1", ckpt_gate_proj_name),
|
|
("w2", ckpt_down_proj_name),
|
|
("w3", ckpt_up_proj_name),
|
|
]
|
|
]
|
|
|
|
def extra_repr(self) -> str:
|
|
|
|
s = (
|
|
f"global_num_experts={self.global_num_experts}, "
|
|
f"local_num_experts={self.local_num_experts}, "
|
|
f"top_k={self.top_k}, "
|
|
f"intermediate_size_per_partition={self.intermediate_size_per_partition}, " # noqa: E501
|
|
f"tp_size={self.tp_size},\n"
|
|
f"ep_size={self.ep_size}, "
|
|
f"reduce_results={self.reduce_results}, "
|
|
f"renormalize={self.renormalize}, "
|
|
f"use_grouped_topk={self.use_grouped_topk}")
|
|
|
|
if self.use_grouped_topk:
|
|
s += f", num_expert_group={self.num_expert_group}, topk_group={self.topk_group}" # noqa: E501
|
|
|
|
s += f", scoring_func='{self.scoring_func}', activation='{self.activation}'" # noqa: E501
|
|
|
|
return s
|
|
|
|
|
|
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
|
layer_name: str) -> torch.Tensor:
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.no_compile_layers[layer_name]
|
|
assert self.quant_method is not None
|
|
|
|
return self.forward_impl(hidden_states, router_logits)
|
|
|
|
|
|
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
|
|
layer_name: str) -> torch.Tensor:
|
|
return torch.empty_like(hidden_states)
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="moe_forward",
|
|
op_func=moe_forward,
|
|
mutates_args=["hidden_states"],
|
|
fake_impl=moe_forward_fake,
|
|
dispatch_key=current_platform.dispatch_key,
|
|
tags=(torch.Tag.needs_fixed_stride_order, ),
|
|
)
|