Support loading transformers models with named parameters (#16868)

Signed-off-by: Alex <alexwu@character.ai>
This commit is contained in:
Alex Wu
2025-04-28 15:15:58 -07:00
committed by GitHub
parent dcbac4cb4b
commit 6e74fd4945

View File

@ -166,6 +166,9 @@ class TransformersModel(nn.Module):
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Initialize parameters
self.init_parameters(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
@ -298,6 +301,25 @@ class TransformersModel(nn.Module):
for child in module.children():
self.init_buffers(child)
def init_parameters(self, module: nn.Module):
"""
If a `parameter` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
"""
for name, param in module.named_parameters(recurse=False):
if param.device == torch.device("meta"):
new_param = nn.Parameter(
torch.empty_like(param.data,
device=self.device_config.device))
setattr(module, name, new_param)
for child in module.children():
self.init_parameters(child)
def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
@ -342,6 +364,7 @@ class TransformersModel(nn.Module):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params = set[str]()
for name, loaded_weight in weights:
# Use "model" instead of base_model_prefix because