Compare commits

...

3 Commits

2 changed files with 16 additions and 4 deletions

View File

@ -3252,7 +3252,7 @@ class GenerationMixin:
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
is_compileable = is_compileable and not generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):

View File

@ -81,6 +81,12 @@ class MistralAttention(LlamaAttention):
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation == "flash_attention_2" and attention_mask.dim() == 4:
# for static cache, the attention mask is 4D, but flash_attention_2 expects 2D, so we recover the 2D mask
min_dtype = torch.finfo(query_states.dtype).min
query_length = attention_mask.size(2)
attention_mask = (attention_mask != min_dtype).int()[:, 0, query_length - 1, :query_length]
attn_output, attn_weights = attention_interface(
self,
query_states,
@ -120,9 +126,17 @@ class MistralModel(LlamaModel):
past_key_values: Cache,
output_attentions: bool,
):
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and past_key_values is not None:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if attention_mask.dim() == 4: # for static cache, the attention mask is 4D, so we recover the 2D mask
query_length = attention_mask.size(2)
mask_2d = (attention_mask != min_dtype).int()[:, 0, query_length - 1, :query_length]
else:
mask_2d = attention_mask
is_padding_right = mask_2d[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
@ -155,8 +169,6 @@ class MistralModel(LlamaModel):
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache or StaticCache
if using_sliding_window_cache or using_static_cache: