Compare commits

...

19 Commits

Author SHA1 Message Date
090d9c4b2a Merge branch 'main' into tensor-cache 2025-01-24 12:02:45 +01:00
5ccb79c16d fixed dynamic cache 2025-01-23 16:45:28 +01:00
80b49d721b rebased 2025-01-22 17:31:39 +01:00
dc1bd15ba9 Merge branch 'main' into tensor-cache 2025-01-22 17:30:23 +01:00
338f5954b9 more reverts 2025-01-22 17:29:48 +01:00
2f4e0bc93e Update src/transformers/cache_utils.py 2025-01-22 17:18:28 +01:00
485f959f85 revert 2025-01-22 17:17:17 +01:00
2bbbbbcf97 add device and dtype setters 2025-01-22 17:15:12 +01:00
85c71b004b Merge branch 'main' into tensor-cache 2025-01-22 15:53:33 +01:00
da60604f2c fix test_cache_utils 2025-01-22 15:43:14 +01:00
6e9799c817 add clone and to 2025-01-22 15:42:43 +01:00
4950a9e3f0 extract wrapper kwargs from init signature to correctly instantate 2025-01-22 13:49:01 +01:00
b67b6eb9b2 make cache class exportable and executorch compatible 2025-01-20 18:47:30 +01:00
d269417aab fix zamba and jamba dynamic cache 2025-01-20 17:21:49 +01:00
95c1686ee0 style 2025-01-20 17:09:21 +01:00
8606594ad4 fix boolean evaluation 2025-01-20 17:08:37 +01:00
45bb39bb80 torch tensor subclassing 2025-01-20 17:01:49 +01:00
a77a94b209 unproxy cache 2025-01-20 14:43:41 +01:00
d4b631edd0 use tensor cache instead of module cache 2025-01-20 14:17:28 +01:00
9 changed files with 164 additions and 108 deletions

View File

@ -1,5 +1,6 @@
import copy
import importlib.metadata
import inspect
import json
import os
from dataclasses import dataclass
@ -9,12 +10,7 @@ import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg
@ -24,13 +20,82 @@ if is_hqq_available():
logger = logging.get_logger(__name__)
class Cache(torch.nn.Module):
class Cache(torch.Tensor):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def __init__(self):
super().__init__()
@staticmethod
def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
init_arguments = list(init_signature.parameters.keys())
init_defaults = {
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
}
for argument in ["dtype", "device"]:
if argument in init_arguments:
arg_idx = init_arguments.index(argument)
if len(args) > arg_idx and args[arg_idx] is not None:
wrapper_kwargs[argument] = args[arg_idx]
elif kwargs.get(argument, None) is not None:
wrapper_kwargs[argument] = kwargs[argument]
elif init_defaults[argument] is not None:
wrapper_kwargs[argument] = init_defaults[argument]
if "cache_config" in init_arguments:
cache_config_idx = init_arguments.index("cache_config")
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
wrapper_kwargs["device"] = args[cache_config_idx].device
elif kwargs.get("cache_config", None) is not None:
wrapper_kwargs["device"] = kwargs["cache_config"].device
elif init_defaults["cache_config"] is not None:
wrapper_kwargs["device"] = init_defaults["cache_config"].device
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
# we create a dummy empty tensor for generic tensor flattening/unflattening
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
return self
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
assert (
func.__name__ in cls.__dict__
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)
def __repr__(self):
return f"{self.__class__.__name__}()"
def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True
def to(self, *args, **kwargs):
# originals
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
# overrides
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, (torch.device, str, int)):
wrapper_kwargs["device"] = arg
elif isinstance(arg, torch.dtype):
wrapper_kwargs["dtype"] = arg
# new wrapper
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
return new_self
def clone(self):
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self
def update(
self,
@ -304,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
cache_implementation = "static"
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device
@ -361,6 +426,16 @@ class DynamicCache(Cache):
```
"""
def __tensor_flatten__(self):
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
@staticmethod
def __tensor_unflatten__(inner_tensors, meta, _, __):
cache = DynamicCache()
cache._seen_tokens = meta["_seen_tokens"]
cache._empty_tensor = inner_tensors["_empty_tensor"]
return cache
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
@ -448,7 +523,7 @@ class DynamicCache(Cache):
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
return layer_seq_length
def get_max_cache_shape(self) -> Optional[int]:
@ -675,9 +750,6 @@ class QuantizedCache(DynamicCache):
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update(
self,
@ -777,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
raise ImportError(
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
)
from optimum.quanto import MaxOptimizer, qint2, qint4
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
@ -796,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
def _quantize(self, tensor, axis):
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight
from optimum.quanto import quantize_weight # type: ignore
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
@ -1105,7 +1177,7 @@ class StaticCache(Cache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1116,7 +1188,6 @@ class StaticCache(Cache):
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
@ -1125,8 +1196,6 @@ class StaticCache(Cache):
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)
self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
@ -1144,18 +1213,10 @@ class StaticCache(Cache):
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)
@ -1304,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1619,7 +1680,7 @@ class HybridCache(Cache):
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
@ -1648,7 +1709,6 @@ class HybridCache(Cache):
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)
self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
@ -1781,7 +1841,7 @@ class HybridCache(Cache):
return self.max_batch_size
class MambaCache:
class MambaCache(Cache):
"""
Cache for mamba model which does not have attention mechanism and key value states.
@ -1838,7 +1898,7 @@ class MambaCache:
config: PretrainedConfig,
batch_size: int = None,
dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None,
device: Union[torch.device, str] = torch.device("meta"),
max_batch_size: Optional[int] = None,
):
if batch_size is not None:
@ -1846,12 +1906,10 @@ class MambaCache:
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")
self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
@ -1981,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int],
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

View File

@ -731,6 +731,7 @@ class GenerationMixin:
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and not isinstance(dict_to_expand[key], Cache)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
@ -4632,13 +4633,13 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
if isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):

View File

@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
if is_torch_available():
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
@ -68,6 +65,8 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
)
self.model = model
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
self.static_cache = StaticCache(
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
@ -75,14 +74,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
)
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
if self.is_causal:
causal_mask = torch.tril(
torch.ones(
self.static_cache.max_cache_len,
self.static_cache.max_cache_len,
dtype=torch.bool,
)
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
)
self.register_buffer("mask", causal_mask, persistent=False)
@ -108,15 +106,20 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
"""
_, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache
outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
past_key_values=self.static_cache,
use_cache=True,
)
return outs.logits
@staticmethod
@ -143,7 +146,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break

View File

@ -215,7 +215,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
intermediate_size = config.mamba_expand * config.hidden_size

View File

@ -129,7 +129,6 @@ class ZambaHybridDynamicCache(DynamicCache):
"""
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
self.intermediate_size = config.mamba_expand * config.hidden_size
@ -139,9 +138,7 @@ class ZambaHybridDynamicCache(DynamicCache):
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}
for i in range(config.num_hidden_layers):
self.conv_states += [
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)

View File

@ -35,7 +35,7 @@ from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.proxy import ParameterProxy
from .. import logging
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from ..cache_utils import Cache
from ..modeling_utils import PretrainedConfig, PreTrainedModel
from ..models.auto import get_values
from ..models.auto.modeling_auto import (
@ -811,40 +811,40 @@ def _proxies_to_metas(v):
return v
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
# def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
# global _CURRENT_TRACER
# if not isinstance(_CURRENT_TRACER, HFTracer):
# raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
# cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
# cache_proxy.install_orig_cache_cls(orig_cache_cls)
# return cache_proxy
return cache_proxy_factory_fn
# return cache_proxy_factory_fn
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
)
ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache",
(DynamicCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache",
(StaticCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)
# # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
# ProxyableCache = HFProxyableClassMeta(
# "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
# )
# ProxyableDynamicCache = HFProxyableClassMeta(
# "ProxyableDynamicCache",
# (DynamicCache,),
# {},
# proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
# )
# ProxyableSinkCache = HFProxyableClassMeta(
# "ProxyableSinkCache",
# (SinkCache,),
# {},
# proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
# )
# ProxyableStaticCache = HFProxyableClassMeta(
# "ProxyableStaticCache",
# (StaticCache,),
# {},
# proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
# )
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
@ -879,10 +879,10 @@ class HFTracer(Tracer):
"tril",
]
_CLASSES_TO_PATCH = {
Cache: ProxyableCache,
DynamicCache: ProxyableDynamicCache,
SinkCache: ProxyableSinkCache,
StaticCache: ProxyableStaticCache,
# Cache: ProxyableCache,
# DynamicCache: ProxyableDynamicCache,
# SinkCache: ProxyableSinkCache,
# StaticCache: ProxyableStaticCache,
}
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

View File

@ -738,6 +738,7 @@ class LlamaIntegrationTest(unittest.TestCase):
@slow
@require_read_token
def test_export_static_cache(self):
# this test only run with an accelerator but it doesn't need an accelerator ?
if version.parse(torch.__version__) < version.parse("2.4.0"):
self.skipTest(reason="This test requires torch >= 2.4 to run.")

View File

@ -2390,7 +2390,7 @@ class ModelTesterMixin:
elif tuple_object is None:
return
# model might return non-tensors objects (e.g. Cache class)
elif isinstance(tuple_object, torch.Tensor):
elif isinstance(tuple_object, torch.Tensor) and not isinstance(tuple_object, Cache):
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5

View File

@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
input_ids = gen_out
# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
self.assertEqual(responses, EXPECTED_DECODED_TEXT)