Compare commits

...

57 Commits

Author SHA1 Message Date
53e8738cc8 Woops 2022-08-02 16:35:20 +02:00
784073a1ff I was wrong 2022-08-02 16:29:34 +02:00
13b9ce187f Maybe this needs to be duplicated 2022-08-02 15:48:48 +02:00
c127d24caf Fix cat dimension 2022-08-02 12:06:21 +02:00
0e1a758eb4 Fix in-place operation 2022-08-02 12:05:38 +02:00
34f7e342f7 Improve typing 2022-08-02 11:54:26 +02:00
5167346acc Reduce the number of concatenation 2022-08-02 11:51:09 +02:00
8f4d60357a Revert "Add back layer_number normalization"
This reverts commit 50e1f2f9fd26abcb6ee9735f2838b796eeec8269.
2022-08-01 15:09:58 +02:00
50e1f2f9fd Add back layer_number normalization 2022-08-01 14:52:21 +02:00
7c399e6a3c Add comment about support for torchScript v1.11 2022-08-01 11:47:32 +02:00
bd1ae60d3a Apply suggestions from code review
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
2022-08-01 11:45:38 +02:00
fd117da195 Bypass torch.baddbmm 2022-07-29 17:10:57 +02:00
b58a6f7298 Try passing a tuple 2022-07-29 16:50:59 +02:00
790e23824c Does this help? 2022-07-29 16:37:50 +02:00
1e4ccaf34d Revert "Fix FX test"
This reverts commit a3d50c08708386c363bbe5b4ac2204e9856b5568.
2022-07-29 16:28:24 +02:00
a3d50c0870 Fix FX test 2022-07-29 16:22:30 +02:00
87861b2ea5 Merge remote-tracking branch 'origin/main' into thomas/bloom_clean_code 2022-07-29 16:17:10 +02:00
88814bf4f1 We don't actually need layer_number to be passed anymore 2022-07-29 15:48:21 +02:00
49aff18fee Actually i think the attention casting only makes sense when we use torch.float16 2022-07-29 15:41:09 +02:00
5fcc118aba Woops 2022-07-29 15:14:24 +02:00
1c93638e8a Make style 2022-07-29 15:00:27 +02:00
98fdf998db Add comment that explains that some methods return views 2022-07-29 14:59:08 +02:00
8699509cd0 Improve type hinting 2022-07-29 14:52:40 +02:00
2d58b7e320 Rename with long version of variables 2022-07-29 14:24:06 +02:00
323c073167 No need to manually cast the values to a specific device 2022-07-29 12:23:06 +02:00
4db82d46a0 Optimize the device allocation in case of hidden_states in multiple devices 2022-07-29 12:05:28 +02:00
b4346e14b9 make style 2022-07-29 10:41:35 +02:00
69663c5191 Improve documentation on past_key_values format 2022-07-29 10:34:50 +02:00
47608f85b6 Woops 2022-07-28 23:43:42 +02:00
c02f409343 I don't like kwargs 2022-07-28 23:40:26 +02:00
02bf51d0cb Woops 2022-07-28 23:29:25 +02:00
ad1bfe963d Nits 2022-07-28 23:26:16 +02:00
47e1969b62 Revert attempt to be backward compatible 2022-07-28 22:54:54 +02:00
4596be9e57 Woops 2022-07-28 22:49:53 +02:00
7795e23843 Try to fix beam_search 2022-07-28 22:49:06 +02:00
995d31a0c8 Woops 2022-07-28 22:44:59 +02:00
6c3bf96202 Try and be backward compatible 2022-07-28 22:43:09 +02:00
5ed059c1d7 make style 2022-07-28 22:33:59 +02:00
77f19b376b Not sure self.layer_number normalization actually matters 2022-07-28 22:28:47 +02:00
ddbe33e59d Nit 2022-07-28 21:44:23 +02:00
2677a28224 Removing layer num normalization seems to be breaking 2022-07-28 21:44:23 +02:00
d40ee96c83 Woops 2022-07-28 21:44:23 +02:00
0ffecbfa22 Woops 2022-07-28 21:44:23 +02:00
3a095d04ab Fix beam search 2022-07-28 21:44:23 +02:00
a8cde02e36 Try to fix beam_search 2022-07-28 21:44:23 +02:00
9b2c1ca021 No need for duplication 2022-07-28 21:44:23 +02:00
96307a4d07 Woops 2022-07-28 21:44:23 +02:00
42e5954430 Woops 2022-07-28 21:44:22 +02:00
298e3fde19 I don't think we actually need the layer_num scaling trick 2022-07-28 21:44:22 +02:00
ec7442c065 Woops 2022-07-28 21:44:22 +02:00
62adfce1b8 Try to reduce the number of reshape/copies 2022-07-28 21:44:22 +02:00
1a8b80b794 WIP 2022-07-28 21:44:22 +02:00
ed09d703fc Improve signatures 2022-07-28 21:44:22 +02:00
0078a6cd96 Woops 2022-07-28 21:44:22 +02:00
69227aea33 Woops 2022-07-28 21:44:22 +02:00
baa5d870e3 make style 2022-07-28 21:44:22 +02:00
d29881f3a0 Cleanup some code 2022-07-28 21:44:22 +02:00

View File

@ -16,12 +16,13 @@
import math
import warnings
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from torch.nn import functional as F
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import (
@ -52,102 +53,99 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
def _make_causal_mask(
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
) -> torch.BoolTensor:
"""
Make causal mask used for bi-directional self-attention.
Make causal mask used for self-attention.
"""
batch_size, target_length = input_ids_shape
mask = torch.full((target_length, target_length), torch.finfo(dtype).min)
mask_cond = torch.arange(mask.size(-1))
intermediate_mask = mask_cond < (mask_cond + 1).view(mask.size(-1), 1)
mask.masked_fill_(intermediate_mask, 0)
mask = mask.to(dtype)
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
mask[:, past_key_values_length:] = True
mask[:, past_key_values_length:].triu_(diagonal=1)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(target_length, past_key_values_length, dtype=dtype), mask], dim=-1)
mask[:, :past_key_values_length] = False
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
return expanded_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None):
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
"""
batch_size, source_length = mask.size()
tgt_len = tgt_len if tgt_len is not None else source_length
batch_size, src_length = mask.shape
tgt_length = tgt_length if tgt_length is not None else src_length
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, source_length).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
def build_alibi_tensor(attention_mask: torch.Tensor, n_head: int, dtype, device) -> torch.Tensor:
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly.
Args:
Returns tensor shaped (batch_size * n_head, 1, max_seq_len)
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
attention_mask (`torch.Tensor`):
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
n_head (`int`, *required*):
num_heads (`int`, *required*):
number of heads
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
batch_size, seq_length = attention_mask.shape
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
base = torch.tensor(
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != n_head:
if closest_power_of_2 != num_heads:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
# => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length)
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
# => the query_length dimension will then be broadcasted correctly
# This is more or less identical to T5's relative position bias:
# https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527
# batch_size = 1, n_head = n_head, query_length
arange_tensor = (attention_mask.cumsum(-1)[:, None, :].to(device) - 1) * attention_mask[:, None]
alibi = slopes.unsqueeze(-1) * arange_tensor
alibi = alibi * attention_mask[:, None]
return alibi.reshape(alibi.shape[0] * n_head, 1, -1).to(dtype)
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
alibi = slopes[..., None] * arange_tensor
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
def dropout_add(x, residual, prob, training):
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *rquired*):
residual (`torch.tensor`, *required*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = nn.functional.dropout(x, p=prob, training=training)
out = F.dropout(x, p=prob, training=training)
out = residual + out
return out
def bloom_gelu_forward(x):
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor:
"""
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
make the model jitable.
@ -159,7 +157,7 @@ def bloom_gelu_forward(x):
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def bloom_gelu_back(g, x):
def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
0.3989423 * x * torch.exp(-0.5 * x * x)
@ -179,12 +177,12 @@ def bloom_gelu_back(g, x):
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
input = ctx.saved_tensors
tmp = bloom_gelu_back(grad_output, input)
return tmp
@ -197,13 +195,12 @@ class BloomGelu(nn.Module):
copied from Megatron-DeepSpeed code and adapted for our needs
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
"""
def __init__(self):
super().__init__()
def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.training:
return GeLUFunction.apply(x)
else:
@ -211,7 +208,7 @@ class BloomGelu(nn.Module):
class BloomAttention(nn.Module):
def __init__(self, config, layer_number=None):
def __init__(self, config: BloomConfig):
super().__init__()
self.pretraining_tp = config.pretraining_tp
@ -230,106 +227,145 @@ class BloomAttention(nn.Module):
)
# Layer-wise attention scaling
self.layer_number = max(1, layer_number)
self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
self.beta = 1.0
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def _split_heads(self, fused_qkv):
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim)
"""
new_tensor_shape = fused_qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
# new_tensor_shape = (fused_qkv.size(1), fused_qkv.size(0)*fused_qkv.size(2), fused_qkv.size(-1))
# fused_qkv = fused_qkv.transpose(1, 0)
fused_qkv = fused_qkv.reshape(new_tensor_shape)
# fused_qkv = fused_qkv.permute(0, 2, 1, 3)
return torch.split(fused_qkv, self.head_dim, -1)
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
storage as `fused_qkv`
def _merge_heads(self, x):
Args:
fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
Returns:
query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
value: [batch_size, seq_length, num_heads, head_dim]
"""
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
"""
Merge heads together over the last dimenstion
Args:
x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
Returns:
torch.tensor: [batch_size, seq_length, num_heads * head_dim]
"""
# What we want to achieve is:
# batch_size * num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads * head_dim
# batch_size * num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads * head_dim
batch_size_and_num_heads, seq_length, _ = x.shape
batch_size = batch_size_and_num_heads // self.num_heads
# First view to decompose the batch size
# batch_size*num_heads, seq_len, head_dim -> batch_size, num_heads, seq_len, head_dim
x = x.view(x.size(0) // self.num_heads, self.num_heads, x.size(1), self.head_dim)
# batch_size * num_heads, seq_length, head_dim -> batch_size, num_heads, seq_length, head_dim
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim)
# batch_size, num_heads, seq_len, head_dim -> batch_size, seq_len, num_heads, head_dim
# batch_size, num_heads, seq_length, head_dim -> batch_size, seq_length, num_heads, head_dim
x = x.permute(0, 2, 1, 3)
# batch_size, seq_len, num_heads, head_dim -> batch_size, seq_len, num_heads * head_dim
return x.reshape(x.size(0), x.size(1), self.num_heads * self.head_dim)
# batch_size, seq_length, num_heads, head_dim -> batch_size, seq_length, num_heads * head_dim
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
def forward(
self,
hidden_states,
residual,
layer_past=None,
attention_mask=None,
alibi=None,
head_mask=None,
use_cache=False,
output_attentions=False,
hidden_states: torch.Tensor,
residual: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor, int]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
alibi = alibi.to(hidden_states.device) # to make the model possible to run under accelerate
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, q_length, _, _ = query_layer.shape
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
present = None
if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
# Cached past key/value:
# - key: [batch_size * self.num_heads, head_dim, kv_length]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
past_key, past_value, kv_past_length = layer_past
past_buffer_size = past_key.shape[-1]
new_kv_past_length = kv_past_length + q_length
if use_cache is True:
present = (key_layer, value_layer)
# We double the buffer for past keys and values everytime we fill it up.
if new_kv_past_length >= past_buffer_size:
past_key = torch.cat((past_key, torch.empty_like(past_key)), dim=2)
past_value = torch.cat((past_value, torch.empty_like(past_value)), dim=1)
past_key[:, :, kv_past_length:new_kv_past_length] = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
past_value[:, kv_past_length:new_kv_past_length, :] = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
key_layer = past_key[:, :, :new_kv_past_length]
value_layer = past_value[:, :new_kv_past_length, :]
if use_cache is True:
present = (past_key, past_value, new_kv_past_length)
else:
present = None
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
new_kv_past_length = key_layer.shape[-1]
beta = 1.0 / self.layer_number
if use_cache is True:
present = (key_layer, value_layer, new_kv_past_length)
# # [batch_size*num_heads, head_dim, q_length] x [batch_size*num_heads, head_dim, k_length] -> [batch_size*num_heads, q_length, k_length]
matmul_result = (1.0 / self.norm_factor) * torch.bmm(
query_layer.transpose(1, 2).reshape(-1, query_layer.shape[1], query_layer.shape[3]),
key_layer.permute(0, 2, 3, 1).reshape(-1, key_layer.shape[3], key_layer.shape[1]),
) + beta * alibi
_, _, kv_length = key_layer.shape
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(-1, self.num_heads, matmul_result.size(1), matmul_result.size(2))
# [batch_size * num_heads, q_length, kv_length]
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
matmul_result = alibi.baddbmm(
batch1=query_layer,
batch2=key_layer,
beta=self.beta,
alpha=self.inv_norm_factor,
)
# We replace the scaled softmax by just a few line of code - [batch_size, num_heads, q_length, k_length]
# change view to [batch_size, num_heads, q_length, kv_length]
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
input_dtype = attention_scores.dtype
attn_weights = (attention_scores * self.layer_number) + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
attention_probs = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
attention_probs = attention_probs * (~attention_mask.to(torch.bool))
# [batch_size, num_heads, q_length, k_length]
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
if input_dtype == torch.float16:
attention_scores = attention_scores.to(torch.float)
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
# [batch_size, num_heads, q_length, kv_length]
attention_probs = self.attention_dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
# change view [batch_size x num_heads, q_length, k_length]
attention_probs_reshaped = attention_probs.view(matmul_result.shape)
# change view [batch_size x num_heads, q_length, kv_length]
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(
attention_probs_reshaped, value_layer.transpose(1, 2).reshape(-1, value_layer.size(1), value_layer.size(3))
)
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = self._merge_heads(context_layer)
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
if self.pretraining_tp > 1 and self.slow_but_exact:
slices = context_layer.shape[-1] / self.pretraining_tp
slices = self.hidden_size / self.pretraining_tp
output_tensor = torch.zeros_like(context_layer)
for i in range(self.pretraining_tp):
output_tensor = output_tensor + nn.functional.linear(
output_tensor = output_tensor + F.linear(
context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
)
@ -346,7 +382,7 @@ class BloomAttention(nn.Module):
class BloomMLP(nn.Module):
def __init__(self, config):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
@ -357,14 +393,14 @@ class BloomMLP(nn.Module):
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size)
self.hidden_dropout = config.hidden_dropout
def forward(self, hidden_states, residual):
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
for i in range(self.pretraining_tp):
intermediate_output = intermediate_output + nn.functional.linear(
intermediate_output = intermediate_output + F.linear(
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)],
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)],
)
@ -377,13 +413,13 @@ class BloomMLP(nn.Module):
class BloomBlock(nn.Module):
def __init__(self, config, layer_number=None):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head
self.self_attention = BloomAttention(config, layer_number=layer_number)
self.num_heads = config.n_head
self.self_attention = BloomAttention(config)
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
@ -393,13 +429,13 @@ class BloomBlock(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
use_cache=False,
output_attentions=False,
alibi=None,
hidden_states: torch.Tensor,
alibi: torch.Tensor,
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor, int]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
# hidden_states: [batch_size, seq_length, hidden_size]
@ -462,9 +498,9 @@ class BloomPreTrainedModel(PreTrainedModel):
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
if isinstance(module, (nn.Linear)):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@ -478,7 +514,7 @@ class BloomPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
if isinstance(module, BloomModel):
module.gradient_checkpointing = value
@ -501,9 +537,8 @@ BLOOM_START_DOCSTRING = r"""
BLOOM_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
sequence tokens in the vocabulary.
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
`input_ids`.
@ -516,6 +551,10 @@ BLOOM_INPUTS_DOCSTRING = r"""
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
Each element of `past_key_values` is a tuple (past_key, past_value):
- past_key: [batch_size * num_heads, head_dim, kv_length]
- past_value: [batch_size * num_heads, kv_length, head_dim]
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@ -555,19 +594,18 @@ BLOOM_INPUTS_DOCSTRING = r"""
BLOOM_START_DOCSTRING,
)
class BloomModel(BloomPreTrainedModel):
def __init__(self, config):
def __init__(self, config: BloomConfig):
super().__init__(config)
self.embed_dim = config.hidden_size
self.n_head = config.n_head
self.num_heads = config.n_head
# Embedding + LN Embedding
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
self.h = nn.ModuleList([BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@ -580,25 +618,29 @@ class BloomModel(BloomPreTrainedModel):
def get_input_embeddings(self):
return self.word_embeddings
def _prepare_attn_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
def _prepare_attn_mask(
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
) -> torch.BoolTensor:
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(attention_mask.device)
device = attention_mask.device
_, src_length = input_shape
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
if src_length > 1:
combined_attention_mask = _make_causal_mask(
input_shape, device=device, past_key_values_length=past_key_values_length
)
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
)
return combined_attention_mask
def set_input_embeddings(self, new_embeddings):
def set_input_embeddings(self, new_embeddings: torch.Tensor):
self.word_embeddings = new_embeddings
@add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
@ -610,17 +652,17 @@ class BloomModel(BloomPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
@ -641,10 +683,9 @@ class BloomModel(BloomPreTrainedModel):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
@ -653,8 +694,8 @@ class BloomModel(BloomPreTrainedModel):
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_head x N x N
# head_mask has shape n_layer x batch x n_head x N x N
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
@ -662,27 +703,29 @@ class BloomModel(BloomPreTrainedModel):
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
# Compute alibi tensor: check build_alibi_tensor documentation
current_sequence_length = hidden_states.shape[1]
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[1]
current_sequence_length += past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past += past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((hidden_states.shape[0], current_sequence_length), device=hidden_states.device)
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.n_head, hidden_states.dtype, hidden_states.device)
alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(attention_mask, input_shape, inputs_embeds, past_key_values_length)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -700,14 +743,14 @@ class BloomModel(BloomPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions, alibi)
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
alibi,
causal_mask,
head_mask[i],
)
@ -735,8 +778,6 @@ class BloomModel(BloomPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = hidden_states.view(output_shape)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
@ -758,7 +799,7 @@ class BloomModel(BloomPreTrainedModel):
class BloomForCausalLM(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
def __init__(self, config: BloomConfig):
super().__init__(config)
self.transformer = BloomModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@ -769,16 +810,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
def set_output_embeddings(self, new_embeddings: torch.Tensor):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs
) -> dict:
# only last token for input_ids if past is not None
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
return {
"input_ids": input_ids,
"past_key_values": past,
@ -795,16 +840,16 @@ class BloomForCausalLM(BloomPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
@ -845,9 +890,12 @@ class BloomForCausalLM(BloomPreTrainedModel):
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
@ -862,14 +910,38 @@ class BloomForCausalLM(BloomPreTrainedModel):
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
Output shares the same memory storage as `past`.
"""
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape
batch_size = len(beam_idx)
num_heads = batch_size_times_num_heads // batch_size
# Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
}
# key: layer_past[0], [batch_size * num_heads, head_dim, seq_length]
# value: layer_past[1], [batch_size * num_heads, seq_length, head_dim]
# kv_past_length: layer_past[2], int
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
(
layer_past[0]
.view(batch_size, num_heads, head_dim, seq_length)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, seq_length, head_dim),
layer_past[2]
)
for layer_past in past
)
@ -892,7 +964,7 @@ class BloomForCausalLM(BloomPreTrainedModel):
class BloomForSequenceClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
def __init__(self, config: BloomConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = BloomModel(config)
@ -910,16 +982,16 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
r"""
@ -966,7 +1038,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1
else:
sequence_lengths = -1
logger.warning(
@ -994,7 +1066,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
@ -1021,7 +1093,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
class BloomForTokenClassification(BloomPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
def __init__(self, config):
def __init__(self, config: BloomConfig):
super().__init__(config)
self.num_labels = config.num_labels
@ -1047,16 +1119,16 @@ class BloomForTokenClassification(BloomPreTrainedModel):
)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor, int], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
@ -1095,8 +1167,11 @@ class BloomForTokenClassification(BloomPreTrainedModel):
loss = None
if labels is not None:
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + transformer_outputs[2:]