mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
fix mask (#17837)
This commit is contained in:
@ -292,7 +292,9 @@ class BloomScaledSoftmax(nn.Module):
|
||||
if self.scale is not None:
|
||||
input = input * self.scale
|
||||
|
||||
if mask is not None:
|
||||
if mask is None:
|
||||
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
||||
|
||||
mask = mask.to(input.device)
|
||||
causal_mask = (
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
|
||||
@ -301,8 +303,6 @@ class BloomScaledSoftmax(nn.Module):
|
||||
)
|
||||
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
||||
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
||||
else:
|
||||
probs = nn.functional.softmax(input, dim=-1, dtype=softmax_dtype)
|
||||
|
||||
if input_in_16bit and self.softmax_in_fp32:
|
||||
probs = probs.to(dtype=input_dtype)
|
||||
|
Reference in New Issue
Block a user