mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Model] Fix Baichuan BNB online quantization (#10572)
Signed-off-by: Chen Wu <cntryroa@gmail.com>
This commit is contained in:
@ -350,6 +350,21 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".W_pack.",
|
||||
".o_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".gate_proj.",
|
||||
".up_proj.",
|
||||
]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
Reference in New Issue
Block a user