mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] use AutoWeightsLoader for gpt2 (#18625)
Signed-off-by: zt2370 <ztang2370@gmail.com>
This commit is contained in:
@ -43,7 +43,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@ -235,6 +235,35 @@ class GPT2Model(nn.Module):
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
|
||||
@ -283,32 +312,16 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if ".attn.bias" in name or ".attn.masked_bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
if not name.startswith("transformer.") and not name.startswith(
|
||||
"lm_head"):
|
||||
name = "transformer." + name
|
||||
loader = AutoWeightsLoader(self)
|
||||
weights = _add_transformer_prefix(weights)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
def _add_transformer_prefix(
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
for name, tensor in weights:
|
||||
if not name.startswith('transformer.') and not name.startswith(
|
||||
"lm_head"):
|
||||
name = 'transformer.' + name
|
||||
yield name, tensor
|
||||
|
Reference in New Issue
Block a user