mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[fix]: add Arm 4bit fused moe support (#23809)
Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
This commit is contained in:
@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/mla_decode.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
"csrc/cpu/torch_bindings.cpp"
|
||||
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
|
||||
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
set(VLLM_EXT_SRC
|
||||
|
@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
|
||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||
|
||||
ops.def(
|
||||
"dynamic_4bit_int_moe("
|
||||
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
|
||||
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
|
||||
"int group_size, bool apply_router_weight_on_input, int activation_kind"
|
||||
") -> Tensor");
|
||||
|
||||
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
|
||||
|
||||
// PagedAttention V2.
|
||||
ops.def(
|
||||
"paged_attention_v2("
|
||||
|
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
156
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
Normal file
@ -0,0 +1,156 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
// _dyn_quant_matmul_4bit is only available on AArch64.
|
||||
#if defined(__aarch64__)
|
||||
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
|
||||
#endif
|
||||
|
||||
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
|
||||
int64_t group_size_eff, int64_t in_features,
|
||||
int64_t out_features) {
|
||||
#if defined(__aarch64__)
|
||||
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
|
||||
in_features, out_features);
|
||||
#else
|
||||
TORCH_CHECK(false,
|
||||
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
|
||||
"_dyn_quant_matmul_4bit is unavailable on this architecture");
|
||||
return {};
|
||||
#endif
|
||||
}
|
||||
|
||||
enum ActivationKind : int64_t {
|
||||
SwiGLU_Gu = 0, // act = SiLU(g) * u
|
||||
SwiGLUOAI = 1, // act = SiLU(u) * g
|
||||
SiLU = 2 // SiLU
|
||||
};
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||
int64_t activation_kind) {
|
||||
TORCH_CHECK(x.dim() == 2, "x must be 2D");
|
||||
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
|
||||
"topk tensors must be [T, K]");
|
||||
TORCH_CHECK(
|
||||
w13_packed.size(0) == w2_packed.size(0),
|
||||
"w13_packed and w2_packed must have same number of experts in dim 0");
|
||||
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
|
||||
|
||||
const int64_t T = x.size(0);
|
||||
const int64_t K = topk_ids.size(1);
|
||||
const int64_t E = w13_packed.size(0);
|
||||
const int64_t N = T * K;
|
||||
|
||||
auto x_c = x.contiguous();
|
||||
auto ids_c = topk_ids.contiguous();
|
||||
auto gates_c = topk_weights.to(at::kFloat).contiguous();
|
||||
|
||||
// bucketing tokens -> experts
|
||||
c10::SmallVector<int64_t, 64> counts(
|
||||
E, 0); // Small vector uses stack allocation
|
||||
{
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
for (int64_t i = 0; i < N; ++i) {
|
||||
const int64_t e_id = ids_ptr[i];
|
||||
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
|
||||
counts[e_id]++;
|
||||
}
|
||||
}
|
||||
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
|
||||
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
|
||||
|
||||
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
|
||||
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
|
||||
{
|
||||
c10::SmallVector<int64_t, 64> cursor(E, 0);
|
||||
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
|
||||
const auto* gts_ptr = gates_c.data_ptr<float>();
|
||||
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
|
||||
auto* gate_ptr = expert_gates.data_ptr<float>();
|
||||
|
||||
for (int64_t t = 0; t < T; ++t) {
|
||||
const int64_t base = t * K;
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const int64_t idx = base + k;
|
||||
const int64_t e = ids_ptr[idx];
|
||||
const int64_t p = offsets[e] + (cursor[e]++);
|
||||
tok_ptr[p] = t;
|
||||
gate_ptr[p] = gts_ptr[idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
|
||||
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
|
||||
|
||||
// Per-expert outputs filled in parallel
|
||||
std::vector<torch::Tensor> y_list(E);
|
||||
y_list.resize(E);
|
||||
|
||||
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
|
||||
for (int64_t e = e_begin; e < e_end; ++e) {
|
||||
const int64_t te = counts[e];
|
||||
if (te == 0) {
|
||||
y_list[e] = at::empty({0, H}, x_c.options());
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t start = offsets[e];
|
||||
|
||||
auto sel_tokens =
|
||||
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
auto gates_e =
|
||||
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
|
||||
|
||||
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
|
||||
|
||||
if (apply_router_weight_on_input) {
|
||||
x_e = x_e.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
|
||||
auto w13_e = w13_packed.select(/*dim=*/0, e);
|
||||
auto w2_e = w2_packed.select(/*dim=*/0, e);
|
||||
|
||||
// W13
|
||||
auto y13 =
|
||||
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
|
||||
|
||||
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
|
||||
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
|
||||
|
||||
torch::Tensor act;
|
||||
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
|
||||
constexpr double kAlpha = 1.702; // GPT-OSS default
|
||||
constexpr double kLimit = 7.0; // GPT-OSS default
|
||||
auto gate_c = at::clamp_max(g_part, kLimit);
|
||||
auto up_c = at::clamp(u_part, -kLimit, kLimit);
|
||||
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
|
||||
act = up_c.add(1.0).mul(glu);
|
||||
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
|
||||
act = at::silu(g_part).mul(u_part);
|
||||
}
|
||||
|
||||
// W2
|
||||
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
|
||||
|
||||
if (!apply_router_weight_on_input) {
|
||||
y = y.mul(gates_e.unsqueeze(1));
|
||||
}
|
||||
|
||||
// Store per-expert result
|
||||
y_list[e] = y;
|
||||
}
|
||||
});
|
||||
|
||||
// Concatenate all expert outputs to match expert_tokens order
|
||||
auto Y_all = at::cat(y_list, /*dim=*/0);
|
||||
auto out = at::zeros({T, H}, x.options());
|
||||
out =
|
||||
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
|
||||
|
||||
return out;
|
||||
}
|
@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
torch::Tensor dynamic_4bit_int_moe_cpu(
|
||||
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
|
||||
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
|
||||
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
|
||||
int64_t activation_kind);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
@ -98,13 +98,16 @@ def select_experts(
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
elif custom_routing_function is None:
|
||||
assert scoring_func == "softmax"
|
||||
topk_weights = torch.nn.functional.softmax(router_logits,
|
||||
dim=1,
|
||||
dtype=torch.float32)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
|
||||
topk_logit_vals, topk_idx = torch.topk(router_logits,
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
if renormalize:
|
||||
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights, topk_ids.to(torch.int32)
|
||||
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
|
||||
else:
|
||||
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
|
||||
topk_vals = (topk_logit_vals - logZ).exp()
|
||||
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
|
||||
else:
|
||||
return custom_routing_function(hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
@ -69,8 +69,6 @@ else:
|
||||
if is_rocm_aiter_moe_enabled():
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
|
||||
rocm_aiter_grouped_topk as grouped_topk)
|
||||
elif current_platform.is_cpu():
|
||||
pass
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
|
||||
if current_platform.is_tpu():
|
||||
|
@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
is_valid_flashinfer_cutlass_fused_moe)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||
|
||||
@ -63,7 +64,7 @@ __all__ = [
|
||||
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||
"CompressedTensorsW8A8Int8MoEMethod",
|
||||
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||
"CompressedTensorsW4A4MoeMethod"
|
||||
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
|
||||
]
|
||||
|
||||
|
||||
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
|
||||
input_quant):
|
||||
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
|
||||
layer.moe_config)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||
"""
|
||||
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
|
||||
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
|
||||
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
|
||||
- Bias: Same data type as original weights
|
||||
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
|
||||
quantized inside the kernel
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
self.has_bias = self.moe.has_bias
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Validate scheme: weights=W4 (channel or group),
|
||||
# activations=dynamic TOKEN (A8)
|
||||
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
aq = self.quant_config.target_scheme_map["Linear"].get(
|
||||
"input_activations")
|
||||
|
||||
# Must be dynamic per-token activations
|
||||
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
|
||||
raise ValueError(
|
||||
"W4A8-int MoE needs dynamic per-token activation quantization."
|
||||
)
|
||||
|
||||
# Weight can be channel-wise (group_size=None) or group-wise
|
||||
self.group_size = wq.group_size if (wq.group_size is not None) else -1
|
||||
if wq.num_bits != 4:
|
||||
raise ValueError(
|
||||
"This method only supports 4-bit weights (num_bits=4).")
|
||||
|
||||
# CPU only
|
||||
if not current_platform.is_cpu():
|
||||
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
|
||||
|
||||
# Arm: check _dyn ops availability
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||
except AttributeError as err:
|
||||
raise RuntimeError(
|
||||
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
|
||||
install a newer build.""") from err
|
||||
self.static_input_scales = False # always dynamic per token
|
||||
|
||||
# ---- parameter creation ----
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Shapes per local rank (TP/EP):
|
||||
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
|
||||
# w2 : [E, H, I_local] int8
|
||||
# Scales:
|
||||
# channel-wise: group_size=-1 -> per-output-row, single scale per row
|
||||
# group-wise : group_size=g ->
|
||||
# per-output-row, (in_features/g) scales
|
||||
|
||||
E = num_experts
|
||||
H = hidden_size
|
||||
IN = intermediate_size_per_partition
|
||||
g = self.group_size
|
||||
|
||||
# Per-row scale columns
|
||||
def _n_scale_cols(in_features: int) -> int:
|
||||
return 1 if g == -1 else (in_features // g)
|
||||
|
||||
# Register unpacked int4-as-int8 weights the loader will fill.
|
||||
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w13, extra_weight_attrs)
|
||||
layer.register_parameter("w13_weight", w13)
|
||||
|
||||
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(w2, extra_weight_attrs)
|
||||
layer.register_parameter("w2_weight", w2)
|
||||
|
||||
# Register scales
|
||||
# KleidiAI groupwise kernels accepts float32 scales
|
||||
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||
|
||||
w13_s = torch.nn.Parameter(torch.ones(E,
|
||||
2 * IN,
|
||||
_n_scale_cols(H),
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w13_s, {
|
||||
"quant_method": "channel" if g == -1 else "group",
|
||||
**extra_weight_attrs
|
||||
})
|
||||
layer.register_parameter("w13_weight_scale", w13_s)
|
||||
|
||||
w2_s = torch.nn.Parameter(torch.ones(E,
|
||||
H,
|
||||
_n_scale_cols(IN),
|
||||
dtype=scale_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(
|
||||
w2_s, {
|
||||
"quant_method": "channel" if g == -1 else "group",
|
||||
**extra_weight_attrs
|
||||
})
|
||||
layer.register_parameter("w2_weight_scale", w2_s)
|
||||
|
||||
if self.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(E,
|
||||
2 * IN,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
# Placeholders for packed weights (will be replaced after packing)
|
||||
layer.register_parameter(
|
||||
"w13_weight_packed",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
|
||||
|
||||
layer.register_parameter(
|
||||
"w2_weight_packed",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
|
||||
|
||||
# dims for 4 bit fused matmuls
|
||||
layer.w13_in_features = H
|
||||
layer.w13_out_features = 2 * IN
|
||||
layer.w2_in_features = IN
|
||||
layer.w2_out_features = H
|
||||
layer.group_size = g
|
||||
|
||||
# post-load packing to dyn-4bit KleidiAI kernel's format
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
E = layer.w13_weight.shape[0]
|
||||
H = layer.w13_in_features
|
||||
I2 = layer.w13_out_features
|
||||
IN = layer.w2_in_features
|
||||
g = layer.group_size
|
||||
|
||||
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
|
||||
scales_2d: torch.Tensor,
|
||||
bias_1d: Optional[torch.Tensor], in_features: int,
|
||||
out_features: int) -> torch.Tensor:
|
||||
# int4 values are stored as int8 in [-8,7].
|
||||
# Shift to unsigned nibble and pack pairs along input-dim.
|
||||
tmp = int4_as_int8_2d.add(8) # [out, in]
|
||||
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
|
||||
torch.uint8) # [out, in//2]
|
||||
|
||||
# KleidiAI groupwise kernels accepts float32 scales
|
||||
# KleidiAI groupwise kernels accepts bfloat16 scales
|
||||
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
|
||||
scales = scales_2d.to(scale_dtype)
|
||||
bias = None if bias_1d is None else bias_1d.to(torch.float32)
|
||||
return torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_nibbles, scales, bias, g if g != -1 else in_features,
|
||||
in_features, out_features)
|
||||
|
||||
# Pack per expert
|
||||
w13_packed_list = []
|
||||
w2_packed_list = []
|
||||
|
||||
has_w13_bias = hasattr(layer,
|
||||
"w13_bias") and layer.w13_bias is not None
|
||||
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
|
||||
|
||||
for e in range(E):
|
||||
w13_packed_list.append(
|
||||
_pack_matrix(
|
||||
layer.w13_weight[e], # [2I, H]
|
||||
layer.w13_weight_scale[e], # [2I, H/g or 1]
|
||||
layer.w13_bias[e] if has_w13_bias else None, # [2I]
|
||||
H,
|
||||
I2))
|
||||
w2_packed_list.append(
|
||||
_pack_matrix(
|
||||
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
|
||||
layer.w2_weight[e], # [H, IN]
|
||||
layer.w2_weight_scale[e], # [H, IN/g or 1]
|
||||
layer.w2_bias[e] if has_w2_bias else None, # [H]
|
||||
IN,
|
||||
layer.w2_out_features # in_features=IN, out_features=H
|
||||
))
|
||||
|
||||
# each packed tensor has identical shape per expert; stack on dim 0
|
||||
w13_packed = torch.stack(w13_packed_list, dim=0)
|
||||
w2_packed = torch.stack(w2_packed_list, dim=0)
|
||||
|
||||
replace_parameter(layer, "w13_weight_packed",
|
||||
torch.nn.Parameter(w13_packed, requires_grad=False))
|
||||
replace_parameter(layer, "w2_weight_packed",
|
||||
torch.nn.Parameter(w2_packed, requires_grad=False))
|
||||
|
||||
# free raw tensors/scales/bias now that they're packed into the payload.
|
||||
replace_parameter(
|
||||
layer, "w13_weight",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w2_weight",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w13_weight_scale",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
replace_parameter(
|
||||
layer, "w2_weight_scale",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
if has_w13_bias:
|
||||
replace_parameter(
|
||||
layer, "w13_bias",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
if has_w2_bias:
|
||||
replace_parameter(
|
||||
layer, "w2_bias",
|
||||
torch.nn.Parameter(torch.empty(0), requires_grad=False))
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
# CPU dynamic 4-bit MoE path does not use modular kernels or
|
||||
# fused_experts; quant config is not needed.
|
||||
return None
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
|
||||
assert activation in (
|
||||
"silu", "swigluoai",
|
||||
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
|
||||
assert expert_map is None, """expert_map/EP not implemented
|
||||
for CPU dyn-4bit MoE."""
|
||||
|
||||
def _act_kind(s: str) -> int:
|
||||
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
|
||||
if s == "swiglu":
|
||||
return 0
|
||||
if s == "swigluoai":
|
||||
return 1
|
||||
if s == "silu":
|
||||
return 2
|
||||
raise ValueError(f"Unknown activation '{s}'")
|
||||
|
||||
# Apply topk softmax on router output
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops._C.dynamic_4bit_int_moe(
|
||||
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
|
||||
layer.w2_weight_packed, layer.w2_out_features,
|
||||
layer.w2_in_features, layer.w13_out_features, layer.group_size,
|
||||
apply_router_weight_on_input, int(_act_kind(activation)))
|
Reference in New Issue
Block a user