Compare commits

...

10 Commits

Author SHA1 Message Date
8aa32719b2 Update README.md 2021-04-23 22:46:21 +02:00
0e6186f35e Update modeling_gpt_neo.py 2021-04-22 18:35:41 +02:00
98cfde7fe0 Update README.md 2021-04-18 22:14:34 +02:00
6ae639f68b Update modeling_utils.py 2021-04-18 21:45:59 +02:00
dca62e37c9 Update modeling_utils.py 2021-04-18 21:30:38 +02:00
aae293d05f Update modeling_gpt_neo.py 2021-04-18 20:52:38 +02:00
09530a96e6 Update modeling_gpt_neo.py 2021-04-18 20:40:55 +02:00
df7f661381 Update modeling_gpt_neo.py 2021-04-18 17:45:10 +02:00
b42abbd61f Update generation_utils.py 2021-04-18 17:43:31 +02:00
7ab9abf08f Update generation_logits_process.py 2021-04-18 17:39:07 +02:00
5 changed files with 56 additions and 10 deletions

View File

@ -1,3 +1,12 @@
# Patches
This branch has the following patches:
* gpt-neo model is loaded directly on GPU to save system memory
* repetition_penalty has range and slope settings, so it doesn't penalize all tokens in the context window
* no copy of the state dict is made while loading a pretrained model
* local self attention uses padding so it doesn't OOM on long sequences
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.

View File

@ -152,17 +152,30 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
"""
def __init__(self, penalty: float):
def __init__(self, penalty: float, m=3.33, penalize_last=250):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
self.penalize_last = None
if not m is None and not penalize_last is None:
self.penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
self.penalty = (m * self.penalty) / (1 + torch.abs(self.penalty) * (m - 1))
self.penalty = 1 + ((self.penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
self.penalize_last = penalize_last
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if not self.penalize_last is None:
penality_len = min(input_ids.shape[1], self.penalize_last)
input_ids = input_ids[:, -penality_len:]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
if not self.penalize_last is None:
penalty = self.penalty.type(score.dtype).to(score.device)
score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
else:
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores.scatter_(1, input_ids, score)
return scores

View File

@ -570,6 +570,8 @@ class GenerationMixin:
def _get_logits_processor(
self,
repetition_penalty: float,
repetition_penalty_range: int,
repetition_penalty_slope: float,
no_repeat_ngram_size: int,
encoder_no_repeat_ngram_size: int,
encoder_input_ids: torch.LongTensor,
@ -625,7 +627,7 @@ class GenerationMixin:
)
)
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty, m=repetition_penalty_slope, penalize_last=repetition_penalty_range))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
@ -675,6 +677,8 @@ class GenerationMixin:
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_penalty_range: Optional[int] = None,
repetition_penalty_slope: Optional[float] = 3.33,
bad_words_ids: Optional[Iterable[int]] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
@ -967,6 +971,8 @@ class GenerationMixin:
# get distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
repetition_penalty_range=repetition_penalty_range,
repetition_penalty_slope=repetition_penalty_slope,
no_repeat_ngram_size=no_repeat_ngram_size,
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
encoder_input_ids=encoder_input_ids,

View File

@ -1138,9 +1138,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
#state_dict = state_dict.copy()
#if metadata is not None:
# state_dict._metadata = metadata
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.

View File

@ -416,6 +416,20 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
# compute block length and num_blocks
batch_size, seq_length = hidden_states.shape[:2]
full_seq_length = seq_length + past_length
padding = None
if layer_past is None and full_seq_length % self.window_size != 0 and full_seq_length > self.window_size:
padding = self.window_size-(full_seq_length%self.window_size)
if attention_mask is None:
attention_mask = torch.zeros(query.shape[0], query.shape[1] + padding).to(query.device)
attention_mask[:, padding:] = 1
else:
attention_mask = torch.cat([torch.zeros(attention_mask.shape[0], padding).to(attention_mask.device), attention_mask], axis=1)
pad = lambda x: torch.cat([torch.zeros(x.shape[0],padding,x.shape[2]).to(x.device), x], axis=1)
query, key, value = map(pad, (query, key, value))
seq_length += padding
full_seq_length += padding
block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size)
# create buckets
@ -457,7 +471,11 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)
attn_output = self.out_proj(attn_output)
if padding is not None:
attn_output = attn_output[:,padding:]
attn_weights = attn_weights[:,padding:]
attn_output = self.out_proj(attn_output.to(hidden_states.dtype))
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output,)
@ -704,7 +722,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embed_dropout)
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i).half().cuda() for i in range(config.num_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.init_weights()
@ -891,8 +909,8 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPTNeoModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.transformer = GPTNeoModel(config).half().cuda()
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False).half().cuda()
self.init_weights()