|
|
|
@ -51,6 +51,15 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def debug(msg, tensor):
|
|
|
|
|
pass
|
|
|
|
|
# if isinstance(tensor, torch.Tensor):
|
|
|
|
|
# print(msg, tensor.shape)
|
|
|
|
|
# print(tensor)
|
|
|
|
|
# else:
|
|
|
|
|
# print(msg, tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
|
|
|
|
|
"""Split a tensor along its last dimension.
|
|
|
|
|
|
|
|
|
@ -286,8 +295,10 @@ class BloomScaledSoftmax(nn.Module):
|
|
|
|
|
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
|
|
|
|
|
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
|
|
|
|
|
|
|
|
|
|
debug("input", input)
|
|
|
|
|
if self.scale is not None:
|
|
|
|
|
input = input * self.scale
|
|
|
|
|
debug("scaled input", input)
|
|
|
|
|
|
|
|
|
|
if mask is None:
|
|
|
|
|
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
|
|
|
|
@ -295,11 +306,16 @@ class BloomScaledSoftmax(nn.Module):
|
|
|
|
|
mask = mask.to(input.device)
|
|
|
|
|
seq_ids = torch.arange(max_positions, device=input.device)
|
|
|
|
|
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
|
|
|
|
|
debug("Causal mask", causal_mask)
|
|
|
|
|
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
|
|
|
|
|
debug("mask output", mask_output)
|
|
|
|
|
debug("padded causal_mask", padded_causal_mask)
|
|
|
|
|
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
|
|
|
|
|
debug("probs", probs)
|
|
|
|
|
|
|
|
|
|
if input_in_16bit and self.softmax_in_fp32:
|
|
|
|
|
probs = probs.to(dtype=input_dtype)
|
|
|
|
|
debug("final probs", probs)
|
|
|
|
|
|
|
|
|
|
return probs
|
|
|
|
|
|
|
|
|
@ -361,6 +377,7 @@ class BloomAttention(nn.Module):
|
|
|
|
|
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
|
|
|
|
|
|
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states)
|
|
|
|
|
debug(f"Mixed x layer {self.layer_number}", mixed_x_layer)
|
|
|
|
|
|
|
|
|
|
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
|
|
|
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
|
|
|
|
@ -394,16 +411,26 @@ class BloomAttention(nn.Module):
|
|
|
|
|
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
|
|
|
|
|
beta = 1.0 / self.layer_number
|
|
|
|
|
|
|
|
|
|
q_query_layer = query_layer.transpose(1, 0)
|
|
|
|
|
q_key_layer = key_layer.transpose(1, 0).transpose(1, 2)
|
|
|
|
|
alpha = 1.0 / self.norm_factor
|
|
|
|
|
|
|
|
|
|
debug(f"Sliced alibi {self.layer_number}", sliced_alibi)
|
|
|
|
|
debug(f"Query layer {self.layer_number}", q_query_layer)
|
|
|
|
|
debug(f"Key layer {self.layer_number}", q_key_layer)
|
|
|
|
|
debug(f"Alpha {self.layer_number}", alpha)
|
|
|
|
|
debug(f"Beta {self.layer_number}", beta)
|
|
|
|
|
matmul_result = torch.baddbmm(
|
|
|
|
|
sliced_alibi,
|
|
|
|
|
query_layer.transpose(1, 0),
|
|
|
|
|
key_layer.transpose(1, 0).transpose(1, 2),
|
|
|
|
|
q_query_layer,
|
|
|
|
|
q_key_layer,
|
|
|
|
|
beta=beta,
|
|
|
|
|
alpha=(1.0 / self.norm_factor),
|
|
|
|
|
alpha=alpha,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# change view to [batch_size, num_heads, q_length, k_length]
|
|
|
|
|
attention_scores = matmul_result.view(*output_size)
|
|
|
|
|
debug(f"Attention scores {self.layer_number}", attention_scores)
|
|
|
|
|
|
|
|
|
|
# attention scores and attention mask [b, np, sq, sk]
|
|
|
|
|
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
|
|
|
|
@ -415,6 +442,8 @@ class BloomAttention(nn.Module):
|
|
|
|
|
if head_mask is not None:
|
|
|
|
|
attention_probs = attention_probs * head_mask
|
|
|
|
|
|
|
|
|
|
debug(f"Attention probs {self.layer_number}", attention_probs)
|
|
|
|
|
|
|
|
|
|
# context layer shape: [batch_size, num_heads, q_length, head_dim]
|
|
|
|
|
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
|
|
|
|
|
|
|
|
|
@ -426,6 +455,7 @@ class BloomAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
# matmul: [batch_size * num_heads, q_length, head_dim]
|
|
|
|
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
|
|
|
|
|
debug(f"bmm {self.layer_number}", context_layer)
|
|
|
|
|
|
|
|
|
|
# change view [batch_size, num_heads, q_length, head_dim]
|
|
|
|
|
context_layer = context_layer.view(*output_size)
|
|
|
|
@ -476,7 +506,12 @@ class BloomMLP(nn.Module):
|
|
|
|
|
self.gelu_impl = BloomGelu()
|
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states, residual):
|
|
|
|
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
|
|
|
|
debug("Hidden states", hidden_states)
|
|
|
|
|
debug("Residual", residual)
|
|
|
|
|
hidden_states = self.dense_h_to_4h(hidden_states)
|
|
|
|
|
debug("Hidden states h to 4h", hidden_states)
|
|
|
|
|
hidden_states = self.gelu_impl(hidden_states)
|
|
|
|
|
debug("Hidden states gelu", hidden_states)
|
|
|
|
|
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact:
|
|
|
|
|
intermediate_output = torch.zeros_like(residual)
|
|
|
|
@ -488,8 +523,10 @@ class BloomMLP(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
intermediate_output = self.dense_4h_to_h(hidden_states)
|
|
|
|
|
debug("Hidden states 4h to h", intermediate_output)
|
|
|
|
|
|
|
|
|
|
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
|
|
|
|
debug("Hidden states dropout add", output)
|
|
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
@ -498,6 +535,7 @@ class BloomBlock(nn.Module):
|
|
|
|
|
def __init__(self, config, layer_number=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
hidden_size = config.hidden_size
|
|
|
|
|
self.layer_number = layer_number
|
|
|
|
|
|
|
|
|
|
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
|
|
|
|
self.n_head = config.n_head
|
|
|
|
@ -542,12 +580,16 @@ class BloomBlock(nn.Module):
|
|
|
|
|
output_attentions=output_attentions,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
debug(f"Attention output {self.layer_number}", attn_outputs[0])
|
|
|
|
|
|
|
|
|
|
attention_output = attn_outputs[0]
|
|
|
|
|
|
|
|
|
|
outputs = attn_outputs[1:]
|
|
|
|
|
|
|
|
|
|
layernorm_output = self.post_attention_layernorm(attention_output)
|
|
|
|
|
|
|
|
|
|
debug(f"Layer norm output {self.layer_number}", layernorm_output)
|
|
|
|
|
|
|
|
|
|
# Get residual
|
|
|
|
|
if self.apply_residual_connection_post_layernorm:
|
|
|
|
|
residual = layernorm_output
|
|
|
|
@ -557,6 +599,8 @@ class BloomBlock(nn.Module):
|
|
|
|
|
# MLP.
|
|
|
|
|
output = self.mlp(layernorm_output, residual)
|
|
|
|
|
|
|
|
|
|
debug(f"MLP output {self.layer_number}", output)
|
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
outputs = (output,) + outputs
|
|
|
|
|
else:
|
|
|
|
@ -754,8 +798,12 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
if inputs_embeds is None:
|
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
|
debug("Embeddings", self.word_embeddings.weight)
|
|
|
|
|
debug("Input embeds", inputs_embeds)
|
|
|
|
|
|
|
|
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
|
|
|
|
debug("Word embeddings layernorm weights", self.word_embeddings_layernorm.weight)
|
|
|
|
|
debug("Word embeddings layernorm", hidden_states)
|
|
|
|
|
|
|
|
|
|
output_shape = input_shape + (hidden_states.size(-1),)
|
|
|
|
|
|
|
|
|
@ -768,6 +816,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
|
|
if past_key_values[0] is not None:
|
|
|
|
|
current_sequence_length += past_key_values[0][0].shape[1]
|
|
|
|
|
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
|
|
|
|
|
debug("Alibi", alibi)
|
|
|
|
|
|
|
|
|
|
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
|
|
|
|
|
|
|
|
@ -808,6 +857,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0]
|
|
|
|
|
debug(f"Block {i}", hidden_states)
|
|
|
|
|
if use_cache is True:
|
|
|
|
|
presents = presents + (outputs[1],)
|
|
|
|
|
|
|
|
|
@ -816,6 +866,7 @@ class BloomModel(BloomPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
# Add last hidden state
|
|
|
|
|
hidden_states = self.ln_f(hidden_states)
|
|
|
|
|
debug(f"Ln_f {i}", hidden_states)
|
|
|
|
|
|
|
|
|
|
if output_hidden_states:
|
|
|
|
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
@ -926,6 +977,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
|
|
|
|
|
|
|
|
|
|
lm_logits = self.lm_head(hidden_states)
|
|
|
|
|
|
|
|
|
|
debug("lm logits", lm_logits)
|
|
|
|
|
|
|
|
|
|
loss = None
|
|
|
|
|
if labels is not None:
|
|
|
|
|
# Shift so that tokens < n predict n
|
|
|
|
|