Compare commits

...

3 Commits

Author SHA1 Message Date
c8b173d6b7 fix 2024-02-29 09:10:09 +09:00
c4aae74ba0 refactor 2024-02-29 09:03:34 +09:00
bb29728d40 refactor 2024-02-29 09:03:14 +09:00
5 changed files with 34 additions and 42 deletions

View File

@ -188,6 +188,8 @@ class AttentionMaskConverter:
@staticmethod
def _unmask_unattended(
expanded_mask: torch.FloatTensor,
attention_mask: torch.FloatTensor,
input_tensor: torch.FloatTensor,
min_dtype: float,
):
# fmt: off
@ -227,12 +229,24 @@ class AttentionMaskConverter:
```
"""
# fmt: on
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
if attention_mask is not None and attention_mask.device.type == "cuda":
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = torch.jit.is_tracing() or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
is_tracing |= isinstance(input_tensor, torch.fx.Proxy)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))
return expanded_mask
def _prepare_4d_causal_attention_mask(

View File

@ -1123,7 +1123,9 @@ class FalconModel(FalconPreTrainedModel):
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, hidden_states, min_dtype=min_dtype
)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(

View File

@ -979,22 +979,10 @@ class GemmaModel(GemmaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
if self.config._attn_implementation == "sdpa":
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, attention_mask, input_tensor, min_dtype
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask

View File

@ -991,6 +991,11 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
# Self-attention mask.
query_length = input_shape[-1]
key_length = past_length + query_length
@ -1036,7 +1041,7 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, min_dtype=min_dtype
self_attention_mask, attention_mask, hidden_states, min_dtype=min_dtype
)
attention_mask = self_attention_mask
@ -1061,11 +1066,6 @@ class GPTBigCodeModel(GPTBigCodePreTrainedModel):
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds

View File

@ -1091,22 +1091,10 @@ class LlamaModel(LlamaPreTrainedModel):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
if self.config._attn_implementation == "sdpa":
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, attention_mask, input_tensor, min_dtype
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask