[Bugfix] Enable loading FP8 checkpoints for gpt_bigcode models (#5460)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2024-06-14 22:28:11 +02:00
committed by GitHub
parent 6e2527a7cb
commit e2afb03c92

View File

@ -299,4 +299,10 @@ class GPTBigCodeForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, 'q')
weight_loader(param, loaded_weight, 'k')
weight_loader(param, loaded_weight, 'v')
else:
weight_loader(param, loaded_weight)