mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-22 02:08:58 +08:00
Compare commits
3 Commits
v4.51.2
...
nit-refact
Author | SHA1 | Date | |
---|---|---|---|
c8b173d6b7 | |||
c4aae74ba0 | |||
bb29728d40 |
@ -188,6 +188,8 @@ class AttentionMaskConverter:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _unmask_unattended(
|
def _unmask_unattended(
|
||||||
expanded_mask: torch.FloatTensor,
|
expanded_mask: torch.FloatTensor,
|
||||||
|
attention_mask: torch.FloatTensor,
|
||||||
|
input_tensor: torch.FloatTensor,
|
||||||
min_dtype: float,
|
min_dtype: float,
|
||||||
):
|
):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@ -227,12 +229,24 @@ class AttentionMaskConverter:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
# fmt: on
|
# fmt: on
|
||||||
if expanded_mask.dtype == torch.bool:
|
|
||||||
raise ValueError(
|
|
||||||
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
|
||||||
)
|
|
||||||
|
|
||||||
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
if attention_mask is not None and attention_mask.device.type == "cuda":
|
||||||
|
if expanded_mask.dtype == torch.bool:
|
||||||
|
raise ValueError(
|
||||||
|
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
||||||
|
is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
||||||
|
is_tracing |= isinstance(input_tensor, torch.fx.Proxy)
|
||||||
|
|
||||||
|
if not is_tracing and torch.any(attention_mask != 1):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
|
||||||
|
|
||||||
|
return expanded_mask
|
||||||
|
|
||||||
|
|
||||||
def _prepare_4d_causal_attention_mask(
|
def _prepare_4d_causal_attention_mask(
|
||||||
|
@ -1123,7 +1123,9 @@ class FalconModel(FalconPreTrainedModel):
|
|||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
if seq_length > 1 and attention_mask.device.type == "cuda":
|
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||||
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||||
|
attention_mask, attention_mask_2d, hidden_states, min_dtype=min_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
@ -979,22 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
|
|||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
if (
|
if self.config._attn_implementation == "sdpa":
|
||||||
self.config._attn_implementation == "sdpa"
|
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||||
and attention_mask is not None
|
causal_mask, attention_mask, input_tensor, min_dtype
|
||||||
and attention_mask.device.type == "cuda"
|
|
||||||
):
|
|
||||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
|
||||||
is_tracing = (
|
|
||||||
torch.jit.is_tracing()
|
|
||||||
or isinstance(input_tensor, torch.fx.Proxy)
|
|
||||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
|
||||||
)
|
)
|
||||||
if not is_tracing and torch.any(attention_mask != 1):
|
|
||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
@ -991,6 +991,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
position_embeds = self.wpe(position_ids)
|
||||||
|
hidden_states = inputs_embeds + position_embeds
|
||||||
|
|
||||||
# Self-attention mask.
|
# Self-attention mask.
|
||||||
query_length = input_shape[-1]
|
query_length = input_shape[-1]
|
||||||
key_length = past_length + query_length
|
key_length = past_length + query_length
|
||||||
@ -1036,7 +1041,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
self_attention_mask = AttentionMaskConverter._unmask_unattended(
|
self_attention_mask = AttentionMaskConverter._unmask_unattended(
|
||||||
self_attention_mask, min_dtype=min_dtype
|
self_attention_mask, attention_mask, hidden_states, min_dtype=min_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
attention_mask = self_attention_mask
|
attention_mask = self_attention_mask
|
||||||
@ -1061,11 +1066,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
|
|||||||
# head_mask has shape n_layer x batch x n_heads x N x N
|
# head_mask has shape n_layer x batch x n_heads x N x N
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.wte(input_ids)
|
|
||||||
position_embeds = self.wpe(position_ids)
|
|
||||||
hidden_states = inputs_embeds + position_embeds
|
|
||||||
|
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_embeds = self.wte(token_type_ids)
|
token_type_embeds = self.wte(token_type_ids)
|
||||||
hidden_states = hidden_states + token_type_embeds
|
hidden_states = hidden_states + token_type_embeds
|
||||||
|
@ -1091,22 +1091,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
|
||||||
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
|
||||||
|
|
||||||
if (
|
if self.config._attn_implementation == "sdpa":
|
||||||
self.config._attn_implementation == "sdpa"
|
causal_mask = AttentionMaskConverter._unmask_unattended(
|
||||||
and attention_mask is not None
|
causal_mask, attention_mask, input_tensor, min_dtype
|
||||||
and attention_mask.device.type == "cuda"
|
|
||||||
):
|
|
||||||
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
|
|
||||||
is_tracing = (
|
|
||||||
torch.jit.is_tracing()
|
|
||||||
or isinstance(input_tensor, torch.fx.Proxy)
|
|
||||||
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
|
|
||||||
)
|
)
|
||||||
if not is_tracing and torch.any(attention_mask != 1):
|
|
||||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
|
||||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
|
||||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
|
||||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
|
||||||
|
|
||||||
return causal_mask
|
return causal_mask
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user