Compare commits

...

1 Commits

Author SHA1 Message Date
c1a668ef9d Adding a bunch of debug places. 2022-07-08 17:43:44 +02:00
3 changed files with 70 additions and 4 deletions

View File

@ -1679,6 +1679,10 @@ class GenerationMixin:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
print("Input ids", input_ids.shape)
print("Input ids", model_inputs["input_ids"].shape)
print("Keys", model_inputs.keys())
# forward pass to get next token
outputs = self(
**model_inputs,
@ -1934,6 +1938,10 @@ class GenerationMixin:
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
print("Input ids", input_ids.shape)
print("Input ids", model_inputs["input_ids"].shape)
print("Keys", model_inputs.keys())
# forward pass to get next token
outputs = self(
**model_inputs,

View File

@ -51,6 +51,15 @@ BLOOM_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def debug(msg, tensor):
pass
# if isinstance(tensor, torch.Tensor):
# print(msg, tensor.shape)
# print(tensor)
# else:
# print(msg, tensor)
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
@ -286,8 +295,10 @@ class BloomScaledSoftmax(nn.Module):
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
debug("input", input)
if self.scale is not None:
input = input * self.scale
debug("scaled input", input)
if mask is None:
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
@ -295,11 +306,16 @@ class BloomScaledSoftmax(nn.Module):
mask = mask.to(input.device)
seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
debug("Causal mask", causal_mask)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
debug("mask output", mask_output)
debug("padded causal_mask", padded_causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
debug("probs", probs)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)
debug("final probs", probs)
return probs
@ -361,6 +377,7 @@ class BloomAttention(nn.Module):
alibi = pre_process_alibi_for_pad(alibi, attention_mask, self.num_heads)
mixed_x_layer = self.query_key_value(hidden_states)
debug(f"Mixed x layer {self.layer_number}", mixed_x_layer)
# [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
@ -394,16 +411,26 @@ class BloomAttention(nn.Module):
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
beta = 1.0 / self.layer_number
q_query_layer = query_layer.transpose(1, 0)
q_key_layer = key_layer.transpose(1, 0).transpose(1, 2)
alpha = 1.0 / self.norm_factor
debug(f"Sliced alibi {self.layer_number}", sliced_alibi)
debug(f"Query layer {self.layer_number}", q_query_layer)
debug(f"Key layer {self.layer_number}", q_key_layer)
debug(f"Alpha {self.layer_number}", alpha)
debug(f"Beta {self.layer_number}", beta)
matmul_result = torch.baddbmm(
sliced_alibi,
query_layer.transpose(1, 0),
key_layer.transpose(1, 0).transpose(1, 2),
q_query_layer,
q_key_layer,
beta=beta,
alpha=(1.0 / self.norm_factor),
alpha=alpha,
)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores = matmul_result.view(*output_size)
debug(f"Attention scores {self.layer_number}", attention_scores)
# attention scores and attention mask [b, np, sq, sk]
max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
@ -415,6 +442,8 @@ class BloomAttention(nn.Module):
if head_mask is not None:
attention_probs = attention_probs * head_mask
debug(f"Attention probs {self.layer_number}", attention_probs)
# context layer shape: [batch_size, num_heads, q_length, head_dim]
output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
@ -426,6 +455,7 @@ class BloomAttention(nn.Module):
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
debug(f"bmm {self.layer_number}", context_layer)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer = context_layer.view(*output_size)
@ -476,7 +506,12 @@ class BloomMLP(nn.Module):
self.gelu_impl = BloomGelu()
def forward(self, hidden_states, residual):
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
debug("Hidden states", hidden_states)
debug("Residual", residual)
hidden_states = self.dense_h_to_4h(hidden_states)
debug("Hidden states h to 4h", hidden_states)
hidden_states = self.gelu_impl(hidden_states)
debug("Hidden states gelu", hidden_states)
if self.pretraining_tp > 1 and self.slow_but_exact:
intermediate_output = torch.zeros_like(residual)
@ -488,8 +523,10 @@ class BloomMLP(nn.Module):
)
else:
intermediate_output = self.dense_4h_to_h(hidden_states)
debug("Hidden states 4h to h", intermediate_output)
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
debug("Hidden states dropout add", output)
return output
@ -498,6 +535,7 @@ class BloomBlock(nn.Module):
def __init__(self, config, layer_number=None):
super().__init__()
hidden_size = config.hidden_size
self.layer_number = layer_number
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.n_head = config.n_head
@ -542,12 +580,16 @@ class BloomBlock(nn.Module):
output_attentions=output_attentions,
)
debug(f"Attention output {self.layer_number}", attn_outputs[0])
attention_output = attn_outputs[0]
outputs = attn_outputs[1:]
layernorm_output = self.post_attention_layernorm(attention_output)
debug(f"Layer norm output {self.layer_number}", layernorm_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
@ -557,6 +599,8 @@ class BloomBlock(nn.Module):
# MLP.
output = self.mlp(layernorm_output, residual)
debug(f"MLP output {self.layer_number}", output)
if use_cache:
outputs = (output,) + outputs
else:
@ -754,8 +798,12 @@ class BloomModel(BloomPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
debug("Embeddings", self.word_embeddings.weight)
debug("Input embeds", inputs_embeds)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
debug("Word embeddings layernorm weights", self.word_embeddings_layernorm.weight)
debug("Word embeddings layernorm", hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@ -768,6 +816,7 @@ class BloomModel(BloomPreTrainedModel):
if past_key_values[0] is not None:
current_sequence_length += past_key_values[0][0].shape[1]
alibi = build_alibi_tensor(current_sequence_length, self.n_head, hidden_states.dtype)
debug("Alibi", alibi)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
@ -808,6 +857,7 @@ class BloomModel(BloomPreTrainedModel):
)
hidden_states = outputs[0]
debug(f"Block {i}", hidden_states)
if use_cache is True:
presents = presents + (outputs[1],)
@ -816,6 +866,7 @@ class BloomModel(BloomPreTrainedModel):
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
debug(f"Ln_f {i}", hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -926,6 +977,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
lm_logits = self.lm_head(hidden_states)
debug("lm logits", lm_logits)
loss = None
if labels is not None:
# Shift so that tokens < n predict n

View File

@ -212,7 +212,12 @@ class TextGenerationPipeline(Pipeline):
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
import datetime
start = datetime.datetime.now()
generated_sequence = self.model.generate(input_ids=input_ids, **generate_kwargs) # BS x SL
print(f"Generation took {datetime.datetime.now() - start}")
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])