Fix cache update! (#38046)

* fix slicing

* better fix
This commit is contained in:
Cyril Vallez
2025-05-09 17:54:48 +02:00
committed by GitHub
parent 7f1a97bae3
commit aaed2f5577

View File

@ -1402,7 +1402,7 @@ class SlidingWindowCache(StaticCache):
value_states = value_states.to(v_out.dtype)
# assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
if cache_position.shape[0] > self.max_cache_len:
if cache_position.shape[0] >= self.max_cache_len:
k_out = key_states[:, :, -self.max_cache_len :, :]
v_out = value_states[:, :, -self.max_cache_len :, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
@ -1413,8 +1413,8 @@ class SlidingWindowCache(StaticCache):
return key_states, value_states
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position > self.max_cache_len - 1
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
k_out = k_out[:, :, indices]
@ -1725,7 +1725,7 @@ class HybridCache(Cache):
self.value_cache.append(new_layer_value_cache)
def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
if cache_position.shape[0] > max_cache_len:
if cache_position.shape[0] >= max_cache_len:
k_out = key_states[:, :, -max_cache_len:, :]
v_out = value_states[:, :, -max_cache_len:, :]
# Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
@ -1736,8 +1736,8 @@ class HybridCache(Cache):
return key_states, value_states
slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, max_cache_len - 1)
to_shift = cache_position > max_cache_len - 1
cache_position = cache_position.clamp(0, max_cache_len - 1)
indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]