mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
10 Commits
torch_vers
...
localattn1
Author | SHA1 | Date | |
---|---|---|---|
8aa32719b2 | |||
0e6186f35e | |||
98cfde7fe0 | |||
6ae639f68b | |||
dca62e37c9 | |||
aae293d05f | |||
09530a96e6 | |||
df7f661381 | |||
b42abbd61f | |||
7ab9abf08f |
@ -1,3 +1,12 @@
|
||||
# Patches
|
||||
|
||||
This branch has the following patches:
|
||||
|
||||
* gpt-neo model is loaded directly on GPU to save system memory
|
||||
* repetition_penalty has range and slope settings, so it doesn't penalize all tokens in the context window
|
||||
* no copy of the state dict is made while loading a pretrained model
|
||||
* local self attention uses padding so it doesn't OOM on long sequences
|
||||
|
||||
<!---
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
|
@ -152,17 +152,30 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
|
||||
<https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
|
||||
"""
|
||||
|
||||
def __init__(self, penalty: float):
|
||||
def __init__(self, penalty: float, m=3.33, penalize_last=250):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
||||
self.penalty = penalty
|
||||
self.penalize_last = None
|
||||
if not m is None and not penalize_last is None:
|
||||
self.penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
|
||||
self.penalty = (m * self.penalty) / (1 + torch.abs(self.penalty) * (m - 1))
|
||||
self.penalty = 1 + ((self.penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
|
||||
self.penalize_last = penalize_last
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if not self.penalize_last is None:
|
||||
penality_len = min(input_ids.shape[1], self.penalize_last)
|
||||
input_ids = input_ids[:, -penality_len:]
|
||||
score = torch.gather(scores, 1, input_ids)
|
||||
|
||||
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
if not self.penalize_last is None:
|
||||
penalty = self.penalty.type(score.dtype).to(score.device)
|
||||
score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
|
||||
else:
|
||||
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
|
||||
|
||||
scores.scatter_(1, input_ids, score)
|
||||
return scores
|
||||
|
@ -570,6 +570,8 @@ class GenerationMixin:
|
||||
def _get_logits_processor(
|
||||
self,
|
||||
repetition_penalty: float,
|
||||
repetition_penalty_range: int,
|
||||
repetition_penalty_slope: float,
|
||||
no_repeat_ngram_size: int,
|
||||
encoder_no_repeat_ngram_size: int,
|
||||
encoder_input_ids: torch.LongTensor,
|
||||
@ -625,7 +627,7 @@ class GenerationMixin:
|
||||
)
|
||||
)
|
||||
if repetition_penalty is not None and repetition_penalty != 1.0:
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
|
||||
processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty, m=repetition_penalty_slope, penalize_last=repetition_penalty_range))
|
||||
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
|
||||
processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
|
||||
if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
|
||||
@ -675,6 +677,8 @@ class GenerationMixin:
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
repetition_penalty_range: Optional[int] = None,
|
||||
repetition_penalty_slope: Optional[float] = 3.33,
|
||||
bad_words_ids: Optional[Iterable[int]] = None,
|
||||
bos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
@ -967,6 +971,8 @@ class GenerationMixin:
|
||||
# get distribution pre_processing samplers
|
||||
logits_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
repetition_penalty_range=repetition_penalty_range,
|
||||
repetition_penalty_slope=repetition_penalty_slope,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=encoder_input_ids,
|
||||
|
@ -1138,9 +1138,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = state_dict.copy()
|
||||
if metadata is not None:
|
||||
state_dict._metadata = metadata
|
||||
#state_dict = state_dict.copy()
|
||||
#if metadata is not None:
|
||||
# state_dict._metadata = metadata
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
|
@ -416,6 +416,20 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
||||
# compute block length and num_blocks
|
||||
batch_size, seq_length = hidden_states.shape[:2]
|
||||
full_seq_length = seq_length + past_length
|
||||
|
||||
padding = None
|
||||
if layer_past is None and full_seq_length % self.window_size != 0 and full_seq_length > self.window_size:
|
||||
padding = self.window_size-(full_seq_length%self.window_size)
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.zeros(query.shape[0], query.shape[1] + padding).to(query.device)
|
||||
attention_mask[:, padding:] = 1
|
||||
else:
|
||||
attention_mask = torch.cat([torch.zeros(attention_mask.shape[0], padding).to(attention_mask.device), attention_mask], axis=1)
|
||||
pad = lambda x: torch.cat([torch.zeros(x.shape[0],padding,x.shape[2]).to(x.device), x], axis=1)
|
||||
query, key, value = map(pad, (query, key, value))
|
||||
seq_length += padding
|
||||
full_seq_length += padding
|
||||
|
||||
block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size)
|
||||
|
||||
# create buckets
|
||||
@ -457,7 +471,11 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
|
||||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
||||
attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
if padding is not None:
|
||||
attn_output = attn_output[:,padding:]
|
||||
attn_weights = attn_weights[:,padding:]
|
||||
|
||||
attn_output = self.out_proj(attn_output.to(hidden_states.dtype))
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output,)
|
||||
@ -704,7 +722,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
|
||||
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.drop = nn.Dropout(config.embed_dropout)
|
||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
|
||||
self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i).half().cuda() for i in range(config.num_layers)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.init_weights()
|
||||
@ -891,8 +909,8 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = GPTNeoModel(config)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
self.transformer = GPTNeoModel(config).half().cuda()
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False).half().cuda()
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
Reference in New Issue
Block a user