mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-22 10:19:00 +08:00
Compare commits
19 Commits
v4.55.4
...
tensor-cac
Author | SHA1 | Date | |
---|---|---|---|
090d9c4b2a | |||
5ccb79c16d | |||
80b49d721b | |||
dc1bd15ba9 | |||
338f5954b9 | |||
2f4e0bc93e | |||
485f959f85 | |||
2bbbbbcf97 | |||
85c71b004b | |||
da60604f2c | |||
6e9799c817 | |||
4950a9e3f0 | |||
b67b6eb9b2 | |||
d269417aab | |||
95c1686ee0 | |||
8606594ad4 | |||
45bb39bb80 | |||
a77a94b209 | |||
d4b631edd0 |
@ -1,5 +1,6 @@
|
|||||||
import copy
|
import copy
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -9,12 +10,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .utils import (
|
from .utils import is_hqq_available, is_optimum_quanto_available, logging
|
||||||
is_hqq_available,
|
|
||||||
is_optimum_quanto_available,
|
|
||||||
is_torchdynamo_compiling,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from .utils.deprecation import deprecate_kwarg
|
from .utils.deprecation import deprecate_kwarg
|
||||||
|
|
||||||
|
|
||||||
@ -24,13 +20,82 @@ if is_hqq_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Cache(torch.nn.Module):
|
class Cache(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
Base, abstract class for all caches. The actual data structure is specific to each subclass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
@staticmethod
|
||||||
super().__init__()
|
def __new__(cls, *args, **kwargs):
|
||||||
|
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method
|
||||||
|
|
||||||
|
wrapper_kwargs = {}
|
||||||
|
init_signature = inspect.signature(cls.__init__)
|
||||||
|
init_arguments = list(init_signature.parameters.keys())
|
||||||
|
init_defaults = {
|
||||||
|
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
|
||||||
|
}
|
||||||
|
|
||||||
|
for argument in ["dtype", "device"]:
|
||||||
|
if argument in init_arguments:
|
||||||
|
arg_idx = init_arguments.index(argument)
|
||||||
|
if len(args) > arg_idx and args[arg_idx] is not None:
|
||||||
|
wrapper_kwargs[argument] = args[arg_idx]
|
||||||
|
elif kwargs.get(argument, None) is not None:
|
||||||
|
wrapper_kwargs[argument] = kwargs[argument]
|
||||||
|
elif init_defaults[argument] is not None:
|
||||||
|
wrapper_kwargs[argument] = init_defaults[argument]
|
||||||
|
|
||||||
|
if "cache_config" in init_arguments:
|
||||||
|
cache_config_idx = init_arguments.index("cache_config")
|
||||||
|
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
|
||||||
|
wrapper_kwargs["device"] = args[cache_config_idx].device
|
||||||
|
elif kwargs.get("cache_config", None) is not None:
|
||||||
|
wrapper_kwargs["device"] = kwargs["cache_config"].device
|
||||||
|
elif init_defaults["cache_config"] is not None:
|
||||||
|
wrapper_kwargs["device"] = init_defaults["cache_config"].device
|
||||||
|
|
||||||
|
self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
|
||||||
|
# we create a dummy empty tensor for generic tensor flattening/unflattening
|
||||||
|
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
|
||||||
|
return self
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args, kwargs):
|
||||||
|
assert (
|
||||||
|
func.__name__ in cls.__dict__
|
||||||
|
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
|
||||||
|
return getattr(cls, func.__name__)(*args, **kwargs)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
# in many places, past_key_values is checked for not being None using `if past_key_values:`
|
||||||
|
# I think `if past_key_values is not None:` should be used instead
|
||||||
|
return self is not None # True
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
# originals
|
||||||
|
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||||
|
|
||||||
|
# overrides
|
||||||
|
for arg in list(args) + list(kwargs.values()):
|
||||||
|
if isinstance(arg, (torch.device, str, int)):
|
||||||
|
wrapper_kwargs["device"] = arg
|
||||||
|
elif isinstance(arg, torch.dtype):
|
||||||
|
wrapper_kwargs["dtype"] = arg
|
||||||
|
|
||||||
|
# new wrapper
|
||||||
|
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
|
||||||
|
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
|
||||||
|
return new_self
|
||||||
|
|
||||||
|
def clone(self):
|
||||||
|
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
|
||||||
|
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
|
||||||
|
new_self.__dict__ = copy.deepcopy(self.__dict__)
|
||||||
|
return new_self
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@ -304,7 +369,7 @@ class StaticCacheConfig(CacheConfig):
|
|||||||
|
|
||||||
cache_implementation = "static"
|
cache_implementation = "static"
|
||||||
|
|
||||||
def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
|
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.max_cache_len = max_cache_len
|
self.max_cache_len = max_cache_len
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -361,6 +426,16 @@ class DynamicCache(Cache):
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, meta, _, __):
|
||||||
|
cache = DynamicCache()
|
||||||
|
cache._seen_tokens = meta["_seen_tokens"]
|
||||||
|
cache._empty_tensor = inner_tensors["_empty_tensor"]
|
||||||
|
return cache
|
||||||
|
|
||||||
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
@deprecate_kwarg("num_hidden_layers", version="4.47.0")
|
||||||
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -448,7 +523,7 @@ class DynamicCache(Cache):
|
|||||||
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
|
||||||
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
|
||||||
)
|
)
|
||||||
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
|
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
|
||||||
return layer_seq_length
|
return layer_seq_length
|
||||||
|
|
||||||
def get_max_cache_shape(self) -> Optional[int]:
|
def get_max_cache_shape(self) -> Optional[int]:
|
||||||
@ -675,9 +750,6 @@ class QuantizedCache(DynamicCache):
|
|||||||
self.axis_key = cache_config.axis_key
|
self.axis_key = cache_config.axis_key
|
||||||
self.axis_value = cache_config.axis_value
|
self.axis_value = cache_config.axis_value
|
||||||
self.compute_dtype = cache_config.compute_dtype
|
self.compute_dtype = cache_config.compute_dtype
|
||||||
self.device = cache_config.device
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
self,
|
self,
|
||||||
@ -777,7 +849,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
|
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
|
||||||
)
|
)
|
||||||
from optimum.quanto import MaxOptimizer, qint2, qint4
|
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore
|
||||||
|
|
||||||
if self.nbits not in [2, 4]:
|
if self.nbits not in [2, 4]:
|
||||||
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
|
||||||
@ -796,7 +868,7 @@ class QuantoQuantizedCache(QuantizedCache):
|
|||||||
def _quantize(self, tensor, axis):
|
def _quantize(self, tensor, axis):
|
||||||
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
|
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
|
||||||
if is_optimum_quanto_available():
|
if is_optimum_quanto_available():
|
||||||
from optimum.quanto import quantize_weight
|
from optimum.quanto import quantize_weight # type: ignore
|
||||||
|
|
||||||
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
|
||||||
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
|
||||||
@ -1105,7 +1177,7 @@ class StaticCache(Cache):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
max_cache_len: int = None,
|
max_cache_len: int = None,
|
||||||
device: torch.device = None,
|
device: Union[torch.device, str] = torch.device("meta"),
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
@ -1116,7 +1188,6 @@ class StaticCache(Cache):
|
|||||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_batch_size = batch_size or max_batch_size
|
self.max_batch_size = batch_size or max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
|
|
||||||
@ -1125,8 +1196,6 @@ class StaticCache(Cache):
|
|||||||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
self.num_key_value_heads = (
|
self.num_key_value_heads = (
|
||||||
config.num_attention_heads
|
config.num_attention_heads
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
if getattr(config, "num_key_value_heads", None) is None
|
||||||
@ -1144,18 +1213,10 @@ class StaticCache(Cache):
|
|||||||
layer_device = self.device
|
layer_device = self.device
|
||||||
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
|
||||||
# Notes:
|
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
|
||||||
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
# preventing compiled graph breaks when updating the cache.
|
||||||
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
# it is not needed anyway)
|
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||||||
# 2. `torch.export()` requires mutations to be registered as buffers.
|
|
||||||
if not is_torchdynamo_compiling():
|
|
||||||
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
|
||||||
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
|
|
||||||
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
|
|
||||||
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
|
|
||||||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
|
||||||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
|
||||||
self.key_cache.append(new_layer_key_cache)
|
self.key_cache.append(new_layer_key_cache)
|
||||||
self.value_cache.append(new_layer_value_cache)
|
self.value_cache.append(new_layer_value_cache)
|
||||||
|
|
||||||
@ -1304,7 +1365,7 @@ class SlidingWindowCache(StaticCache):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
max_cache_len: int = None,
|
max_cache_len: int = None,
|
||||||
device: torch.device = None,
|
device: Union[torch.device, str] = torch.device("meta"),
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
@ -1619,7 +1680,7 @@ class HybridCache(Cache):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
max_cache_len: int = None,
|
max_cache_len: int = None,
|
||||||
device: Union[torch.device, str] = None,
|
device: Union[torch.device, str] = torch.device("meta"),
|
||||||
dtype: torch.dtype = torch.float32,
|
dtype: torch.dtype = torch.float32,
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
@ -1648,7 +1709,6 @@ class HybridCache(Cache):
|
|||||||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||||||
)
|
)
|
||||||
|
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
|
||||||
self.is_sliding = torch.tensor(
|
self.is_sliding = torch.tensor(
|
||||||
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
|
||||||
@ -1781,7 +1841,7 @@ class HybridCache(Cache):
|
|||||||
return self.max_batch_size
|
return self.max_batch_size
|
||||||
|
|
||||||
|
|
||||||
class MambaCache:
|
class MambaCache(Cache):
|
||||||
"""
|
"""
|
||||||
Cache for mamba model which does not have attention mechanism and key value states.
|
Cache for mamba model which does not have attention mechanism and key value states.
|
||||||
|
|
||||||
@ -1838,7 +1898,7 @@ class MambaCache:
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
batch_size: int = None,
|
batch_size: int = None,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
device: Optional[Union[torch.device, str]] = None,
|
device: Union[torch.device, str] = torch.device("meta"),
|
||||||
max_batch_size: Optional[int] = None,
|
max_batch_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if batch_size is not None:
|
if batch_size is not None:
|
||||||
@ -1846,12 +1906,10 @@ class MambaCache:
|
|||||||
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
|
||||||
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
|
||||||
)
|
)
|
||||||
self.dtype = dtype
|
|
||||||
self.max_batch_size = batch_size or max_batch_size
|
self.max_batch_size = batch_size or max_batch_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.ssm_state_size = config.state_size
|
self.ssm_state_size = config.state_size
|
||||||
self.conv_kernel_size = config.conv_kernel
|
self.conv_kernel_size = config.conv_kernel
|
||||||
self.device = torch.device(device) if device is not None else torch.device("meta")
|
|
||||||
|
|
||||||
self.conv_states: List[torch.Tensor] = []
|
self.conv_states: List[torch.Tensor] = []
|
||||||
self.ssm_states: List[torch.Tensor] = []
|
self.ssm_states: List[torch.Tensor] = []
|
||||||
@ -1981,17 +2039,14 @@ class OffloadedStaticCache(StaticCache):
|
|||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
max_cache_len: Optional[int],
|
max_cache_len: Optional[int],
|
||||||
device: Union[str, torch.device],
|
device: Union[torch.device, str] = torch.device("meta"),
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: torch.dtype = torch.float32,
|
||||||
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
offload_device: Union[str, torch.device] = torch.device("cpu"),
|
||||||
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Cache, self).__init__()
|
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||||||
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
|
|
||||||
self.offload_device = torch.device(offload_device)
|
self.offload_device = torch.device(offload_device)
|
||||||
self.dtype = dtype if dtype is not None else torch.float32
|
|
||||||
|
|
||||||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||||||
|
@ -731,6 +731,7 @@ class GenerationMixin:
|
|||||||
key != "cache_position"
|
key != "cache_position"
|
||||||
and dict_to_expand[key] is not None
|
and dict_to_expand[key] is not None
|
||||||
and isinstance(dict_to_expand[key], torch.Tensor)
|
and isinstance(dict_to_expand[key], torch.Tensor)
|
||||||
|
and not isinstance(dict_to_expand[key], Cache)
|
||||||
):
|
):
|
||||||
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
||||||
return dict_to_expand
|
return dict_to_expand
|
||||||
@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
|
|||||||
"""
|
"""
|
||||||
if data is None:
|
if data is None:
|
||||||
return [None] * (full_batch_size // split_size)
|
return [None] * (full_batch_size // split_size)
|
||||||
if isinstance(data, torch.Tensor):
|
|
||||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
|
||||||
# New cache format
|
# New cache format
|
||||||
elif isinstance(data, DynamicCache) or (
|
elif isinstance(data, DynamicCache) or (
|
||||||
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||||
):
|
):
|
||||||
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
|
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||||
elif isinstance(data, tuple):
|
elif isinstance(data, tuple):
|
||||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
if isinstance(data[0], tuple):
|
if isinstance(data[0], tuple):
|
||||||
@ -4632,13 +4633,13 @@ def stack_model_outputs(model_outputs: List[ModelOutput], config: PretrainedConf
|
|||||||
"""
|
"""
|
||||||
if any(data is None for data in data):
|
if any(data is None for data in data):
|
||||||
return None
|
return None
|
||||||
if isinstance(data[0], torch.Tensor):
|
|
||||||
return torch.cat(data, dim=0)
|
|
||||||
# New cache format
|
# New cache format
|
||||||
elif isinstance(data[0], DynamicCache):
|
if isinstance(data[0], DynamicCache):
|
||||||
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||||
elif isinstance(data[0], EncoderDecoderCache):
|
elif isinstance(data[0], EncoderDecoderCache):
|
||||||
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
|
||||||
|
elif isinstance(data[0], torch.Tensor):
|
||||||
|
return torch.cat(data, dim=0)
|
||||||
elif isinstance(data[0], tuple):
|
elif isinstance(data[0], tuple):
|
||||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
if isinstance(data[0][0], tuple):
|
if isinstance(data[0][0], tuple):
|
||||||
|
@ -16,10 +16,7 @@ from ..utils.import_utils import is_torch_available
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from transformers import (
|
from transformers import PreTrainedModel, StaticCache
|
||||||
PreTrainedModel,
|
|
||||||
StaticCache,
|
|
||||||
)
|
|
||||||
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
|
||||||
|
|
||||||
|
|
||||||
@ -68,6 +65,8 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
||||||
|
|
||||||
self.static_cache = StaticCache(
|
self.static_cache = StaticCache(
|
||||||
config=self.model.config,
|
config=self.model.config,
|
||||||
batch_size=self.model.generation_config.cache_config.batch_size,
|
batch_size=self.model.generation_config.cache_config.batch_size,
|
||||||
@ -75,14 +74,13 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
dtype=self.model.dtype,
|
dtype=self.model.dtype,
|
||||||
device=self.model.generation_config.cache_config.device,
|
device=self.model.generation_config.cache_config.device,
|
||||||
)
|
)
|
||||||
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
|
for i in range(len(self.static_cache.key_cache)):
|
||||||
|
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
|
||||||
|
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
|
||||||
|
|
||||||
if self.is_causal:
|
if self.is_causal:
|
||||||
causal_mask = torch.tril(
|
causal_mask = torch.tril(
|
||||||
torch.ones(
|
torch.ones(self.static_cache.max_cache_len, self.static_cache.max_cache_len, dtype=torch.bool)
|
||||||
self.static_cache.max_cache_len,
|
|
||||||
self.static_cache.max_cache_len,
|
|
||||||
dtype=torch.bool,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.register_buffer("mask", causal_mask, persistent=False)
|
self.register_buffer("mask", causal_mask, persistent=False)
|
||||||
|
|
||||||
@ -108,15 +106,20 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
|
||||||
"""
|
"""
|
||||||
_, seqlen = input_ids.shape
|
_, seqlen = input_ids.shape
|
||||||
|
|
||||||
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
past_key_values = self.static_cache
|
||||||
|
|
||||||
outs = self.model(
|
outs = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attn_mask,
|
attention_mask=attn_mask,
|
||||||
position_ids=cache_position.unsqueeze(0),
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
cache_position=cache_position,
|
cache_position=cache_position,
|
||||||
past_key_values=self.static_cache,
|
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return outs.logits
|
return outs.logits
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -143,7 +146,7 @@ class TorchExportableModuleWithStaticCache(torch.nn.Module):
|
|||||||
prompt_token_len = prompt_token_ids.shape[-1]
|
prompt_token_len = prompt_token_ids.shape[-1]
|
||||||
max_generation_length = prompt_token_len + max_new_tokens
|
max_generation_length = prompt_token_len + max_new_tokens
|
||||||
for buffer_name, buffer in exported_program.named_buffers():
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
if buffer_name.startswith("static_cache.key_cache"):
|
if buffer_name.startswith("key_cache"):
|
||||||
max_cache_len = buffer.shape[2]
|
max_cache_len = buffer.shape[2]
|
||||||
max_generation_length = min(max_generation_length, max_cache_len)
|
max_generation_length = min(max_generation_length, max_cache_len)
|
||||||
break
|
break
|
||||||
|
@ -215,7 +215,6 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
|
|||||||
|
|
||||||
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
|
||||||
self.layers_block_type = config.layers_block_type
|
self.layers_block_type = config.layers_block_type
|
||||||
self.has_previous_state = False # only used by mamba
|
self.has_previous_state = False # only used by mamba
|
||||||
intermediate_size = config.mamba_expand * config.hidden_size
|
intermediate_size = config.mamba_expand * config.hidden_size
|
||||||
|
@ -129,7 +129,6 @@ class ZambaHybridDynamicCache(DynamicCache):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
|
||||||
self.dtype = dtype
|
|
||||||
self.layers_block_type = config.layers_block_type
|
self.layers_block_type = config.layers_block_type
|
||||||
self.has_previous_state = False # only used by mamba
|
self.has_previous_state = False # only used by mamba
|
||||||
self.intermediate_size = config.mamba_expand * config.hidden_size
|
self.intermediate_size = config.mamba_expand * config.hidden_size
|
||||||
@ -139,9 +138,7 @@ class ZambaHybridDynamicCache(DynamicCache):
|
|||||||
self.conv_states = []
|
self.conv_states = []
|
||||||
self.ssm_states = []
|
self.ssm_states = []
|
||||||
self.transformer_layers = []
|
self.transformer_layers = []
|
||||||
self._modules = {}
|
|
||||||
self._parameters = {}
|
|
||||||
self._buffers = {}
|
|
||||||
for i in range(config.num_hidden_layers):
|
for i in range(config.num_hidden_layers):
|
||||||
self.conv_states += [
|
self.conv_states += [
|
||||||
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
|
torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
|
||||||
|
@ -35,7 +35,7 @@ from torch.fx._symbolic_trace import is_fx_tracing
|
|||||||
from torch.fx.proxy import ParameterProxy
|
from torch.fx.proxy import ParameterProxy
|
||||||
|
|
||||||
from .. import logging
|
from .. import logging
|
||||||
from ..cache_utils import Cache, DynamicCache, SinkCache, StaticCache
|
from ..cache_utils import Cache
|
||||||
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
from ..modeling_utils import PretrainedConfig, PreTrainedModel
|
||||||
from ..models.auto import get_values
|
from ..models.auto import get_values
|
||||||
from ..models.auto.modeling_auto import (
|
from ..models.auto.modeling_auto import (
|
||||||
@ -811,40 +811,40 @@ def _proxies_to_metas(v):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
# def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
|
||||||
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
# def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
|
||||||
global _CURRENT_TRACER
|
# global _CURRENT_TRACER
|
||||||
if not isinstance(_CURRENT_TRACER, HFTracer):
|
# if not isinstance(_CURRENT_TRACER, HFTracer):
|
||||||
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
# raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
|
||||||
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
|
# cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
|
||||||
cache_proxy.install_orig_cache_cls(orig_cache_cls)
|
# cache_proxy.install_orig_cache_cls(orig_cache_cls)
|
||||||
return cache_proxy
|
# return cache_proxy
|
||||||
|
|
||||||
return cache_proxy_factory_fn
|
# return cache_proxy_factory_fn
|
||||||
|
|
||||||
|
|
||||||
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
# # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
|
||||||
ProxyableCache = HFProxyableClassMeta(
|
# ProxyableCache = HFProxyableClassMeta(
|
||||||
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
|
# "ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
|
||||||
)
|
# )
|
||||||
ProxyableDynamicCache = HFProxyableClassMeta(
|
# ProxyableDynamicCache = HFProxyableClassMeta(
|
||||||
"ProxyableDynamicCache",
|
# "ProxyableDynamicCache",
|
||||||
(DynamicCache,),
|
# (DynamicCache,),
|
||||||
{},
|
# {},
|
||||||
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
|
# proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
|
||||||
)
|
# )
|
||||||
ProxyableSinkCache = HFProxyableClassMeta(
|
# ProxyableSinkCache = HFProxyableClassMeta(
|
||||||
"ProxyableSinkCache",
|
# "ProxyableSinkCache",
|
||||||
(SinkCache,),
|
# (SinkCache,),
|
||||||
{},
|
# {},
|
||||||
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
|
# proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
|
||||||
)
|
# )
|
||||||
ProxyableStaticCache = HFProxyableClassMeta(
|
# ProxyableStaticCache = HFProxyableClassMeta(
|
||||||
"ProxyableStaticCache",
|
# "ProxyableStaticCache",
|
||||||
(StaticCache,),
|
# (StaticCache,),
|
||||||
{},
|
# {},
|
||||||
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
|
# proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
def _generate_random_int(low: int = 10, high: int = 20, forbidden_values: Optional[List[int]] = None):
|
||||||
@ -879,10 +879,10 @@ class HFTracer(Tracer):
|
|||||||
"tril",
|
"tril",
|
||||||
]
|
]
|
||||||
_CLASSES_TO_PATCH = {
|
_CLASSES_TO_PATCH = {
|
||||||
Cache: ProxyableCache,
|
# Cache: ProxyableCache,
|
||||||
DynamicCache: ProxyableDynamicCache,
|
# DynamicCache: ProxyableDynamicCache,
|
||||||
SinkCache: ProxyableSinkCache,
|
# SinkCache: ProxyableSinkCache,
|
||||||
StaticCache: ProxyableStaticCache,
|
# StaticCache: ProxyableStaticCache,
|
||||||
}
|
}
|
||||||
|
|
||||||
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
|
||||||
|
@ -738,6 +738,7 @@ class LlamaIntegrationTest(unittest.TestCase):
|
|||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
def test_export_static_cache(self):
|
def test_export_static_cache(self):
|
||||||
|
# this test only run with an accelerator but it doesn't need an accelerator ?
|
||||||
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
if version.parse(torch.__version__) < version.parse("2.4.0"):
|
||||||
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
self.skipTest(reason="This test requires torch >= 2.4 to run.")
|
||||||
|
|
||||||
|
@ -2390,7 +2390,7 @@ class ModelTesterMixin:
|
|||||||
elif tuple_object is None:
|
elif tuple_object is None:
|
||||||
return
|
return
|
||||||
# model might return non-tensors objects (e.g. Cache class)
|
# model might return non-tensors objects (e.g. Cache class)
|
||||||
elif isinstance(tuple_object, torch.Tensor):
|
elif isinstance(tuple_object, torch.Tensor) and not isinstance(tuple_object, Cache):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(
|
torch.allclose(
|
||||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||||
|
@ -215,11 +215,11 @@ class CacheTest(unittest.TestCase):
|
|||||||
# Check if the exported model is configured with the `StaticCache` correctly
|
# Check if the exported model is configured with the `StaticCache` correctly
|
||||||
n_static_key_caches = n_static_value_caches = 0
|
n_static_key_caches = n_static_value_caches = 0
|
||||||
for buffer_name, buffer in exported_program.named_buffers():
|
for buffer_name, buffer in exported_program.named_buffers():
|
||||||
if buffer_name.startswith("static_cache.key_cache"):
|
if buffer_name.startswith("key_cache"):
|
||||||
self.assertTrue(buffer.shape[0] == batch_size)
|
self.assertTrue(buffer.shape[0] == batch_size)
|
||||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
n_static_key_caches = n_static_key_caches + 1
|
n_static_key_caches = n_static_key_caches + 1
|
||||||
if buffer_name.startswith("static_cache.value_cache"):
|
if buffer_name.startswith("value_cache"):
|
||||||
self.assertTrue(buffer.shape[0] == batch_size)
|
self.assertTrue(buffer.shape[0] == batch_size)
|
||||||
self.assertTrue(buffer.shape[2] == max_cache_len)
|
self.assertTrue(buffer.shape[2] == max_cache_len)
|
||||||
n_static_value_caches = n_static_value_caches + 1
|
n_static_value_caches = n_static_value_caches + 1
|
||||||
@ -364,7 +364,7 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
input_ids = gen_out
|
input_ids = gen_out
|
||||||
|
|
||||||
# We went well beyond the cache length
|
# We went well beyond the cache length
|
||||||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)
|
||||||
|
|
||||||
# And it still produces a coherent english
|
# And it still produces a coherent english
|
||||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||||
@ -619,4 +619,4 @@ class CacheIntegrationTest(unittest.TestCase):
|
|||||||
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
|
||||||
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
|
||||||
] # fmt: skip
|
] # fmt: skip
|
||||||
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
|
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
|
||||||
|
Reference in New Issue
Block a user