mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Granite-4 support loading quantized checkpoint (#22925)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
@ -471,7 +471,10 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
# Mapping different experts' layout:
|
||||
# from HF (input_linear, output_linear, router)
|
||||
# to vLLM (experts_w13({e}.w1, {e}.w2), experts_w3({e}.w3), gate)
|
||||
if n.endswith('.block_sparse_moe.input_linear.weight'):
|
||||
# The renaming and parameter loading logic is the same for weight
|
||||
# and weight_scale tensors so we can reuse them without issues.
|
||||
if (n.endswith('.block_sparse_moe.input_linear.weight') or
|
||||
n.endswith('.block_sparse_moe.input_linear.weight_scale')):
|
||||
for e in range(p.size(0)):
|
||||
w1_name = n.replace(
|
||||
'.block_sparse_moe.input_linear.weight',
|
||||
@ -490,7 +493,8 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
w3_name,
|
||||
shard_id='w3',
|
||||
expert_id=e)
|
||||
elif n.endswith('.block_sparse_moe.output_linear.weight'):
|
||||
elif (n.endswith('.block_sparse_moe.output_linear.weight') or
|
||||
n.endswith('.block_sparse_moe.output_linear.weight_scale')):
|
||||
for e in range(p.size(0)):
|
||||
w2_name = n.replace(
|
||||
'.block_sparse_moe.output_linear.weight',
|
||||
|
Reference in New Issue
Block a user