mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[Masks
] Fix mask handling in eager for vision models (#41625)
add mask handling in case of models that do use it
This commit is contained in:
@ -96,25 +96,28 @@ class ASTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -161,25 +161,28 @@ class DeiTPatchEmbeddings(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -149,25 +149,28 @@ class Dinov2PatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -176,18 +176,21 @@ def eager_attention_forward(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -194,18 +194,21 @@ def eager_attention_forward(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -32,7 +32,8 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BaseModelOutput, DepthEstimatorOutput, SemanticSegmenterOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import ModelOutput, auto_docstring, logging, torch_int
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging, torch_int
|
||||
from ...utils.backbone_utils import load_backbone
|
||||
from ...utils.generic import can_return_tuple, check_model_inputs
|
||||
from .configuration_dpt import DPTConfig
|
||||
@ -267,25 +268,28 @@ class DPTViTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -147,18 +147,21 @@ def eager_attention_forward(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -178,25 +178,28 @@ class VideoMAEPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -167,24 +167,28 @@ class ViTPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -326,25 +326,28 @@ class ViTMAEPatchEmbeddings(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -163,25 +163,28 @@ class ViTMSNPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -30,7 +30,8 @@ from ...activations import ACT2FN
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import BackboneOutput, BaseModelOutput
|
||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||
from ...utils import auto_docstring, logging
|
||||
from ...processing_utils import Unpack
|
||||
from ...utils import TransformersKwargs, auto_docstring, logging
|
||||
from ...utils.backbone_utils import BackboneMixin
|
||||
from ...utils.generic import check_model_inputs
|
||||
from .configuration_vitpose_backbone import VitPoseBackboneConfig
|
||||
@ -95,25 +96,28 @@ class VitPoseBackboneEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -156,25 +156,28 @@ class VivitEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
@ -211,25 +211,28 @@ class YolosPatchEmbeddings(nn.Module):
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.vit.modeling_vit.eager_attention_forward
|
||||
# Copied from transformers.models.bert.modeling_bert.eager_attention_forward
|
||||
def eager_attention_forward(
|
||||
module: nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
scaling: float,
|
||||
scaling: Optional[float] = None,
|
||||
dropout: float = 0.0,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[TransformersKwargs],
|
||||
):
|
||||
if scaling is None:
|
||||
scaling = query.size(-1) ** -0.5
|
||||
|
||||
# Take the dot product between "query" and "key" to get the raw attention scores.
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
|
||||
attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
Reference in New Issue
Block a user