Merge branch 'main' into fix_xpu_kernels

This commit is contained in:
Mohamed Mekkouri
2025-10-14 16:42:13 +02:00
committed by GitHub
5 changed files with 31 additions and 21 deletions

View File

@ -176,6 +176,11 @@ class DynamicSlidingWindowLayer(DynamicLayer):
super().__init__()
self.sliding_window = sliding_window
self.cumulative_length = 0
self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
def lazy_initialization(self, key_states: torch.Tensor) -> None:
super().lazy_initialization(key_states)
self._sliding_window_tensor = self._sliding_window_tensor.to(self.device)
def update(
self,
@ -932,7 +937,7 @@ class DynamicCache(Cache):
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor]]] = None,
config: Optional[PreTrainedConfig] = None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
@ -965,10 +970,15 @@ class DynamicCache(Cache):
# In this case, use the passed data to already fill in the Cache
if ddp_cache_data is not None:
# Init all the layers with the data
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
for layer_idx, (sliding_window_tensor, key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a new cache layer for each entry of the ddp_data
if config is None:
layers.append(DynamicLayer())
if sliding_window_tensor is not None:
# Since the same layer is dispatched across replicas, sliding_window is the same for all
sliding_window = sliding_window_tensor[0].item()
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
else:
layers.append(DynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)
@ -982,6 +992,10 @@ class DynamicCache(Cache):
else:
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
def __iter__(self):
for layer in self.layers:
yield getattr(layer, "_sliding_window_tensor", None), layer.keys, layer.values
class StaticCache(Cache):
"""

View File

@ -18,6 +18,7 @@ from typing import Optional, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from ...cache_utils import Cache
from ...configuration_utils import PreTrainedConfig
@ -374,9 +375,6 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# num_items_in_batch is only needed for loss computation
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
@ -435,12 +433,8 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.decoder.config.vocab_size,
num_items_in_batch=num_items_in_batch,
)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
if not return_dict:
if loss is not None:

View File

@ -121,17 +121,15 @@ class AutoQuantizationConfig:
@classmethod
def from_dict(cls, quantization_config_dict: dict):
quant_method = quantization_config_dict.get("quant_method")
if quant_method is None:
# We need a special care for bnb models to make sure everything is BC ..
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
elif quant_method is None:
raise ValueError(
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
)
if quant_method == QuantizationMethod.BITS_AND_BYTES:
if quantization_config_dict.get("load_in_8bit"):
quant_method += "_8bit"
else:
quant_method += "_4bit"
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
raise ValueError(
f"Unknown quantization type, got {quant_method} - supported types are:"

View File

@ -1598,6 +1598,7 @@ class ModelTesterMixin:
cache_shape = (batch_size, num_heads, cache_length, head_dim)
non_empty_pkv = tuple(
(
None,
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
)

View File

@ -1806,7 +1806,10 @@ class ModelUtilsTest(TestCasePlus):
# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
past_key_values = [
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
]
past_key_values = DynamicCache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
(