mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
3 Commits
v4.44.2
...
nit-refact
Author | SHA1 | Date | |
---|---|---|---|
c8b173d6b7 | |||
c4aae74ba0 | |||
bb29728d40 |
@ -188,6 +188,8 @@ class AttentionMaskConverter:
|
||||
@staticmethod
|
||||
def _unmask_unattended(
|
||||
expanded_mask: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
input_tensor: torch.FloatTensor,
|
||||
min_dtype: float,
|
||||
):
|
||||
# fmt: off
|
||||
@ -227,12 +229,24 @@ class AttentionMaskConverter:
|
||||
```
|
||||
"""
|
||||
# fmt: on
|
||||
if expanded_mask.dtype == torch.bool:
|
||||
raise ValueError(
|
||||
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
||||
)
|
||||
|
||||
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
||||
if attention_mask is not None and attention_mask.device.type == "cuda":
|
||||
if expanded_mask.dtype == torch.bool:
|
||||
raise ValueError(
|
||||
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
||||
)
|
||||
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
is_tracing |= isinstance(input_tensor, torch.fx.Proxy)
|
||||
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
||||
|
||||
return expanded_mask
|
||||
|
||||
|
||||
def _prepare_4d_causal_attention_mask(
|
||||
|
@ -1123,7 +1123,9 @@ class FalconModel(FalconPreTrainedModel):
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
attention_mask, attention_mask_2d, hidden_states, min_dtype=min_dtype
|
||||
)
|
||||
else:
|
||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
|
@ -979,22 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
):
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||
causal_mask, attention_mask, input_tensor, min_dtype
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
@ -991,6 +991,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
# Self-attention mask.
|
||||
query_length = input_shape[-1]
|
||||
key_length = past_length + query_length
|
||||
@ -1036,7 +1041,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
self_attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||
self_attention_mask, min_dtype=min_dtype
|
||||
self_attention_mask, attention_mask, hidden_states, min_dtype=min_dtype
|
||||
)
|
||||
|
||||
attention_mask = self_attention_mask
|
||||
@ -1061,11 +1066,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = inputs_embeds + position_embeds
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
hidden_states = hidden_states + token_type_embeds
|
||||
|
@ -1091,22 +1091,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
):
|
||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||
is_tracing = (
|
||||
torch.jit.is_tracing()
|
||||
or isinstance(input_tensor, torch.fx.Proxy)
|
||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||
if self.config._attn_implementation == "sdpa":
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||
causal_mask, attention_mask, input_tensor, min_dtype
|
||||
)
|
||||
if not is_tracing and torch.any(attention_mask != 1):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
Reference in New Issue
Block a user