mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-12 01:04:36 +08:00
Merge branch 'main' into fix_xpu_kernels
This commit is contained in:
@ -176,6 +176,11 @@ class DynamicSlidingWindowLayer(DynamicLayer):
|
||||
super().__init__()
|
||||
self.sliding_window = sliding_window
|
||||
self.cumulative_length = 0
|
||||
self._sliding_window_tensor = torch.tensor(self.sliding_window, dtype=torch.long)
|
||||
|
||||
def lazy_initialization(self, key_states: torch.Tensor) -> None:
|
||||
super().lazy_initialization(key_states)
|
||||
self._sliding_window_tensor = self._sliding_window_tensor.to(self.device)
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -932,7 +937,7 @@ class DynamicCache(Cache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor]]] = None,
|
||||
config: Optional[PreTrainedConfig] = None,
|
||||
offloading: bool = False,
|
||||
offload_only_non_sliding: bool = False,
|
||||
@ -965,10 +970,15 @@ class DynamicCache(Cache):
|
||||
# In this case, use the passed data to already fill in the Cache
|
||||
if ddp_cache_data is not None:
|
||||
# Init all the layers with the data
|
||||
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
|
||||
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
|
||||
for layer_idx, (sliding_window_tensor, key_states, value_states) in enumerate(ddp_cache_data):
|
||||
# If the config was not passed above, initialize a new cache layer for each entry of the ddp_data
|
||||
if config is None:
|
||||
layers.append(DynamicLayer())
|
||||
if sliding_window_tensor is not None:
|
||||
# Since the same layer is dispatched across replicas, sliding_window is the same for all
|
||||
sliding_window = sliding_window_tensor[0].item()
|
||||
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
|
||||
else:
|
||||
layers.append(DynamicLayer())
|
||||
# Update the layer with the data
|
||||
_, _ = layers[layer_idx].update(key_states, value_states)
|
||||
|
||||
@ -982,6 +992,10 @@ class DynamicCache(Cache):
|
||||
else:
|
||||
super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
|
||||
|
||||
def __iter__(self):
|
||||
for layer in self.layers:
|
||||
yield getattr(layer, "_sliding_window_tensor", None), layer.keys, layer.values
|
||||
|
||||
|
||||
class StaticCache(Cache):
|
||||
"""
|
||||
|
||||
@ -18,6 +18,7 @@ from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from ...cache_utils import Cache
|
||||
from ...configuration_utils import PreTrainedConfig
|
||||
@ -374,9 +375,6 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# num_items_in_batch is only needed for loss computation
|
||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
||||
|
||||
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
|
||||
|
||||
kwargs_decoder = {
|
||||
@ -435,12 +433,8 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
if labels is not None:
|
||||
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
|
||||
|
||||
loss = self.loss_function(
|
||||
logits=logits,
|
||||
labels=labels,
|
||||
vocab_size=self.decoder.config.vocab_size,
|
||||
num_items_in_batch=num_items_in_batch,
|
||||
)
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), labels.reshape(-1))
|
||||
|
||||
if not return_dict:
|
||||
if loss is not None:
|
||||
|
||||
@ -121,17 +121,15 @@ class AutoQuantizationConfig:
|
||||
@classmethod
|
||||
def from_dict(cls, quantization_config_dict: dict):
|
||||
quant_method = quantization_config_dict.get("quant_method")
|
||||
if quant_method is None:
|
||||
# We need a special care for bnb models to make sure everything is BC ..
|
||||
if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
|
||||
suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
|
||||
quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
|
||||
elif quant_method is None:
|
||||
raise ValueError(
|
||||
"The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
|
||||
)
|
||||
|
||||
if quant_method == QuantizationMethod.BITS_AND_BYTES:
|
||||
if quantization_config_dict.get("load_in_8bit"):
|
||||
quant_method += "_8bit"
|
||||
else:
|
||||
quant_method += "_4bit"
|
||||
|
||||
if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING:
|
||||
raise ValueError(
|
||||
f"Unknown quantization type, got {quant_method} - supported types are:"
|
||||
|
||||
@ -1598,6 +1598,7 @@ class ModelTesterMixin:
|
||||
cache_shape = (batch_size, num_heads, cache_length, head_dim)
|
||||
non_empty_pkv = tuple(
|
||||
(
|
||||
None,
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
torch.rand(cache_shape, dtype=torch.float, device=torch_device),
|
||||
)
|
||||
|
||||
@ -1806,7 +1806,10 @@ class ModelUtilsTest(TestCasePlus):
|
||||
|
||||
# simulate injecting virtual tokens like in prefix tuning
|
||||
num_virtual_tokens = 3
|
||||
past_key_values = [torch.randn(2, 1, 2, num_virtual_tokens, 8)] * 2
|
||||
past_key_values = [
|
||||
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
|
||||
]
|
||||
past_key_values = DynamicCache(past_key_values)
|
||||
model_inputs["attention_mask"] = torch.cat(
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user