diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 470a7053e1..c2c310fca4 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -235,6 +235,35 @@ class GPT2Model(nn.Module): hidden_states = self.ln_f(hidden_states) return hidden_states + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + class GPT2LMHeadModel(nn.Module, SupportsPP): @@ -283,32 +312,16 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - params_dict = dict(self.named_parameters(remove_duplicate=False)) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if ".attn.bias" in name or ".attn.masked_bias" in name: - # Skip attention mask. - # NOTE: "c_attn.bias" should not be skipped. - continue - if not name.startswith("transformer.") and not name.startswith( - "lm_head"): - name = "transformer." + name + loader = AutoWeightsLoader(self) + weights = _add_transformer_prefix(weights) + return loader.load_weights(weights) - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - # The HF's GPT-2 implementation uses Conv1D instead of Linear. - # Because of this, we need to transpose the weights. - # Note(zhuohan): the logic below might break quantized models. - for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: - if conv1d_weight_name not in name: - continue - if not name.endswith(".weight"): - continue - loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params +def _add_transformer_prefix( + weights: Iterable[tuple[str, torch.Tensor]] +) -> Iterable[tuple[str, torch.Tensor]]: + for name, tensor in weights: + if not name.startswith('transformer.') and not name.startswith( + "lm_head"): + name = 'transformer.' + name + yield name, tensor