mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 20:14:36 +08:00 
			
		
		
		
	Compare commits
	
		
			10 Commits
		
	
	
		
			kernel_con
			...
			localattn1
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8aa32719b2 | |||
| 0e6186f35e | |||
| 98cfde7fe0 | |||
| 6ae639f68b | |||
| dca62e37c9 | |||
| aae293d05f | |||
| 09530a96e6 | |||
| df7f661381 | |||
| b42abbd61f | |||
| 7ab9abf08f | 
@ -1,3 +1,12 @@
 | 
			
		||||
# Patches
 | 
			
		||||
 | 
			
		||||
This branch has the following patches:
 | 
			
		||||
 | 
			
		||||
* gpt-neo model is loaded directly on GPU to save system memory
 | 
			
		||||
* repetition_penalty has range and slope settings, so it doesn't penalize all tokens in the context window
 | 
			
		||||
* no copy of the state dict is made while loading a pretrained model
 | 
			
		||||
* local self attention uses padding so it doesn't OOM on long sequences
 | 
			
		||||
 | 
			
		||||
<!---
 | 
			
		||||
Copyright 2020 The HuggingFace Team. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -152,17 +152,30 @@ class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
 | 
			
		||||
            <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, penalty: float):
 | 
			
		||||
    def __init__(self, penalty: float, m=3.33, penalize_last=250):
 | 
			
		||||
        if not isinstance(penalty, float) or not (penalty > 0):
 | 
			
		||||
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
 | 
			
		||||
 | 
			
		||||
        self.penalty = penalty
 | 
			
		||||
        self.penalize_last = None
 | 
			
		||||
        if not m is None and not penalize_last is None:
 | 
			
		||||
            self.penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
 | 
			
		||||
            self.penalty = (m * self.penalty) / (1 + torch.abs(self.penalty) * (m - 1))
 | 
			
		||||
            self.penalty = 1 + ((self.penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
 | 
			
		||||
            self.penalize_last = penalize_last
 | 
			
		||||
 | 
			
		||||
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
			
		||||
        if not self.penalize_last is None:
 | 
			
		||||
            penality_len = min(input_ids.shape[1], self.penalize_last)
 | 
			
		||||
            input_ids = input_ids[:, -penality_len:]
 | 
			
		||||
        score = torch.gather(scores, 1, input_ids)
 | 
			
		||||
 | 
			
		||||
        # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
 | 
			
		||||
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)
 | 
			
		||||
        if not self.penalize_last is None:
 | 
			
		||||
            penalty = self.penalty.type(score.dtype).to(score.device)
 | 
			
		||||
            score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
 | 
			
		||||
        else:
 | 
			
		||||
            score = torch.where(score < 0, score * self.penalty, score / self.penalty)
 | 
			
		||||
 | 
			
		||||
        scores.scatter_(1, input_ids, score)
 | 
			
		||||
        return scores
 | 
			
		||||
 | 
			
		||||
@ -570,6 +570,8 @@ class GenerationMixin:
 | 
			
		||||
    def _get_logits_processor(
 | 
			
		||||
        self,
 | 
			
		||||
        repetition_penalty: float,
 | 
			
		||||
        repetition_penalty_range: int,
 | 
			
		||||
        repetition_penalty_slope: float,
 | 
			
		||||
        no_repeat_ngram_size: int,
 | 
			
		||||
        encoder_no_repeat_ngram_size: int,
 | 
			
		||||
        encoder_input_ids: torch.LongTensor,
 | 
			
		||||
@ -625,7 +627,7 @@ class GenerationMixin:
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        if repetition_penalty is not None and repetition_penalty != 1.0:
 | 
			
		||||
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
 | 
			
		||||
            processors.append(RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty, m=repetition_penalty_slope, penalize_last=repetition_penalty_range))
 | 
			
		||||
        if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
 | 
			
		||||
            processors.append(NoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
 | 
			
		||||
        if encoder_no_repeat_ngram_size is not None and encoder_no_repeat_ngram_size > 0:
 | 
			
		||||
@ -675,6 +677,8 @@ class GenerationMixin:
 | 
			
		||||
        top_k: Optional[int] = None,
 | 
			
		||||
        top_p: Optional[float] = None,
 | 
			
		||||
        repetition_penalty: Optional[float] = None,
 | 
			
		||||
        repetition_penalty_range: Optional[int] = None,
 | 
			
		||||
        repetition_penalty_slope: Optional[float] = 3.33,
 | 
			
		||||
        bad_words_ids: Optional[Iterable[int]] = None,
 | 
			
		||||
        bos_token_id: Optional[int] = None,
 | 
			
		||||
        pad_token_id: Optional[int] = None,
 | 
			
		||||
@ -967,6 +971,8 @@ class GenerationMixin:
 | 
			
		||||
        # get distribution pre_processing samplers
 | 
			
		||||
        logits_processor = self._get_logits_processor(
 | 
			
		||||
            repetition_penalty=repetition_penalty,
 | 
			
		||||
            repetition_penalty_range=repetition_penalty_range,
 | 
			
		||||
            repetition_penalty_slope=repetition_penalty_slope,
 | 
			
		||||
            no_repeat_ngram_size=no_repeat_ngram_size,
 | 
			
		||||
            encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
 | 
			
		||||
            encoder_input_ids=encoder_input_ids,
 | 
			
		||||
 | 
			
		||||
@ -1138,9 +1138,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
 | 
			
		||||
 | 
			
		||||
            # copy state_dict so _load_from_state_dict can modify it
 | 
			
		||||
            metadata = getattr(state_dict, "_metadata", None)
 | 
			
		||||
            state_dict = state_dict.copy()
 | 
			
		||||
            if metadata is not None:
 | 
			
		||||
                state_dict._metadata = metadata
 | 
			
		||||
            #state_dict = state_dict.copy()
 | 
			
		||||
            #if metadata is not None:
 | 
			
		||||
            #    state_dict._metadata = metadata
 | 
			
		||||
 | 
			
		||||
            # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
 | 
			
		||||
            # so we need to apply the function recursively.
 | 
			
		||||
 | 
			
		||||
@ -416,6 +416,20 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
 | 
			
		||||
        # compute block length and num_blocks
 | 
			
		||||
        batch_size, seq_length = hidden_states.shape[:2]
 | 
			
		||||
        full_seq_length = seq_length + past_length
 | 
			
		||||
 | 
			
		||||
        padding = None
 | 
			
		||||
        if layer_past is None and full_seq_length % self.window_size != 0 and full_seq_length > self.window_size:
 | 
			
		||||
            padding = self.window_size-(full_seq_length%self.window_size)
 | 
			
		||||
            if attention_mask is None:
 | 
			
		||||
                attention_mask = torch.zeros(query.shape[0], query.shape[1] + padding).to(query.device)
 | 
			
		||||
                attention_mask[:, padding:] = 1
 | 
			
		||||
            else:
 | 
			
		||||
                attention_mask = torch.cat([torch.zeros(attention_mask.shape[0], padding).to(attention_mask.device), attention_mask], axis=1)
 | 
			
		||||
            pad = lambda x: torch.cat([torch.zeros(x.shape[0],padding,x.shape[2]).to(x.device), x], axis=1)
 | 
			
		||||
            query, key, value = map(pad, (query, key, value))
 | 
			
		||||
            seq_length += padding
 | 
			
		||||
            full_seq_length += padding
 | 
			
		||||
 | 
			
		||||
        block_length, num_blocks = self._get_block_length_and_num_blocks(full_seq_length, self.window_size)
 | 
			
		||||
 | 
			
		||||
        # create buckets
 | 
			
		||||
@ -457,7 +471,11 @@ class GPTNeoLocalSelfAttention(nn.Module, GPTNeoAttentionMixin):
 | 
			
		||||
        attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
 | 
			
		||||
        attn_output = attn_output.reshape(batch_size, seq_length, self.embed_dim)
 | 
			
		||||
 | 
			
		||||
        attn_output = self.out_proj(attn_output)
 | 
			
		||||
        if padding is not None:
 | 
			
		||||
            attn_output = attn_output[:,padding:]
 | 
			
		||||
            attn_weights = attn_weights[:,padding:]
 | 
			
		||||
 | 
			
		||||
        attn_output = self.out_proj(attn_output.to(hidden_states.dtype))
 | 
			
		||||
        attn_output = self.resid_dropout(attn_output)
 | 
			
		||||
 | 
			
		||||
        outputs = (attn_output,)
 | 
			
		||||
@ -704,7 +722,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
 | 
			
		||||
        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
 | 
			
		||||
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
 | 
			
		||||
        self.drop = nn.Dropout(config.embed_dropout)
 | 
			
		||||
        self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i) for i in range(config.num_layers)])
 | 
			
		||||
        self.h = nn.ModuleList([GPTNeoBlock(config, layer_id=i).half().cuda() for i in range(config.num_layers)])
 | 
			
		||||
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 | 
			
		||||
 | 
			
		||||
        self.init_weights()
 | 
			
		||||
@ -891,8 +909,8 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.transformer = GPTNeoModel(config)
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
        self.transformer = GPTNeoModel(config).half().cuda()
 | 
			
		||||
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False).half().cuda()
 | 
			
		||||
 | 
			
		||||
        self.init_weights()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user