Compare commits

...

5 Commits

Author SHA1 Message Date
db7d6a80e8 Release: v4.16.2 2022-01-31 11:40:27 -05:00
23538cef37 Add header (#15434) 2022-01-31 11:18:49 -05:00
1004fdf791 [Hotfix] Fix Swin model outputs (#15414)
* Fix Swin model outputs

* Rename pooler
2022-01-31 11:18:30 -05:00
c4ad38e5ac Release: v4.16.1 2022-01-28 11:53:30 -05:00
ce0102acd0 Add init to BORT (#15378)
* Add init to BORT

* BORT should be in init
2022-01-28 11:52:07 -05:00
7 changed files with 45 additions and 23 deletions

View File

@ -54,5 +54,7 @@ This model was contributed by [novice03](https://huggingface.co/novice03>). The
- forward
## SwinForImageClassification
[[autodoc]] transformers.SwinForImageClassification
- forward

View File

@ -351,7 +351,7 @@ install_requires = [
setup(
name="transformers",
version="4.16.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.16.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.16.0"
__version__ = "4.16.2"
from typing import TYPE_CHECKING

View File

@ -31,6 +31,7 @@ from . import (
bigbird_pegasus,
blenderbot,
blenderbot_small,
bort,
byt5,
camembert,
canine,

View File

View File

@ -21,11 +21,11 @@ import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from .configuration_swin import SwinConfig
@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, pixel_values):
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2)
return pixel_values
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
return embeddings
class SwinPatchMerging(nn.Module):
@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
SWIN_START_DOCSTRING,
)
class SwinModel(SwinPreTrainedModel):
def __init__(self, config):
def __init__(self, config, add_pooling_layer=True):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
self.pool = nn.AdaptiveAvgPool1d(1)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()
@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
self.encoder.layer[layer].attention.prune_heads(heads)
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
sequence_output = self.pool(sequence_output.transpose(1, 2))
sequence_output = torch.flatten(sequence_output, 1)
pooled_output = None
if self.pooler is not None:
pooled_output = self.pooler(sequence_output.transpose(1, 2))
pooled_output = torch.flatten(pooled_output, 1)
if not return_dict:
return (sequence_output,) + encoder_outputs[1:]
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutput(
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
return_dict=return_dict,
)
sequence_output = outputs[0]
pooled_output = outputs[1]
logits = self.classifier(sequence_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(

View File

@ -137,9 +137,11 @@ class SwinModelTester:
model.eval()
result = model(pixel_values)
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1))
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
expected_seq_len = (config.image_size // config.patch_size) ** 2
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features))
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
def create_and_check_for_image_classification(self, config, pixel_values, labels):
config.num_labels = self.type_sequence_label_size
@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device)
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))