mirror of
				https://github.com/huggingface/transformers.git
				synced 2025-11-04 03:44:37 +08:00 
			
		
		
		
	Compare commits
	
		
			3 Commits
		
	
	
		
			v4.52.0
			...
			nit-refact
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| c8b173d6b7 | |||
| c4aae74ba0 | |||
| bb29728d40 | 
@ -188,6 +188,8 @@ class AttentionMaskConverter:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _unmask_unattended(
 | 
			
		||||
        expanded_mask: torch.FloatTensor,
 | 
			
		||||
        attention_mask: torch.FloatTensor,
 | 
			
		||||
        input_tensor: torch.FloatTensor,
 | 
			
		||||
        min_dtype: float,
 | 
			
		||||
    ):
 | 
			
		||||
        # fmt: off
 | 
			
		||||
@ -227,13 +229,25 @@ class AttentionMaskConverter:
 | 
			
		||||
        ```
 | 
			
		||||
        """
 | 
			
		||||
        # fmt: on
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None and attention_mask.device.type == "cuda":
 | 
			
		||||
            if expanded_mask.dtype == torch.bool:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
 | 
			
		||||
            is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
 | 
			
		||||
            is_tracing |= isinstance(input_tensor, torch.fx.Proxy)
 | 
			
		||||
 | 
			
		||||
            if not is_tracing and torch.any(attention_mask != 1):
 | 
			
		||||
                # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
 | 
			
		||||
                # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
 | 
			
		||||
                # Details: https://github.com/pytorch/pytorch/issues/110213
 | 
			
		||||
                return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
 | 
			
		||||
 | 
			
		||||
        return expanded_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_4d_causal_attention_mask(
 | 
			
		||||
    attention_mask: Optional[torch.Tensor],
 | 
			
		||||
 | 
			
		||||
@ -1123,7 +1123,9 @@ class FalconModel(FalconPreTrainedModel):
 | 
			
		||||
                    # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
 | 
			
		||||
                    # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
 | 
			
		||||
                    if seq_length > 1 and attention_mask.device.type == "cuda":
 | 
			
		||||
                        attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
 | 
			
		||||
                        attention_mask = AttentionMaskConverter._unmask_unattended(
 | 
			
		||||
                            attention_mask, attention_mask_2d, hidden_states, min_dtype=min_dtype
 | 
			
		||||
                        )
 | 
			
		||||
            else:
 | 
			
		||||
                # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
 | 
			
		||||
                attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
 | 
			
		||||
@ -979,22 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
 | 
			
		||||
            padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
 | 
			
		||||
            causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            self.config._attn_implementation == "sdpa"
 | 
			
		||||
            and attention_mask is not None
 | 
			
		||||
            and attention_mask.device.type == "cuda"
 | 
			
		||||
        ):
 | 
			
		||||
            # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
 | 
			
		||||
            is_tracing = (
 | 
			
		||||
                torch.jit.is_tracing()
 | 
			
		||||
                or isinstance(input_tensor, torch.fx.Proxy)
 | 
			
		||||
                or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
 | 
			
		||||
        if self.config._attn_implementation == "sdpa":
 | 
			
		||||
            causal_mask = AttentionMaskConverter._unmask_unattended(
 | 
			
		||||
                causal_mask, attention_mask, input_tensor, min_dtype
 | 
			
		||||
            )
 | 
			
		||||
            if not is_tracing and torch.any(attention_mask != 1):
 | 
			
		||||
                # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
 | 
			
		||||
                # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
 | 
			
		||||
                # Details: https://github.com/pytorch/pytorch/issues/110213
 | 
			
		||||
                causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 | 
			
		||||
 | 
			
		||||
        return causal_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -991,6 +991,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
 | 
			
		||||
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
 | 
			
		||||
            position_ids = position_ids.unsqueeze(0)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.wte(input_ids)
 | 
			
		||||
        position_embeds = self.wpe(position_ids)
 | 
			
		||||
        hidden_states = inputs_embeds + position_embeds
 | 
			
		||||
 | 
			
		||||
        # Self-attention mask.
 | 
			
		||||
        query_length = input_shape[-1]
 | 
			
		||||
        key_length = past_length + query_length
 | 
			
		||||
@ -1036,7 +1041,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
 | 
			
		||||
                    # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
 | 
			
		||||
                    # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
 | 
			
		||||
                    self_attention_mask = AttentionMaskConverter._unmask_unattended(
 | 
			
		||||
                        self_attention_mask, min_dtype=min_dtype
 | 
			
		||||
                        self_attention_mask, attention_mask, hidden_states, min_dtype=min_dtype
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            attention_mask = self_attention_mask
 | 
			
		||||
@ -1061,11 +1066,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
 | 
			
		||||
        # head_mask has shape n_layer x batch x n_heads x N x N
 | 
			
		||||
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.wte(input_ids)
 | 
			
		||||
        position_embeds = self.wpe(position_ids)
 | 
			
		||||
        hidden_states = inputs_embeds + position_embeds
 | 
			
		||||
 | 
			
		||||
        if token_type_ids is not None:
 | 
			
		||||
            token_type_embeds = self.wte(token_type_ids)
 | 
			
		||||
            hidden_states = hidden_states + token_type_embeds
 | 
			
		||||
 | 
			
		||||
@ -1091,22 +1091,10 @@ class LlamaModel(LlamaPreTrainedModel):
 | 
			
		||||
            padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
 | 
			
		||||
            causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            self.config._attn_implementation == "sdpa"
 | 
			
		||||
            and attention_mask is not None
 | 
			
		||||
            and attention_mask.device.type == "cuda"
 | 
			
		||||
        ):
 | 
			
		||||
            # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
 | 
			
		||||
            is_tracing = (
 | 
			
		||||
                torch.jit.is_tracing()
 | 
			
		||||
                or isinstance(input_tensor, torch.fx.Proxy)
 | 
			
		||||
                or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
 | 
			
		||||
        if self.config._attn_implementation == "sdpa":
 | 
			
		||||
            causal_mask = AttentionMaskConverter._unmask_unattended(
 | 
			
		||||
                causal_mask, attention_mask, input_tensor, min_dtype
 | 
			
		||||
            )
 | 
			
		||||
            if not is_tracing and torch.any(attention_mask != 1):
 | 
			
		||||
                # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
 | 
			
		||||
                # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
 | 
			
		||||
                # Details: https://github.com/pytorch/pytorch/issues/110213
 | 
			
		||||
                causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 | 
			
		||||
 | 
			
		||||
        return causal_mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user