mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models (#5460)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -299,4 +299,10 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
|
||||
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
|
||||
weight_loader(param, loaded_weight, 'q')
|
||||
weight_loader(param, loaded_weight, 'k')
|
||||
weight_loader(param, loaded_weight, 'v')
|
||||
else:
|
||||
weight_loader(param, loaded_weight)
|
||||
|
Reference in New Issue
Block a user