mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Support loading transformers models with named parameters (#16868)
Signed-off-by: Alex <alexwu@character.ai>
This commit is contained in:
@ -166,6 +166,9 @@ class TransformersModel(nn.Module):
|
||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||
self.init_buffers(self.model)
|
||||
|
||||
# Initialize parameters
|
||||
self.init_parameters(self.model)
|
||||
|
||||
# Move remaining meta tensors to device (should happen last)
|
||||
self.meta_to_empty(self.model)
|
||||
|
||||
@ -298,6 +301,25 @@ class TransformersModel(nn.Module):
|
||||
for child in module.children():
|
||||
self.init_buffers(child)
|
||||
|
||||
def init_parameters(self, module: nn.Module):
|
||||
"""
|
||||
If a `parameter` is on the `meta` device, then its parent
|
||||
`module` is the original module created by:
|
||||
|
||||
```python
|
||||
with torch.device("meta"):
|
||||
self.model: PreTrainedModel = AutoModel.from_config(...)
|
||||
```
|
||||
"""
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if param.device == torch.device("meta"):
|
||||
new_param = nn.Parameter(
|
||||
torch.empty_like(param.data,
|
||||
device=self.device_config.device))
|
||||
setattr(module, name, new_param)
|
||||
for child in module.children():
|
||||
self.init_parameters(child)
|
||||
|
||||
def meta_to_empty(self, module: nn.Module):
|
||||
tensors = list(chain(module.buffers(), module.parameters()))
|
||||
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
||||
@ -342,6 +364,7 @@ class TransformersModel(nn.Module):
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
loaded_params = set[str]()
|
||||
for name, loaded_weight in weights:
|
||||
# Use "model" instead of base_model_prefix because
|
||||
|
Reference in New Issue
Block a user