Add GPTQ support for Gemma (#3200)

This commit is contained in:
TechxGenus
2024-03-07 08:19:14 +08:00
committed by GitHub
parent 4cb3b924cd
commit d3c04b6a39

View File

@ -325,11 +325,17 @@ class GemmaForCausalLM(nn.Module):
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if "norm.weight" in name: