Files
transformers/examples/modular-transformers/modular_new_task_model.py
Arthur 6f6095e0cf Refactor weight loading (#41580)
* ah actually we don't discard lm head if missing -> needs to be moved to correct device and etc

* fix some tests

* small fixes

* up

* up

* dik why we tie weights twice but,..,,.

* ups

* removeunused

* fix hunyuan

* small fix

* nits

* ish

* up

* rev

* fix more tie weights keys

* small fixes

* nit

* update

* fix and fix

* fix a test

* glubs

* current shitty changes

* ship validated ones

* more

* more update

* more

* more

* more

* mllama

* more up

* fix ernie

* fix xopies

* up more

* more fixes

* up

* up

* fix-copies

* fix more

* more updates

* AI UPDATE

* up

* hoey

* make it fast

* fix

* lol

* fix asjusting

* more fixes

* _dtype nit

* up

* nit

* update

* update

* remove semaphores

* fix import to avoid jit execution

* try to remove custom tiing logic when its stupid

* fix more individual models

* fix whisper as well

* fix?

* fox umt5

* improve tqdm bar

* cleanup a bit

* oupsi

* some updates

* improve

* remove all buffering -> much faster without it

* remove some tie_weights custome funcs when not needed

* more fixes related to strict matching regex

* remove ALL custom tie weights

* small update

* revert change to init scheme (no need for params)

* mixtral init

* try less strict source check

* tied weight first shot to the fiiiixxxxxx

* does this help?

* :)

* fix some ppolry defined tied_weights_keys for now

* subclass nn.Parameters

* up

* lol

* Ouiiii

* fix led

* fix long cat flash

* fix qwen and long cat flash

* properly fix qwen init

* just push this for now

* propnet is dumb

* update

* push

* remove explict sharing of some tied keys.

* update decoder.bias

* moe case

* more changes to untangle old hardcoded ting

* fixup

* fix big faileurs

* fix prophnet

* fix resize token embeddings

* nits

* fix xcodex

* asyncio?

* fix smart apply

* fix data-2-vec

* [build-ci-image]

* checkout

* uupdate

* fix hunyuan

* update error message

* fix deformable detr

* fixes

* fix init weights for non param gate up projs

* shared todo?

* update some models

* big revert, don't break this behaviour

* ty @SunMarc this fixes the buffers

Co-authored-by: SunMarc <SunMarc@users.noreply.github.com>

* mt5 fuck

* fix lxmbert

* nuke slow test fetcher

* fix zamba and deepcopy for now

* fix zamba tied weight keys! ~

* fix-copies

* update fetch terst

* fix gradient for test modeling common!

* break "shared" for now I will fix tomorrow changes are properly isoalted now :)

* does this fix marian? probably not

* fix some vlms

* D fine seems to handle this well

* glob is fine actually

* fix dab detr

* small steps

* opusy

* fix some more models?

* yups

* better erro

* fix?

* fix double escape

* escape wehere it makes sense

* ??

* fix ibert

* fix tvp as well

* more fxes

* try always download ref PR

* ONONONO

* big fixup

* more fixup

* small step

* small nits

* nits

* brut force some stuff

* fix vilt

* make sure special models that always need tie always tie

* cleaning up

* small nits

* fix zamba and bridge tower!

* just fixup

* potential culprits

* revert bark and fix bridgetower

* remove now non existant tie_weights

* ?

* lol reformer actually had nothing tied!

* wow these two fucking models were really not well made

* fix sam family!

* fix bark revision

* fix speech2test ?

* push this for now....

* upsy

* the fuck

* fix rtdetr

* update

* proper

* wow that one 's annoying

* update

* try to find the culprit

* get some help on common

* nit about general init and cls.padding_idx

* revert num workers update

* remove old loading func

* fix glob

* add annotations

* fix re

* small improvements

* clean some stuff

* improvements

* someone did not understannnnnnd what I tried to dooo or does BNB not support that either?

* gluos

* fix case when `.` is just not there

* remove unused arg

* recover orignal parameter/buffer using _original

* fix glob issu

* this?

* deepspeed best-effort

* remove unused stuff

* Update tie weight keys as they were just wroong

Co-authored-by: Benjamin Bossan <benjaminbossan@users.noreply.github.com>"

* up

* augustuc clauss, a gloubs gloups gloubs

* fixup

* fixup

* there was fucking typo

* mrain

* nits

* fix marian 3 remaining tests

* one more

* fix some of the copies, not all :)

* small cleanup

* one propertest

* fix core model loadig tes

* attempt a new test

* fix some of the annoying tests by supporting reading .bin sometimes

* push

* push more small fixes

* remove 1 useless test

* up

* fix audio flamingo post rebase

* fixup

* some small updatess

* fix sam models

* nits

* up

* updates

* onem ore

* skip this stupid test

* some other fixes

* fixup

* update

* skip more offloaded stuff

* oups

* ups

* update mixtral

* skip this one

* LET"SGO

* fixup

* rope delta order

* fix csm

* small nit

---------

Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
Co-authored-by: SunMarc <SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <marc@huggingface.co>
2025-11-13 17:12:52 +01:00

92 lines
3.5 KiB
Python

from typing import ClassVar, Optional
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
from ...cache_utils import Cache
class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
def __init__(self, config):
super().__init__(config=config)
self.embedding_dim = self.config.embedding_dim
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
if self.language_model._tied_weights_keys is not None:
prefix = "model.language_model."
prefixed_mapping = {
f"{prefix}{target}": f"{prefix}{source}"
for target, source in self.language_model._tied_weights_keys.items()
}
if isinstance(self._tied_weights_keys, dict):
self._tied_weights_keys.update(prefixed_mapping)
else:
self._tied_weights_keys = prefixed_mapping
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
token_type_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
):
r"""
Returns:
"""
vlm_outputs = super().forward(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
token_type_ids=token_type_ids,
cache_position=cache_position,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True,
num_logits_to_keep=num_logits_to_keep,
)
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
# L2 normalization
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
if attention_mask is not None:
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
return (embeddings,) + vlm_outputs
def resize_token_embeddings(
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True
) -> nn.Embedding:
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
# Update vocab size
self.config.text_config.vocab_size = model_embeds.num_embeddings
self.config.vocab_size = model_embeds.num_embeddings
self.vocab_size = model_embeds.num_embeddings
return model_embeds