[fix]: add Arm 4bit fused moe support (#23809)

Signed-off-by: Nikhil Gupta <nikhil.gupta2@arm.com>
This commit is contained in:
Nikhil Gupta
2025-09-24 02:32:22 +01:00
committed by GitHub
parent 9df8da548e
commit 359d293006
7 changed files with 488 additions and 11 deletions

View File

@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")
"csrc/cpu/torch_bindings.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC

View File

@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()");
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
ops.def(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
"int group_size, bool apply_router_weight_on_input, int activation_kind"
") -> Tensor");
ops.impl("dynamic_4bit_int_moe", torch::kCPU, &dynamic_4bit_int_moe_cpu);
// PagedAttention V2.
ops.def(
"paged_attention_v2("

View File

@ -0,0 +1,156 @@
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/all.h>
// _dyn_quant_matmul_4bit is only available on AArch64.
#if defined(__aarch64__)
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
#endif
inline torch::Tensor mm(const torch::Tensor& a, const torch::Tensor& packed_w,
int64_t group_size_eff, int64_t in_features,
int64_t out_features) {
#if defined(__aarch64__)
return at::_ops::_dyn_quant_matmul_4bit::call(a, packed_w, group_size_eff,
in_features, out_features);
#else
TORCH_CHECK(false,
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
"_dyn_quant_matmul_4bit is unavailable on this architecture");
return {};
#endif
}
enum ActivationKind : int64_t {
SwiGLU_Gu = 0, // act = SiLU(g) * u
SwiGLUOAI = 1, // act = SiLU(u) * g
SiLU = 2 // SiLU
};
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
int64_t activation_kind) {
TORCH_CHECK(x.dim() == 2, "x must be 2D");
TORCH_CHECK(topk_ids.dim() == 2 && topk_weights.dim() == 2,
"topk tensors must be [T, K]");
TORCH_CHECK(
w13_packed.size(0) == w2_packed.size(0),
"w13_packed and w2_packed must have same number of experts in dim 0");
TORCH_CHECK(I2 == 2 * I, "I2 must equal 2*I");
const int64_t T = x.size(0);
const int64_t K = topk_ids.size(1);
const int64_t E = w13_packed.size(0);
const int64_t N = T * K;
auto x_c = x.contiguous();
auto ids_c = topk_ids.contiguous();
auto gates_c = topk_weights.to(at::kFloat).contiguous();
// bucketing tokens -> experts
c10::SmallVector<int64_t, 64> counts(
E, 0); // Small vector uses stack allocation
{
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
for (int64_t i = 0; i < N; ++i) {
const int64_t e_id = ids_ptr[i];
TORCH_CHECK(0 <= e_id && e_id < E, "expert id out of range");
counts[e_id]++;
}
}
c10::SmallVector<int64_t, 65> offsets(E + 1, 0); // ( E +1 )
for (int64_t e = 0; e < E; ++e) offsets[e + 1] = offsets[e] + counts[e];
auto expert_tokens = at::empty({offsets[E]}, ids_c.options());
auto expert_gates = at::empty({offsets[E]}, gates_c.options());
{
c10::SmallVector<int64_t, 64> cursor(E, 0);
const auto* ids_ptr = ids_c.data_ptr<int64_t>();
const auto* gts_ptr = gates_c.data_ptr<float>();
auto* tok_ptr = expert_tokens.data_ptr<int64_t>();
auto* gate_ptr = expert_gates.data_ptr<float>();
for (int64_t t = 0; t < T; ++t) {
const int64_t base = t * K;
for (int64_t k = 0; k < K; ++k) {
const int64_t idx = base + k;
const int64_t e = ids_ptr[idx];
const int64_t p = offsets[e] + (cursor[e]++);
tok_ptr[p] = t;
gate_ptr[p] = gts_ptr[idx];
}
}
}
const int64_t g_eff_13 = (group_size != -1) ? group_size : H;
const int64_t g_eff_2 = (group_size != -1) ? group_size : I;
// Per-expert outputs filled in parallel
std::vector<torch::Tensor> y_list(E);
y_list.resize(E);
at::parallel_for(0, E, 1, [&](int64_t e_begin, int64_t e_end) {
for (int64_t e = e_begin; e < e_end; ++e) {
const int64_t te = counts[e];
if (te == 0) {
y_list[e] = at::empty({0, H}, x_c.options());
continue;
}
const int64_t start = offsets[e];
auto sel_tokens =
expert_tokens.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto gates_e =
expert_gates.narrow(/*dim=*/0, /*start=*/start, /*length=*/te);
auto x_e = x_c.index_select(/*dim=*/0, sel_tokens);
if (apply_router_weight_on_input) {
x_e = x_e.mul(gates_e.unsqueeze(1));
}
auto w13_e = w13_packed.select(/*dim=*/0, e);
auto w2_e = w2_packed.select(/*dim=*/0, e);
// W13
auto y13 =
mm(x_e, w13_e, g_eff_13, /*in_features=*/H, /*out_features=*/I2);
auto g_part = y13.narrow(/*dim=*/1, /*start=*/0, /*length=*/I);
auto u_part = y13.narrow(/*dim=*/1, /*start=*/I, /*length=*/I);
torch::Tensor act;
if (activation_kind == ActivationKind::SwiGLUOAI) { // SwiGLUOAI
constexpr double kAlpha = 1.702; // GPT-OSS default
constexpr double kLimit = 7.0; // GPT-OSS default
auto gate_c = at::clamp_max(g_part, kLimit);
auto up_c = at::clamp(u_part, -kLimit, kLimit);
auto glu = gate_c.mul(at::sigmoid(gate_c.mul(kAlpha)));
act = up_c.add(1.0).mul(glu);
} else { // SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act = at::silu(g_part).mul(u_part);
}
// W2
auto y = mm(act, w2_e, g_eff_2, /*in_features=*/I, /*out_features=*/H);
if (!apply_router_weight_on_input) {
y = y.mul(gates_e.unsqueeze(1));
}
// Store per-expert result
y_list[e] = y;
}
});
// Concatenate all expert outputs to match expert_tokens order
auto Y_all = at::cat(y_list, /*dim=*/0);
auto out = at::zeros({T, H}, x.options());
out =
at::index_add(out, /*dim=*/0, /*index=*/expert_tokens, /*source=*/Y_all);
return out;
}

View File

@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const std::optional<torch::Tensor>& has_initial_state,
const torch::Tensor& ssm_states, int64_t pad_slot_id);
torch::Tensor dynamic_4bit_int_moe_cpu(
torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights,
torch::Tensor w13_packed, torch::Tensor w2_packed, int64_t H, int64_t I,
int64_t I2, int64_t group_size, bool apply_router_weight_on_input,
int64_t activation_kind);
using fptr_t = int64_t;
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
torch::Tensor& rank_data, int64_t rank,

View File

@ -98,13 +98,16 @@ def select_experts(
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
assert scoring_func == "softmax"
topk_weights = torch.nn.functional.softmax(router_logits,
dim=1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(topk_weights, top_k, dim=-1)
topk_logit_vals, topk_idx = torch.topk(router_logits,
k=top_k,
dim=-1,
sorted=False)
if renormalize:
topk_weights /= topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
topk_vals = torch.softmax(topk_logit_vals, dim=-1)
else:
logZ = torch.logsumexp(router_logits, dim=-1, keepdim=True)
topk_vals = (topk_logit_vals - logZ).exp()
return topk_vals.to(torch.float32), topk_idx.to(torch.int32)
else:
return custom_routing_function(hidden_states=hidden_states,
gating_output=router_logits,

View File

@ -69,8 +69,6 @@ else:
if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk as grouped_topk)
elif current_platform.is_cpu():
pass
else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu():

View File

@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
int4_w4a16_moe_quant_config, int8_w8a8_moe_quant_config,
int8_w8a16_moe_quant_config, nvfp4_moe_quant_config)
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import select_experts
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
@ -63,7 +64,7 @@ __all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
"CompressedTensorsW4A4MoeMethod"
"CompressedTensorsW4A4MoeMethod", "CompressedTensorsW4A8Int8MoEMethod"
]
@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
layer.moe_config)
elif quant_config._is_dynamic_token_w4a8_int(weight_quant,
input_quant):
return CompressedTensorsW4A8Int8MoEMethod(quant_config,
layer.moe_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map=expert_map,
quant_config=self.moe_quant_config,
)
class CompressedTensorsW4A8Int8MoEMethod(CompressedTensorsMoEMethod):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
- Bias: Same data type as original weights
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
quantized inside the kernel
"""
def __init__(
self,
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
moe: FusedMoEConfig):
super().__init__(moe)
self.has_bias = self.moe.has_bias
self.quant_config = quant_config
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
wq = self.quant_config.target_scheme_map["Linear"].get("weights")
aq = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
# Must be dynamic per-token activations
if aq.strategy != QuantizationStrategy.TOKEN or not aq.dynamic:
raise ValueError(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self.group_size = wq.group_size if (wq.group_size is not None) else -1
if wq.num_bits != 4:
raise ValueError(
"This method only supports 4-bit weights (num_bits=4).")
# CPU only
if not current_platform.is_cpu():
raise ValueError("CompressedTensorsW4A8Int8MoEMethod is CPU-only.")
# Arm: check _dyn ops availability
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
_ = torch.ops.aten._dyn_quant_matmul_4bit
_ = torch.ops.aten._dyn_quant_pack_4bit_weight
except AttributeError as err:
raise RuntimeError(
f"""PyTorch {torch.__version__} lacks _dyn_quant_* 4bit ops;
install a newer build.""") from err
self.static_input_scales = False # always dynamic per token
# ---- parameter creation ----
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Shapes per local rank (TP/EP):
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
# w2 : [E, H, I_local] int8
# Scales:
# channel-wise: group_size=-1 -> per-output-row, single scale per row
# group-wise : group_size=g ->
# per-output-row, (in_features/g) scales
E = num_experts
H = hidden_size
IN = intermediate_size_per_partition
g = self.group_size
# Per-row scale columns
def _n_scale_cols(in_features: int) -> int:
return 1 if g == -1 else (in_features // g)
# Register unpacked int4-as-int8 weights the loader will fill.
w13 = torch.nn.Parameter(torch.empty(E, 2 * IN, H, dtype=torch.int8),
requires_grad=False)
set_weight_attrs(w13, extra_weight_attrs)
layer.register_parameter("w13_weight", w13)
w2 = torch.nn.Parameter(torch.empty(E, H, IN, dtype=torch.int8),
requires_grad=False)
set_weight_attrs(w2, extra_weight_attrs)
layer.register_parameter("w2_weight", w2)
# Register scales
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
w13_s = torch.nn.Parameter(torch.ones(E,
2 * IN,
_n_scale_cols(H),
dtype=scale_dtype),
requires_grad=False)
set_weight_attrs(
w13_s, {
"quant_method": "channel" if g == -1 else "group",
**extra_weight_attrs
})
layer.register_parameter("w13_weight_scale", w13_s)
w2_s = torch.nn.Parameter(torch.ones(E,
H,
_n_scale_cols(IN),
dtype=scale_dtype),
requires_grad=False)
set_weight_attrs(
w2_s, {
"quant_method": "channel" if g == -1 else "group",
**extra_weight_attrs
})
layer.register_parameter("w2_weight_scale", w2_s)
if self.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(E,
2 * IN,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
# Placeholders for packed weights (will be replaced after packing)
layer.register_parameter(
"w13_weight_packed",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
set_weight_attrs(layer.w13_weight_packed, extra_weight_attrs)
layer.register_parameter(
"w2_weight_packed",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
set_weight_attrs(layer.w2_weight_packed, extra_weight_attrs)
# dims for 4 bit fused matmuls
layer.w13_in_features = H
layer.w13_out_features = 2 * IN
layer.w2_in_features = IN
layer.w2_out_features = H
layer.group_size = g
# post-load packing to dyn-4bit KleidiAI kernel's format
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E = layer.w13_weight.shape[0]
H = layer.w13_in_features
I2 = layer.w13_out_features
IN = layer.w2_in_features
g = layer.group_size
def _pack_matrix(int4_as_int8_2d: torch.Tensor,
scales_2d: torch.Tensor,
bias_1d: Optional[torch.Tensor], in_features: int,
out_features: int) -> torch.Tensor:
# int4 values are stored as int8 in [-8,7].
# Shift to unsigned nibble and pack pairs along input-dim.
tmp = int4_as_int8_2d.add(8) # [out, in]
uint8_nibbles = ((tmp[:, 1::2] << 4) | tmp[:, ::2]).to(
torch.uint8) # [out, in//2]
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype = torch.float32 if g == -1 else torch.bfloat16
scales = scales_2d.to(scale_dtype)
bias = None if bias_1d is None else bias_1d.to(torch.float32)
return torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_nibbles, scales, bias, g if g != -1 else in_features,
in_features, out_features)
# Pack per expert
w13_packed_list = []
w2_packed_list = []
has_w13_bias = hasattr(layer,
"w13_bias") and layer.w13_bias is not None
has_w2_bias = hasattr(layer, "w2_bias") and layer.w2_bias is not None
for e in range(E):
w13_packed_list.append(
_pack_matrix(
layer.w13_weight[e], # [2I, H]
layer.w13_weight_scale[e], # [2I, H/g or 1]
layer.w13_bias[e] if has_w13_bias else None, # [2I]
H,
I2))
w2_packed_list.append(
_pack_matrix(
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
layer.w2_weight[e], # [H, IN]
layer.w2_weight_scale[e], # [H, IN/g or 1]
layer.w2_bias[e] if has_w2_bias else None, # [H]
IN,
layer.w2_out_features # in_features=IN, out_features=H
))
# each packed tensor has identical shape per expert; stack on dim 0
w13_packed = torch.stack(w13_packed_list, dim=0)
w2_packed = torch.stack(w2_packed_list, dim=0)
replace_parameter(layer, "w13_weight_packed",
torch.nn.Parameter(w13_packed, requires_grad=False))
replace_parameter(layer, "w2_weight_packed",
torch.nn.Parameter(w2_packed, requires_grad=False))
# free raw tensors/scales/bias now that they're packed into the payload.
replace_parameter(
layer, "w13_weight",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w2_weight",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w13_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
replace_parameter(
layer, "w2_weight_scale",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
if has_w13_bias:
replace_parameter(
layer, "w13_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
if has_w2_bias:
replace_parameter(
layer, "w2_bias",
torch.nn.Parameter(torch.empty(0), requires_grad=False))
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
# CPU dynamic 4-bit MoE path does not use modular kernels or
# fused_experts; quant config is not needed.
return None
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert not enable_eplb, "EPLB not supported for W4A8-int MoE yet."
assert activation in (
"silu", "swigluoai",
"swiglu"), "Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert expert_map is None, """expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def _act_kind(s: str) -> int:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if s == "swiglu":
return 0
if s == "swigluoai":
return 1
if s == "silu":
return 2
raise ValueError(f"Unknown activation '{s}'")
# Apply topk softmax on router output
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return torch.ops._C.dynamic_4bit_int_moe(
x, topk_ids.to(torch.long), topk_weights, layer.w13_weight_packed,
layer.w2_weight_packed, layer.w2_out_features,
layer.w2_in_features, layer.w13_out_features, layer.group_size,
apply_router_weight_on_input, int(_act_kind(activation)))