This commit is contained in:
Younes Belkada
2022-06-27 14:08:18 +02:00
committed by GitHub
parent ee0d001de7
commit 3ec7d4cfe4

View File

@ -292,7 +292,9 @@ class BloomScaledSoftmax(nn.Module):
if self.scale is not None:
input = input * self.scale
if mask is not None:
if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
@ -301,8 +303,6 @@ class BloomScaledSoftmax(nn.Module):
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
else:
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)