mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Model] use AutoWeightsLoader for commandr (#19399)
Signed-off-by: py-andy-c <pychen1017@gmail.com>
This commit is contained in:
@ -51,7 +51,8 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||||
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||||
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -286,6 +287,7 @@ class CohereModel(nn.Module):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
@ -339,6 +341,62 @@ class CohereModel(nn.Module):
|
|||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache quantization scales
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||||
|
if shard_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(shard_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
|
class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
@ -408,65 +466,6 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
stacked_params_mapping = [
|
loader = AutoWeightsLoader(
|
||||||
# (param_name, shard_name, shard_id)
|
self, skip_prefixes=["lm_head", "rotary_emb.inv_freq"])
|
||||||
("qkv_proj", "q_proj", "q"),
|
return loader.load_weights(weights)
|
||||||
("qkv_proj", "k_proj", "k"),
|
|
||||||
("qkv_proj", "v_proj", "v"),
|
|
||||||
("gate_up_proj", "gate_proj", 0),
|
|
||||||
("gate_up_proj", "up_proj", 1),
|
|
||||||
]
|
|
||||||
params_dict = dict(self.named_parameters())
|
|
||||||
loaded_params: set[str] = set()
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
|
|
||||||
# Skip loading rotary embeddings since vLLM has its own
|
|
||||||
if "rotary_emb.inv_freq" in name:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (self.quant_config is not None and
|
|
||||||
(scale_name := self.quant_config.get_cache_scale(name))):
|
|
||||||
# Loading kv cache quantization scales
|
|
||||||
param = params_dict[scale_name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
|
||||||
loaded_weight[0])
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(scale_name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
||||||
if shard_name not in name:
|
|
||||||
continue
|
|
||||||
name = name.replace(shard_name, param_name)
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = param.weight_loader
|
|
||||||
weight_loader(param, loaded_weight, shard_id)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# lm_head is not used in vllm as it is tied with embed_token.
|
|
||||||
# To prevent errors, skip loading lm_head.weight.
|
|
||||||
if "lm_head.weight" in name:
|
|
||||||
continue
|
|
||||||
# Skip loading extra bias for GPTQ models.
|
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
|
||||||
continue
|
|
||||||
# Remapping the name of FP8 kv-scale.
|
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
||||||
if name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
|
||||||
continue
|
|
||||||
param = params_dict[name]
|
|
||||||
weight_loader = getattr(param, "weight_loader",
|
|
||||||
default_weight_loader)
|
|
||||||
weight_loader(param, loaded_weight)
|
|
||||||
loaded_params.add(name)
|
|
||||||
return loaded_params
|
|
||||||
|
|||||||
Reference in New Issue
Block a user