Compare commits

...

19 Commits

Author SHA1 Message Date
090d9c4b2a Merge branch 'main' into tensor-cache 2025-01-24 12:02:45 +01:00
5ccb79c16d fixed dynamic cache 2025-01-23 16:45:28 +01:00
80b49d721b rebased 2025-01-22 17:31:39 +01:00
dc1bd15ba9 Merge branch 'main' into tensor-cache 2025-01-22 17:30:23 +01:00
338f5954b9 more reverts 2025-01-22 17:29:48 +01:00
2f4e0bc93e Update src/transformers/cache_utils.py 2025-01-22 17:18:28 +01:00
485f959f85 revert 2025-01-22 17:17:17 +01:00
2bbbbbcf97 add device and dtype setters 2025-01-22 17:15:12 +01:00
85c71b004b Merge branch 'main' into tensor-cache 2025-01-22 15:53:33 +01:00
da60604f2c fix test_cache_utils 2025-01-22 15:43:14 +01:00
6e9799c817 add clone and to 2025-01-22 15:42:43 +01:00
4950a9e3f0 extract wrapper kwargs from init signature to correctly instantate 2025-01-22 13:49:01 +01:00
b67b6eb9b2 make cache class exportable and executorch compatible 2025-01-20 18:47:30 +01:00
d269417aab fix zamba and jamba dynamic cache 2025-01-20 17:21:49 +01:00
95c1686ee0 style 2025-01-20 17:09:21 +01:00
8606594ad4 fix boolean evaluation 2025-01-20 17:08:37 +01:00
45bb39bb80 torch tensor subclassing 2025-01-20 17:01:49 +01:00
a77a94b209 unproxy cache 2025-01-20 14:43:41 +01:00
d4b631edd0 use tensor cache instead of module cache 2025-01-20 14:17:28 +01:00
9 changed files with 164 additions and 108 deletions

View File

@ -1,5 +1,6 @@
import copy import copy
import importlib.metadata import importlib.metadata
import inspect
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
@ -9,12 +10,7 @@ import torch
from packaging import version from packaging import version
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import ( from .utils import is_hqq_available, is_optimum_quanto_available, logging
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils.deprecation import deprecate_kwarg from .utils.deprecation import deprecate_kwarg
@ -24,13 +20,82 @@ if is_hqq_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
class Cache(torch.nn.Module): class Cache(torch.Tensor):
""" """
Base, abstract class for all caches. The actual data structure is specific to each subclass. Base, abstract class for all caches. The actual data structure is specific to each subclass.
""" """
def __init__(self): @staticmethod
super().__init__() def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
init_arguments = list(init_signature.parameters.keys())
init_defaults = {
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
}
for argument in ["dtype", "device"]:
if argument in init_arguments:
arg_idx = init_arguments.index(argument)
if len(args) > arg_idx and args[arg_idx] is not None:
wrapper_kwargs[argument] = args[arg_idx]
elif kwargs.get(argument, None) is not None:
wrapper_kwargs[argument] = kwargs[argument]
elif init_defaults[argument] is not None:
wrapper_kwargs[argument] = init_defaults[argument]
if "cache_config" in init_arguments:
cache_config_idx = init_arguments.index("cache_config")
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
wrapper_kwargs["device"] = args[cache_config_idx].device
elif kwargs.get("cache_config", None) is not None:
wrapper_kwargs["device"] = kwargs["cache_config"].device
elif init_defaults["cache_config"] is not None:
wrapper_kwargs["device"] = init_defaults["cache_config"].device
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
# we create a dummy empty tensor for generic tensor flattening/unflattening
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
return self
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
assert (
func.__name__ in cls.__dict__
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)
def __repr__(self):
return f"{self.__class__.__name__}()"
def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True
def to(self, *args, **kwargs):
# originals
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
# overrides
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, (torch.device, str, int)):
wrapper_kwargs["device"] = arg
elif isinstance(arg, torch.dtype):
wrapper_kwargs["dtype"] = arg
# new wrapper
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
return new_self
def clone(self):
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self
def update( def update(
self, self,
@ -304,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
cache_implementation = "static" cache_implementation = "static"
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"): def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
self.batch_size = batch_size self.batch_size = batch_size
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len
self.device = device self.device = device
@ -361,6 +426,16 @@ class DynamicCache(Cache):
``` ```
""" """
def __tensor_flatten__(self):
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, _, __):
cache = DynamicCache()
cache._seen_tokens = meta["_seen_tokens"]
cache._empty_tensor = inner_tensors["_empty_tensor"]
return cache
@deprecate_kwarg("num_hidden_layers", version="4.47.0") @deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None: def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__() super().__init__()
@ -448,7 +523,7 @@ class DynamicCache(Cache):
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
) )
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
return layer_seq_length return layer_seq_length
def get_max_cache_shape(self) -> Optional[int]: def get_max_cache_shape(self) -> Optional[int]:
@ -675,9 +750,6 @@ class QuantizedCache(DynamicCache):
self.axis_key = cache_config.axis_key self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update( def update(
self, self,
@ -777,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
raise ImportError( raise ImportError(
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}." f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
) )
from optimum.quanto import MaxOptimizer, qint2, qint4 from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
if self.nbits not in [2, 4]: if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
@ -796,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
def _quantize(self, tensor, axis): def _quantize(self, tensor, axis):
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available(): if is_optimum_quanto_available():
from optimum.quanto import quantize_weight from optimum.quanto import quantize_weight # type: ignore
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size) scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size) qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
@ -1105,7 +1177,7 @@ class StaticCache(Cache):
config: PretrainedConfig, config: PretrainedConfig,
batch_size: int = None, batch_size: int = None,
max_cache_len: int = None, max_cache_len: int = None,
device: torch.device = None, device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1116,7 +1188,6 @@ class StaticCache(Cache):
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead." "v4.49. Use the more precisely named 'max_batch_size' argument instead."
) )
self.max_batch_size = batch_size or max_batch_size self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
@ -1125,8 +1196,6 @@ class StaticCache(Cache):
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
) )
self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = ( self.num_key_value_heads = (
config.num_attention_heads config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None if getattr(config, "num_key_value_heads", None) is None
@ -1144,18 +1213,10 @@ class StaticCache(Cache):
layer_device = self.device layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device) new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes: # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph # preventing compiled graph breaks when updating the cache.
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case torch._dynamo.mark_static_address(new_layer_key_cache)
# it is not needed anyway) torch._dynamo.mark_static_address(new_layer_key_cache)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache) self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache) self.value_cache.append(new_layer_value_cache)
@ -1304,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
config: PretrainedConfig, config: PretrainedConfig,
batch_size: int = None, batch_size: int = None,
max_cache_len: int = None, max_cache_len: int = None,
device: torch.device = None, device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1619,7 +1680,7 @@ class HybridCache(Cache):
config: PretrainedConfig, config: PretrainedConfig,
batch_size: int = None, batch_size: int = None,
max_cache_len: int = None, max_cache_len: int = None,
device: Union[torch.device, str] = None, device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1648,7 +1709,6 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
) )
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor( self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool [bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
@ -1781,7 +1841,7 @@ class HybridCache(Cache):
return self.max_batch_size return self.max_batch_size
class MambaCache: class MambaCache(Cache):
""" """
Cache for mamba model which does not have attention mechanism and key value states. Cache for mamba model which does not have attention mechanism and key value states.
@ -1838,7 +1898,7 @@ class MambaCache:
config: PretrainedConfig, config: PretrainedConfig,
batch_size: int = None, batch_size: int = None,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None, device: Union[torch.device, str] = torch.device("meta"),
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
): ):
if batch_size is not None: if batch_size is not None:
@ -1846,12 +1906,10 @@ class MambaCache:
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead." "v4.49. Use the more precisely named 'max_batch_size' argument instead."
) )
self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: List[torch.Tensor] = [] self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = [] self.ssm_states: List[torch.Tensor] = []
@ -1981,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
config: PretrainedConfig, config: PretrainedConfig,
max_batch_size: int, max_batch_size: int,
max_cache_len: Optional[int], max_cache_len: Optional[int],
device: Union[str, torch.device], device: Union[torch.device, str] = torch.device("meta"),
dtype: Optional[torch.dtype] = None, dtype: torch.dtype = torch.float32,
offload_device: Union[str, torch.device] = torch.device("cpu"), offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None: ) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device) self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

View File

@ -731,6 +731,7 @@ class GenerationMixin:
key != "cache_position" key != "cache_position"
and dict_to_expand[key] is not None and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor) and isinstance(dict_to_expand[key], torch.Tensor)
and not isinstance(dict_to_expand[key], Cache)
): ):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand return dict_to_expand
@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
""" """
if data is None: if data is None:
return [None] * (full_batch_size // split_size) return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format # New cache format
elif isinstance(data, DynamicCache) or ( elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache) isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
): ):
return data.batch_split(full_batch_size, split_size, num_hidden_layers) return data.batch_split(full_batch_size, split_size, num_hidden_layers)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple): elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple): if isinstance(data[0], tuple):
@ -4632,13 +4633,13 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
""" """
if any(data is None for data in data): if any(data is None for data in data):
return None return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format # New cache format
elif isinstance(data[0], DynamicCache): if isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache): elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple): elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple): if isinstance(data[0][0], tuple):

View File

@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
if is_torch_available(): if is_torch_available():
from transformers import ( from transformers import PreTrainedModel, StaticCache
PreTrainedModel,
StaticCache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3 from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
@ -68,6 +65,8 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
) )
self.model = model self.model = model
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
self.static_cache = StaticCache( self.static_cache = StaticCache(
config=self.model.config, config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size, batch_size=self.model.generation_config.cache_config.batch_size,
@ -75,14 +74,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
dtype=self.model.dtype, dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device, device=self.model.generation_config.cache_config.device,
) )
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures) for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
if self.is_causal: if self.is_causal:
causal_mask = torch.tril( causal_mask = torch.tril(
torch.ones( torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
) )
self.register_buffer("mask", causal_mask, persistent=False) self.register_buffer("mask", causal_mask, persistent=False)
@ -108,15 +106,20 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box. ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
""" """
_, seqlen = input_ids.shape _, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache
outs = self.model( outs = self.model(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attn_mask, attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0), position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position, cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True, use_cache=True,
) )
return outs.logits return outs.logits
@staticmethod @staticmethod
@ -143,7 +146,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
prompt_token_len = prompt_token_ids.shape[-1] prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers(): for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"): if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2] max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len) max_generation_length = min(max_generation_length, max_cache_len)
break break

View File

@ -215,7 +215,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
def __init__(self, config, batch_size, dtype=torch.float16, device=None): def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__() super().__init__()
self.dtype = dtype
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size intermediate_size = config.mamba_expand * config.hidden_size

View File

@ -129,7 +129,6 @@ class ZambaHybridDynamicCache(DynamicCache):
""" """
def __init__(self, config, batch_size, dtype=torch.float16, device=None): def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba self.has_previous_state = False # only used by mamba
self.intermediate_size = config.mamba_expand * config.hidden_size self.intermediate_size = config.mamba_expand * config.hidden_size
@ -139,9 +138,7 @@ class ZambaHybridDynamicCache(DynamicCache):
self.conv_states = [] self.conv_states = []
self.ssm_states = [] self.ssm_states = []
self.transformer_layers = [] self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
self.conv_states += [ self.conv_states += [
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype) torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)

View File

@ -35,7 +35,7 @@ from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.proxy import ParameterProxy from torch.fx.proxy import ParameterProxy
from .. import logging from .. import logging
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache from ..cache_utils import Cache
from ..modeling_utils import PretrainedConfig, PreTrainedModel from ..modeling_utils import PretrainedConfig, PreTrainedModel
from ..models.auto import get_values from ..models.auto import get_values
from ..models.auto.modeling_auto import ( from ..models.auto.modeling_auto import (
@ -811,40 +811,40 @@ def _proxies_to_metas(v):
return v return v
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]: # def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: # def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER # global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer): # if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") # raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER) # cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls) # cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy # return cache_proxy
return cache_proxy_factory_fn # return cache_proxy_factory_fn
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. # # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta( # ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache) # "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
) # )
ProxyableDynamicCache = HFProxyableClassMeta( # ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", # "ProxyableDynamicCache",
(DynamicCache,), # (DynamicCache,),
{}, # {},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache), # proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
) # )
ProxyableSinkCache = HFProxyableClassMeta( # ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", # "ProxyableSinkCache",
(SinkCache,), # (SinkCache,),
{}, # {},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache), # proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
) # )
ProxyableStaticCache = HFProxyableClassMeta( # ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", # "ProxyableStaticCache",
(StaticCache,), # (StaticCache,),
{}, # {},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache), # proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
) # )
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None): def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
@ -879,10 +879,10 @@ class HFTracer(Tracer):
"tril", "tril",
] ]
_CLASSES_TO_PATCH = { _CLASSES_TO_PATCH = {
Cache: ProxyableCache, # Cache: ProxyableCache,
DynamicCache: ProxyableDynamicCache, # DynamicCache: ProxyableDynamicCache,
SinkCache: ProxyableSinkCache, # SinkCache: ProxyableSinkCache,
StaticCache: ProxyableStaticCache, # StaticCache: ProxyableStaticCache,
} }
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

View File

@ -738,6 +738,7 @@ class LlamaIntegrationTest(unittest.TestCase):
@slow @slow
@require_read_token @require_read_token
def test_export_static_cache(self): def test_export_static_cache(self):
# this test only run with an accelerator but it doesn't need an accelerator ?
if version.parse(torch.__version__) < version.parse("2.4.0"): if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.") self.skipTest(reason="This test requires torch >= 2.4 to run.")

View File

@ -2390,7 +2390,7 @@ class ModelTesterMixin:
elif tuple_object is None: elif tuple_object is None:
return return
# model might return non-tensors objects (e.g. Cache class) # model might return non-tensors objects (e.g. Cache class)
elif isinstance(tuple_object, torch.Tensor): elif isinstance(tuple_object, torch.Tensor) and not isinstance(tuple_object, Cache):
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5

View File

@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
# Check if the exported model is configured with the `StaticCache` correctly # Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0 n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers(): for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"): if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len) self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1 n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"): if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size) self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len) self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1 n_static_value_caches = n_static_value_caches + 1
@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
input_ids = gen_out input_ids = gen_out
# We went well beyond the cache length # We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5) self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
# And it still produces a coherent english # And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True) decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the' 'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip ] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT) self.assertEqual(responses, EXPECTED_DECODED_TEXT)