mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
5 Commits
remove-tf-
...
v4.16.2
Author | SHA1 | Date | |
---|---|---|---|
db7d6a80e8 | |||
23538cef37 | |||
1004fdf791 | |||
c4ad38e5ac | |||
ce0102acd0 |
@ -54,5 +54,7 @@ This model was contributed by [novice03](https://huggingface.co/novice03>). The
|
||||
- forward
|
||||
|
||||
|
||||
## SwinForImageClassification
|
||||
|
||||
[[autodoc]] transformers.SwinForImageClassification
|
||||
- forward
|
2
setup.py
2
setup.py
@ -351,7 +351,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.16.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.16.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
|
@ -22,7 +22,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.16.0"
|
||||
__version__ = "4.16.2"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
@ -31,6 +31,7 @@ from . import (
|
||||
bigbird_pegasus,
|
||||
blenderbot,
|
||||
blenderbot_small,
|
||||
bort,
|
||||
byt5,
|
||||
camembert,
|
||||
canine,
|
||||
|
0
src/transformers/models/bort/__init__.py
Normal file
0
src/transformers/models/bort/__init__.py
Normal file
@ -21,11 +21,11 @@ import math
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, SequenceClassifierOutput
|
||||
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, SequenceClassifierOutput
|
||||
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
||||
from ...utils import logging
|
||||
from .configuration_swin import SwinConfig
|
||||
@ -143,8 +143,8 @@ class SwinPatchEmbeddings(nn.Module):
|
||||
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
pixel_values = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
return pixel_values
|
||||
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
||||
return embeddings
|
||||
|
||||
|
||||
class SwinPatchMerging(nn.Module):
|
||||
@ -659,7 +659,7 @@ SWIN_INPUTS_DOCSTRING = r"""
|
||||
SWIN_START_DOCSTRING,
|
||||
)
|
||||
class SwinModel(SwinPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.num_layers = len(config.depths)
|
||||
@ -669,7 +669,7 @@ class SwinModel(SwinPreTrainedModel):
|
||||
self.encoder = SwinEncoder(config, self.embeddings.patch_grid)
|
||||
|
||||
self.layernorm = nn.LayerNorm(self.num_features, eps=config.layer_norm_eps)
|
||||
self.pool = nn.AdaptiveAvgPool1d(1)
|
||||
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -686,7 +686,7 @@ class SwinModel(SwinPreTrainedModel):
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
@add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
pixel_values=None,
|
||||
@ -744,14 +744,18 @@ class SwinModel(SwinPreTrainedModel):
|
||||
|
||||
sequence_output = encoder_outputs[0]
|
||||
sequence_output = self.layernorm(sequence_output)
|
||||
sequence_output = self.pool(sequence_output.transpose(1, 2))
|
||||
sequence_output = torch.flatten(sequence_output, 1)
|
||||
|
||||
pooled_output = None
|
||||
if self.pooler is not None:
|
||||
pooled_output = self.pooler(sequence_output.transpose(1, 2))
|
||||
pooled_output = torch.flatten(pooled_output, 1)
|
||||
|
||||
if not return_dict:
|
||||
return (sequence_output,) + encoder_outputs[1:]
|
||||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutput(
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=sequence_output,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
@ -829,22 +833,35 @@ class SwinForImageClassification(SwinPreTrainedModel):
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
pooled_output = outputs[1]
|
||||
|
||||
logits = self.classifier(sequence_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
if self.config.problem_type is None:
|
||||
if self.num_labels == 1:
|
||||
self.config.problem_type = "regression"
|
||||
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
||||
self.config.problem_type = "single_label_classification"
|
||||
else:
|
||||
self.config.problem_type = "multi_label_classification"
|
||||
|
||||
if self.config.problem_type == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
if self.num_labels == 1:
|
||||
loss = loss_fct(logits.squeeze(), labels.squeeze())
|
||||
else:
|
||||
loss = loss_fct(logits, labels)
|
||||
elif self.config.problem_type == "single_label_classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
elif self.config.problem_type == "multi_label_classification":
|
||||
loss_fct = BCEWithLogitsLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return SequenceClassifierOutput(
|
||||
|
@ -137,9 +137,11 @@ class SwinModelTester:
|
||||
model.eval()
|
||||
result = model(pixel_values)
|
||||
|
||||
num_features = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||
# since the model we're testing only consists of a single layer, expected_seq_len = number of patches
|
||||
expected_seq_len = (config.image_size // config.patch_size) ** 2
|
||||
expected_dim = int(config.embed_dim * 2 ** (len(config.depths) - 1))
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_features))
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, expected_seq_len, expected_dim))
|
||||
|
||||
def create_and_check_for_image_classification(self, config, pixel_values, labels):
|
||||
config.num_labels = self.type_sequence_label_size
|
||||
@ -392,6 +394,6 @@ class SwinModelIntegrationTest(unittest.TestCase):
|
||||
expected_shape = torch.Size((1, 1000))
|
||||
self.assertEqual(outputs.logits.shape, expected_shape)
|
||||
|
||||
expected_slice = torch.tensor([-0.2952, -0.4777, 0.2025]).to(torch_device)
|
||||
expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device)
|
||||
|
||||
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
|
||||
|
Reference in New Issue
Block a user