add comment

This commit is contained in:
Vasqu
2025-05-22 17:27:31 +02:00
parent a12f2f4382
commit 2602b0f950

View File

@ -313,7 +313,11 @@ class WhisperAttention(nn.Module):
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
# get query proj
# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. whisper is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()