mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-17 00:00:49 +08:00
* ah actually we don't discard lm head if missing -> needs to be moved to correct device and etc * fix some tests * small fixes * up * up * dik why we tie weights twice but,..,,. * ups * removeunused * fix hunyuan * small fix * nits * ish * up * rev * fix more tie weights keys * small fixes * nit * update * fix and fix * fix a test * glubs * current shitty changes * ship validated ones * more * more update * more * more * more * mllama * more up * fix ernie * fix xopies * up more * more fixes * up * up * fix-copies * fix more * more updates * AI UPDATE * up * hoey * make it fast * fix * lol * fix asjusting * more fixes * _dtype nit * up * nit * update * update * remove semaphores * fix import to avoid jit execution * try to remove custom tiing logic when its stupid * fix more individual models * fix whisper as well * fix? * fox umt5 * improve tqdm bar * cleanup a bit * oupsi * some updates * improve * remove all buffering -> much faster without it * remove some tie_weights custome funcs when not needed * more fixes related to strict matching regex * remove ALL custom tie weights * small update * revert change to init scheme (no need for params) * mixtral init * try less strict source check * tied weight first shot to the fiiiixxxxxx * does this help? * :) * fix some ppolry defined tied_weights_keys for now * subclass nn.Parameters * up * lol * Ouiiii * fix led * fix long cat flash * fix qwen and long cat flash * properly fix qwen init * just push this for now * propnet is dumb * update * push * remove explict sharing of some tied keys. * update decoder.bias * moe case * more changes to untangle old hardcoded ting * fixup * fix big faileurs * fix prophnet * fix resize token embeddings * nits * fix xcodex * asyncio? * fix smart apply * fix data-2-vec * [build-ci-image] * checkout * uupdate * fix hunyuan * update error message * fix deformable detr * fixes * fix init weights for non param gate up projs * shared todo? * update some models * big revert, don't break this behaviour * ty @SunMarc this fixes the buffers Co-authored-by: SunMarc <SunMarc@users.noreply.github.com> * mt5 fuck * fix lxmbert * nuke slow test fetcher * fix zamba and deepcopy for now * fix zamba tied weight keys! ~ * fix-copies * update fetch terst * fix gradient for test modeling common! * break "shared" for now I will fix tomorrow changes are properly isoalted now :) * does this fix marian? probably not * fix some vlms * D fine seems to handle this well * glob is fine actually * fix dab detr * small steps * opusy * fix some more models? * yups * better erro * fix? * fix double escape * escape wehere it makes sense * ?? * fix ibert * fix tvp as well * more fxes * try always download ref PR * ONONONO * big fixup * more fixup * small step * small nits * nits * brut force some stuff * fix vilt * make sure special models that always need tie always tie * cleaning up * small nits * fix zamba and bridge tower! * just fixup * potential culprits * revert bark and fix bridgetower * remove now non existant tie_weights * ? * lol reformer actually had nothing tied! * wow these two fucking models were really not well made * fix sam family! * fix bark revision * fix speech2test ? * push this for now.... * upsy * the fuck * fix rtdetr * update * proper * wow that one 's annoying * update * try to find the culprit * get some help on common * nit about general init and cls.padding_idx * revert num workers update * remove old loading func * fix glob * add annotations * fix re * small improvements * clean some stuff * improvements * someone did not understannnnnnd what I tried to dooo or does BNB not support that either? * gluos * fix case when `.` is just not there * remove unused arg * recover orignal parameter/buffer using _original * fix glob issu * this? * deepspeed best-effort * remove unused stuff * Update tie weight keys as they were just wroong Co-authored-by: Benjamin Bossan <benjaminbossan@users.noreply.github.com>" * up * augustuc clauss, a gloubs gloups gloubs * fixup * fixup * there was fucking typo * mrain * nits * fix marian 3 remaining tests * one more * fix some of the copies, not all :) * small cleanup * one propertest * fix core model loadig tes * attempt a new test * fix some of the annoying tests by supporting reading .bin sometimes * push * push more small fixes * remove 1 useless test * up * fix audio flamingo post rebase * fixup * some small updatess * fix sam models * nits * up * updates * onem ore * skip this stupid test * some other fixes * fixup * update * skip more offloaded stuff * oups * ups * update mixtral * skip this one * LET"SGO * fixup * rope delta order * fix csm * small nit --------- Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com> Co-authored-by: SunMarc <SunMarc@users.noreply.github.com> Co-authored-by: Marc Sun <marc@huggingface.co>
92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
from typing import ClassVar, Optional
|
|
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
|
|
from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration
|
|
|
|
from ...cache_utils import Cache
|
|
|
|
|
|
class NewTaskModelForNewTask(PaliGemmaForConditionalGeneration):
|
|
main_input_name: ClassVar[str] = "doc_input_ids" # transformers-related
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config=config)
|
|
|
|
self.embedding_dim = self.config.embedding_dim
|
|
self.custom_text_proj = nn.Linear(self.config.text_config.hidden_size, self.embedding_dim)
|
|
|
|
if self.language_model._tied_weights_keys is not None:
|
|
prefix = "model.language_model."
|
|
prefixed_mapping = {
|
|
f"{prefix}{target}": f"{prefix}{source}"
|
|
for target, source in self.language_model._tied_weights_keys.items()
|
|
}
|
|
if isinstance(self._tied_weights_keys, dict):
|
|
self._tied_weights_keys.update(prefixed_mapping)
|
|
else:
|
|
self._tied_weights_keys = prefixed_mapping
|
|
|
|
self.post_init()
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.LongTensor = None,
|
|
pixel_values: torch.FloatTensor = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
token_type_ids: Optional[torch.LongTensor] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
return_dict: Optional[bool] = None,
|
|
num_logits_to_keep: int = 0,
|
|
):
|
|
r"""
|
|
Returns:
|
|
"""
|
|
vlm_outputs = super().forward(
|
|
input_ids=input_ids,
|
|
pixel_values=pixel_values,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
token_type_ids=token_type_ids,
|
|
cache_position=cache_position,
|
|
inputs_embeds=inputs_embeds,
|
|
labels=labels,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=True,
|
|
return_dict=True,
|
|
num_logits_to_keep=num_logits_to_keep,
|
|
)
|
|
last_hidden_states = vlm_outputs.hidden_states[-1] # (batch_size, sequence_length, hidden_size)
|
|
proj = self.custom_text_proj(last_hidden_states) # (batch_size, sequence_length, dim)
|
|
|
|
# L2 normalization
|
|
embeddings = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
|
|
|
if attention_mask is not None:
|
|
embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
|
|
|
|
return (embeddings,) + vlm_outputs
|
|
|
|
def resize_token_embeddings(
|
|
self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, mean_resizing=True
|
|
) -> nn.Embedding:
|
|
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
|
|
|
|
# Update vocab size
|
|
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
|
self.config.vocab_size = model_embeds.num_embeddings
|
|
self.vocab_size = model_embeds.num_embeddings
|
|
|
|
return model_embeds
|