mirror of
https://github.com/huggingface/transformers.git
synced 2025-11-11 16:54:37 +08:00
@ -1402,7 +1402,7 @@ class SlidingWindowCache(StaticCache):
|
||||
value_states = value_states.to(v_out.dtype)
|
||||
|
||||
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
|
||||
if cache_position.shape[0] > self.max_cache_len:
|
||||
if cache_position.shape[0] >= self.max_cache_len:
|
||||
k_out = key_states[:, :, -self.max_cache_len :, :]
|
||||
v_out = value_states[:, :, -self.max_cache_len :, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
@ -1413,8 +1413,8 @@ class SlidingWindowCache(StaticCache):
|
||||
return key_states, value_states
|
||||
|
||||
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
||||
to_shift = cache_position > self.max_cache_len - 1
|
||||
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
|
||||
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
|
||||
|
||||
k_out = k_out[:, :, indices]
|
||||
@ -1725,7 +1725,7 @@ class HybridCache(Cache):
|
||||
self.value_cache.append(new_layer_value_cache)
|
||||
|
||||
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
|
||||
if cache_position.shape[0] > max_cache_len:
|
||||
if cache_position.shape[0] >= max_cache_len:
|
||||
k_out = key_states[:, :, -max_cache_len:, :]
|
||||
v_out = value_states[:, :, -max_cache_len:, :]
|
||||
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
|
||||
@ -1736,8 +1736,8 @@ class HybridCache(Cache):
|
||||
return key_states, value_states
|
||||
|
||||
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
to_shift = cache_position > max_cache_len - 1
|
||||
cache_position = cache_position.clamp(0, max_cache_len - 1)
|
||||
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
|
||||
k_out = k_out[:, :, indices]
|
||||
v_out = v_out[:, :, indices]
|
||||
|
||||
Reference in New Issue
Block a user