This commit is contained in:
Younes Belkada
2022-06-27 14:08:18 +02:00
committed by GitHub
parent ee0d001de7
commit 3ec7d4cfe4

View File

@ -292,17 +292,17 @@ class BloomScaledSoftmax(nn.Module):
if self.scale is not None:
input = input * self.scale
if mask is not None:
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
else:
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)