Compare commits

...

3 Commits

Author SHA1 Message Date
08b9e37639 debug 2023-10-11 11:45:33 +02:00
4fa06e3276 reorder 2023-10-11 11:44:37 +02:00
b8eacdab4b Test 2023-10-11 11:43:36 +02:00

View File

@ -317,7 +317,27 @@ class MistralAttention(nn.Module):
return attn_output, attn_weights, past_key_value
class MistralFlashAttention2(MistralAttention):
class FlashAttentionMixin(torch.nn.Module):
def to(self, *args, **kwargs):
print("Called to")
target_dtype = None
if "dtype" not in kwargs:
for arg in args:
if isinstance(arg, torch.dtype):
target_dtype = arg
break
else:
target_dtype = kwargs["dtype"]
if target_dtype is not None and target_dtype == torch.float32:
raise ValueError(
"You cannot cast a model that has been loaded with Flash Attention 2 in `float32`"
)
return super().to(*args, **kwargs)
class MistralFlashAttention2(FlashAttentionMixin, MistralAttention):
"""
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of