mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
19 Commits
v4.50.3-De
...
tensor-cac
Author | SHA1 | Date | |
---|---|---|---|
090d9c4b2a | |||
5ccb79c16d | |||
80b49d721b | |||
dc1bd15ba9 | |||
338f5954b9 | |||
2f4e0bc93e | |||
485f959f85 | |||
2bbbbbcf97 | |||
85c71b004b | |||
da60604f2c | |||
6e9799c817 | |||
4950a9e3f0 | |||
b67b6eb9b2 | |||
d269417aab | |||
95c1686ee0 | |||
8606594ad4 | |||
45bb39bb80 | |||
a77a94b209 | |||
d4b631edd0 |
@ -1,5 +1,6 @@
|
||||
import copy
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
@ -9,12 +10,7 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .utils import (
|
||||
is_hqq_available,
|
||||
is_optimum_quanto_available,
|
||||
is_torchdynamo_compiling,
|
||||
logging,
|
||||
)
|
||||
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||
from .utils.deprecation import deprecate_kwarg
|
||||
|
||||
|
||||
@ -24,13 +20,82 @@ if is_hqq_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Cache(torch.nn.Module):
|
||||
class Cache(torch.Tensor):
|
||||
"""
|
||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@staticmethod
|
||||
def __new__(cls, *args, **kwargs):
|
||||
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
|
||||
|
||||
wrapper_kwargs = {}
|
||||
init_signature = inspect.signature(cls.__init__)
|
||||
init_arguments = list(init_signature.parameters.keys())
|
||||
init_defaults = {
|
||||
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
|
||||
}
|
||||
|
||||
for argument in ["dtype", "device"]:
|
||||
if argument in init_arguments:
|
||||
arg_idx = init_arguments.index(argument)
|
||||
if len(args) > arg_idx and args[arg_idx] is not None:
|
||||
wrapper_kwargs[argument] = args[arg_idx]
|
||||
elif kwargs.get(argument, None) is not None:
|
||||
wrapper_kwargs[argument] = kwargs[argument]
|
||||
elif init_defaults[argument] is not None:
|
||||
wrapper_kwargs[argument] = init_defaults[argument]
|
||||
|
||||
if "cache_config" in init_arguments:
|
||||
cache_config_idx = init_arguments.index("cache_config")
|
||||
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
|
||||
wrapper_kwargs["device"] = args[cache_config_idx].device
|
||||
elif kwargs.get("cache_config", None) is not None:
|
||||
wrapper_kwargs["device"] = kwargs["cache_config"].device
|
||||
elif init_defaults["cache_config"] is not None:
|
||||
wrapper_kwargs["device"] = init_defaults["cache_config"].device
|
||||
|
||||
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
|
||||
# we create a dummy empty tensor for generic tensor flattening/unflattening
|
||||
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||
assert (
|
||||
func.__name__ in cls.__dict__
|
||||
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
|
||||
return getattr(cls, func.__name__)(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
def __bool__(self):
|
||||
# in many places, past_key_values is checked for not being None using `if past_key_values:`
|
||||
# I think `if past_key_values is not None:` should be used instead
|
||||
return self is not None # True
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
# originals
|
||||
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||
|
||||
# overrides
|
||||
for arg in list(args) + list(kwargs.values()):
|
||||
if isinstance(arg, (torch.device, str, int)):
|
||||
wrapper_kwargs["device"] = arg
|
||||
elif isinstance(arg, torch.dtype):
|
||||
wrapper_kwargs["dtype"] = arg
|
||||
|
||||
# new wrapper
|
||||
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
|
||||
return new_self
|
||||
|
||||
def clone(self):
|
||||
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
|
||||
new_self.__dict__ = copy.deepcopy(self.__dict__)
|
||||
return new_self
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -304,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
|
||||
|
||||
cache_implementation = "static"
|
||||
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
|
||||
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
|
||||
self.batch_size = batch_size
|
||||
self.max_cache_len = max_cache_len
|
||||
self.device = device
|
||||
@ -361,6 +426,16 @@ class DynamicCache(Cache):
|
||||
```
|
||||
"""
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, meta, _, __):
|
||||
cache = DynamicCache()
|
||||
cache._seen_tokens = meta["_seen_tokens"]
|
||||
cache._empty_tensor = inner_tensors["_empty_tensor"]
|
||||
return cache
|
||||
|
||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
@ -448,7 +523,7 @@ class DynamicCache(Cache):
|
||||
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
||||
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
||||
)
|
||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
|
||||
return layer_seq_length
|
||||
|
||||
def get_max_cache_shape(self) -> Optional[int]:
|
||||
@ -675,9 +750,6 @@ class QuantizedCache(DynamicCache):
|
||||
self.axis_key = cache_config.axis_key
|
||||
self.axis_value = cache_config.axis_value
|
||||
self.compute_dtype = cache_config.compute_dtype
|
||||
self.device = cache_config.device
|
||||
|
||||
super().__init__()
|
||||
|
||||
def update(
|
||||
self,
|
||||
@ -777,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
||||
raise ImportError(
|
||||
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
|
||||
)
|
||||
from optimum.quanto import MaxOptimizer, qint2, qint4
|
||||
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
|
||||
|
||||
if self.nbits not in [2, 4]:
|
||||
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
||||
@ -796,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
||||
def _quantize(self, tensor, axis):
|
||||
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import quantize_weight
|
||||
from optimum.quanto import quantize_weight # type: ignore
|
||||
|
||||
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
||||
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
||||
@ -1105,7 +1177,7 @@ class StaticCache(Cache):
|
||||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: torch.device = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
@ -1116,7 +1188,6 @@ class StaticCache(Cache):
|
||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
|
||||
@ -1125,8 +1196,6 @@ class StaticCache(Cache):
|
||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
self.num_key_value_heads = (
|
||||
config.num_attention_heads
|
||||
if getattr(config, "num_key_value_heads", None) is None
|
||||
@ -1144,18 +1213,10 @@ class StaticCache(Cache):
|
||||
layer_device = self.device
|
||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||
# Notes:
|
||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
||||
# it is not needed anyway)
|
||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
||||
if not is_torchdynamo_compiling():
|
||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||||
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||
# preventing compiled graph breaks when updating the cache.
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||
self.key_cache.append(new_layer_key_cache)
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
@ -1304,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
|
||||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: torch.device = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
@ -1619,7 +1680,7 @@ class HybridCache(Cache):
|
||||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
max_cache_len: int = None,
|
||||
device: Union[torch.device, str] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
max_batch_size: Optional[int] = None,
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
@ -1648,7 +1709,6 @@ class HybridCache(Cache):
|
||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||
)
|
||||
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||
self.is_sliding = torch.tensor(
|
||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||
@ -1781,7 +1841,7 @@ class HybridCache(Cache):
|
||||
return self.max_batch_size
|
||||
|
||||
|
||||
class MambaCache:
|
||||
class MambaCache(Cache):
|
||||
"""
|
||||
Cache for mamba model which does not have attention mechanism and key value states.
|
||||
|
||||
@ -1838,7 +1898,7 @@ class MambaCache:
|
||||
config: PretrainedConfig,
|
||||
batch_size: int = None,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
max_batch_size: Optional[int] = None,
|
||||
):
|
||||
if batch_size is not None:
|
||||
@ -1846,12 +1906,10 @@ class MambaCache:
|
||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||
)
|
||||
self.dtype = dtype
|
||||
self.max_batch_size = batch_size or max_batch_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.ssm_state_size = config.state_size
|
||||
self.conv_kernel_size = config.conv_kernel
|
||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
||||
|
||||
self.conv_states: List[torch.Tensor] = []
|
||||
self.ssm_states: List[torch.Tensor] = []
|
||||
@ -1981,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
|
||||
config: PretrainedConfig,
|
||||
max_batch_size: int,
|
||||
max_cache_len: Optional[int],
|
||||
device: Union[str, torch.device],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Union[torch.device, str] = torch.device("meta"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||
) -> None:
|
||||
super(Cache, self).__init__()
|
||||
self.max_batch_size = max_batch_size
|
||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
||||
self.offload_device = torch.device(offload_device)
|
||||
self.dtype = dtype if dtype is not None else torch.float32
|
||||
|
||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||
|
@ -731,6 +731,7 @@ class GenerationMixin:
|
||||
key != "cache_position"
|
||||
and dict_to_expand[key] is not None
|
||||
and isinstance(dict_to_expand[key], torch.Tensor)
|
||||
and not isinstance(dict_to_expand[key], Cache)
|
||||
):
|
||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||
return dict_to_expand
|
||||
@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
|
||||
"""
|
||||
if data is None:
|
||||
return [None] * (full_batch_size // split_size)
|
||||
if isinstance(data, torch.Tensor):
|
||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||
# New cache format
|
||||
elif isinstance(data, DynamicCache) or (
|
||||
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||
):
|
||||
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||
if isinstance(data, torch.Tensor):
|
||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||
elif isinstance(data, tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0], tuple):
|
||||
@ -4632,13 +4633,13 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
|
||||
"""
|
||||
if any(data is None for data in data):
|
||||
return None
|
||||
if isinstance(data[0], torch.Tensor):
|
||||
return torch.cat(data, dim=0)
|
||||
# New cache format
|
||||
elif isinstance(data[0], DynamicCache):
|
||||
if isinstance(data[0], DynamicCache):
|
||||
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
elif isinstance(data[0], EncoderDecoderCache):
|
||||
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||
elif isinstance(data[0], torch.Tensor):
|
||||
return torch.cat(data, dim=0)
|
||||
elif isinstance(data[0], tuple):
|
||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||
if isinstance(data[0][0], tuple):
|
||||
|
@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
StaticCache,
|
||||
)
|
||||
from transformers import PreTrainedModel, StaticCache
|
||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||
|
||||
|
||||
@ -68,6 +65,8 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||
|
||||
self.static_cache = StaticCache(
|
||||
config=self.model.config,
|
||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||
@ -75,14 +74,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
dtype=self.model.dtype,
|
||||
device=self.model.generation_config.cache_config.device,
|
||||
)
|
||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||
for i in range(len(self.static_cache.key_cache)):
|
||||
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||
|
||||
if self.is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(
|
||||
self.static_cache.max_cache_len,
|
||||
self.static_cache.max_cache_len,
|
||||
dtype=torch.bool,
|
||||
)
|
||||
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
|
||||
)
|
||||
self.register_buffer("mask", causal_mask, persistent=False)
|
||||
|
||||
@ -108,15 +106,20 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
||||
"""
|
||||
_, seqlen = input_ids.shape
|
||||
|
||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
past_key_values = self.static_cache
|
||||
|
||||
outs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attn_mask,
|
||||
position_ids=cache_position.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
past_key_values=self.static_cache,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
return outs.logits
|
||||
|
||||
@staticmethod
|
||||
@ -143,7 +146,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
||||
prompt_token_len = prompt_token_ids.shape[-1]
|
||||
max_generation_length = prompt_token_len + max_new_tokens
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
if buffer_name.startswith("key_cache"):
|
||||
max_cache_len = buffer.shape[2]
|
||||
max_generation_length = min(max_generation_length, max_cache_len)
|
||||
break
|
||||
|
@ -215,7 +215,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
||||
|
||||
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.layers_block_type = config.layers_block_type
|
||||
self.has_previous_state = False # only used by mamba
|
||||
intermediate_size = config.mamba_expand * config.hidden_size
|
||||
|
@ -129,7 +129,6 @@ class ZambaHybridDynamicCache(DynamicCache):
|
||||
"""
|
||||
|
||||
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
||||
self.dtype = dtype
|
||||
self.layers_block_type = config.layers_block_type
|
||||
self.has_previous_state = False # only used by mamba
|
||||
self.intermediate_size = config.mamba_expand * config.hidden_size
|
||||
@ -139,9 +138,7 @@ class ZambaHybridDynamicCache(DynamicCache):
|
||||
self.conv_states = []
|
||||
self.ssm_states = []
|
||||
self.transformer_layers = []
|
||||
self._modules = {}
|
||||
self._parameters = {}
|
||||
self._buffers = {}
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
self.conv_states += [
|
||||
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
|
||||
|
@ -35,7 +35,7 @@ from torch.fx._symbolic_trace import is_fx_tracing
|
||||
from torch.fx.proxy import ParameterProxy
|
||||
|
||||
from .. import logging
|
||||
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
||||
from ..cache_utils import Cache
|
||||
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from ..models.auto import get_values
|
||||
from ..models.auto.modeling_auto import (
|
||||
@ -811,40 +811,40 @@ def _proxies_to_metas(v):
|
||||
return v
|
||||
|
||||
|
||||
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||
global _CURRENT_TRACER
|
||||
if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
||||
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
|
||||
cache_proxy.install_orig_cache_cls(orig_cache_cls)
|
||||
return cache_proxy
|
||||
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||
# def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||
# global _CURRENT_TRACER
|
||||
# if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||
# raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
||||
# cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
|
||||
# cache_proxy.install_orig_cache_cls(orig_cache_cls)
|
||||
# return cache_proxy
|
||||
|
||||
return cache_proxy_factory_fn
|
||||
# return cache_proxy_factory_fn
|
||||
|
||||
|
||||
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
||||
ProxyableCache = HFProxyableClassMeta(
|
||||
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
|
||||
)
|
||||
ProxyableDynamicCache = HFProxyableClassMeta(
|
||||
"ProxyableDynamicCache",
|
||||
(DynamicCache,),
|
||||
{},
|
||||
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
|
||||
)
|
||||
ProxyableSinkCache = HFProxyableClassMeta(
|
||||
"ProxyableSinkCache",
|
||||
(SinkCache,),
|
||||
{},
|
||||
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
|
||||
)
|
||||
ProxyableStaticCache = HFProxyableClassMeta(
|
||||
"ProxyableStaticCache",
|
||||
(StaticCache,),
|
||||
{},
|
||||
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
|
||||
)
|
||||
# # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
||||
# ProxyableCache = HFProxyableClassMeta(
|
||||
# "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
|
||||
# )
|
||||
# ProxyableDynamicCache = HFProxyableClassMeta(
|
||||
# "ProxyableDynamicCache",
|
||||
# (DynamicCache,),
|
||||
# {},
|
||||
# proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
|
||||
# )
|
||||
# ProxyableSinkCache = HFProxyableClassMeta(
|
||||
# "ProxyableSinkCache",
|
||||
# (SinkCache,),
|
||||
# {},
|
||||
# proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
|
||||
# )
|
||||
# ProxyableStaticCache = HFProxyableClassMeta(
|
||||
# "ProxyableStaticCache",
|
||||
# (StaticCache,),
|
||||
# {},
|
||||
# proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
|
||||
# )
|
||||
|
||||
|
||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||
@ -879,10 +879,10 @@ class HFTracer(Tracer):
|
||||
"tril",
|
||||
]
|
||||
_CLASSES_TO_PATCH = {
|
||||
Cache: ProxyableCache,
|
||||
DynamicCache: ProxyableDynamicCache,
|
||||
SinkCache: ProxyableSinkCache,
|
||||
StaticCache: ProxyableStaticCache,
|
||||
# Cache: ProxyableCache,
|
||||
# DynamicCache: ProxyableDynamicCache,
|
||||
# SinkCache: ProxyableSinkCache,
|
||||
# StaticCache: ProxyableStaticCache,
|
||||
}
|
||||
|
||||
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||
|
@ -738,6 +738,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_export_static_cache(self):
|
||||
# this test only run with an accelerator but it doesn't need an accelerator ?
|
||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||
|
||||
|
@ -2390,7 +2390,7 @@ class ModelTesterMixin:
|
||||
elif tuple_object is None:
|
||||
return
|
||||
# model might return non-tensors objects (e.g. Cache class)
|
||||
elif isinstance(tuple_object, torch.Tensor):
|
||||
elif isinstance(tuple_object, torch.Tensor) and not isinstance(tuple_object, Cache):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
|
@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
|
||||
# Check if the exported model is configured with the `StaticCache` correctly
|
||||
n_static_key_caches = n_static_value_caches = 0
|
||||
for buffer_name, buffer in exported_program.named_buffers():
|
||||
if buffer_name.startswith("static_cache.key_cache"):
|
||||
if buffer_name.startswith("key_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_key_caches = n_static_key_caches + 1
|
||||
if buffer_name.startswith("static_cache.value_cache"):
|
||||
if buffer_name.startswith("value_cache"):
|
||||
self.assertTrue(buffer.shape[0] == batch_size)
|
||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||
n_static_value_caches = n_static_value_caches + 1
|
||||
@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
input_ids = gen_out
|
||||
|
||||
# We went well beyond the cache length
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
|
||||
|
||||
# And it still produces a coherent english
|
||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||
@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||
] # fmt: skip
|
||||
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
|
||||
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||
|
Reference in New Issue
Block a user