mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Compare commits
8 Commits
timm_wrapp
...
v4.3.2
Author | SHA1 | Date | |
---|---|---|---|
cd48078ce5 | |||
727ab9d398 | |||
c95fae6d65 | |||
cc86472c78 | |||
02451cda74 | |||
800f385d78 | |||
bcf49c0438 | |||
15a8906c71 |
@ -26,7 +26,7 @@ author = u'huggingface'
|
||||
# The short X.Y version
|
||||
version = u''
|
||||
# The full version, including alpha/beta/rc tags
|
||||
release = u'4.2.0'
|
||||
release = u'4.3.0'
|
||||
# Prefix link to point to master, comment this during version release and uncomment below line
|
||||
extlinks = {'prefix_link': ('https://github.com/huggingface/transformers/blob/master/%s', '')}
|
||||
# Prefix link to always point to corresponding version, uncomment this during version release
|
||||
|
@ -58,8 +58,8 @@ Wav2Vec2Model
|
||||
:members: forward
|
||||
|
||||
|
||||
Wav2Vec2ForMaskedLM
|
||||
Wav2Vec2ForCTC
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Wav2Vec2ForMaskedLM
|
||||
.. autoclass:: transformers.Wav2Vec2ForCTC
|
||||
:members: forward
|
||||
|
6
setup.py
6
setup.py
@ -102,7 +102,7 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.0",
|
||||
"jax>=0.2.8",
|
||||
"jaxlib>=0.1.59",
|
||||
"keras2onnx",
|
||||
"numpy>=1.17",
|
||||
@ -132,7 +132,7 @@ _deps = [
|
||||
"tensorflow-cpu>=2.3",
|
||||
"tensorflow>=2.3",
|
||||
"timeout-decorator",
|
||||
"tokenizers==0.10.1rc1",
|
||||
"tokenizers>=0.10.1,<0.11",
|
||||
"torch>=1.0",
|
||||
"tqdm>=4.27",
|
||||
"unidic>=1.0.2",
|
||||
@ -282,7 +282,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.3.0.rc1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.3.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
|
@ -22,7 +22,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.3.0.rc1"
|
||||
__version__ = "4.3.2"
|
||||
|
||||
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||
# default Python logging output behavior when present.
|
||||
@ -367,6 +367,7 @@ if is_torch_available():
|
||||
_import_structure["models.wav2vec2"].extend(
|
||||
[
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
@ -1813,6 +1814,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
from .models.wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
|
@ -15,7 +15,7 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.0",
|
||||
"jax": "jax>=0.2.8",
|
||||
"jaxlib": "jaxlib>=0.1.59",
|
||||
"keras2onnx": "keras2onnx",
|
||||
"numpy": "numpy>=1.17",
|
||||
@ -45,7 +45,7 @@ deps = {
|
||||
"tensorflow-cpu": "tensorflow-cpu>=2.3",
|
||||
"tensorflow": "tensorflow>=2.3",
|
||||
"timeout-decorator": "timeout-decorator",
|
||||
"tokenizers": "tokenizers==0.10.1rc1",
|
||||
"tokenizers": "tokenizers>=0.10.1,<0.11",
|
||||
"torch": "torch>=1.0",
|
||||
"tqdm": "tqdm>=4.27",
|
||||
"unidic": "unidic>=1.0.2",
|
||||
|
@ -144,7 +144,11 @@ try:
|
||||
_faiss_version = importlib_metadata.version("faiss")
|
||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
try:
|
||||
_faiss_version = importlib_metadata.version("faiss-cpu")
|
||||
logger.debug(f"Successfully imported faiss version {_faiss_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_faiss_available = False
|
||||
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
|
@ -1306,6 +1306,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
eos_token_id=None,
|
||||
length_penalty=None,
|
||||
no_repeat_ngram_size=None,
|
||||
encoder_no_repeat_ngram_size=None,
|
||||
repetition_penalty=None,
|
||||
bad_words_ids=None,
|
||||
num_return_sequences=None,
|
||||
@ -1372,6 +1373,9 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
order to encourage the model to produce longer sequences.
|
||||
no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size can only occur once.
|
||||
encoder_no_repeat_ngram_size (:obj:`int`, `optional`, defaults to 0):
|
||||
If set to int > 0, all ngrams of that size that occur in the ``encoder_input_ids`` cannot occur in the
|
||||
``decoder_input_ids``.
|
||||
bad_words_ids(:obj:`List[int]`, `optional`):
|
||||
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
|
||||
should not appear in the generated text, use :obj:`tokenizer.encode(bad_word, add_prefix_space=True)`.
|
||||
@ -1490,6 +1494,8 @@ class RagTokenForGeneration(RagPreTrainedModel):
|
||||
pre_processor = self._get_logits_processor(
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
encoder_no_repeat_ngram_size=encoder_no_repeat_ngram_size,
|
||||
encoder_input_ids=context_input_ids,
|
||||
bad_words_ids=bad_words_ids,
|
||||
min_length=min_length,
|
||||
eos_token_id=eos_token_id,
|
||||
|
@ -29,6 +29,7 @@ if is_torch_available():
|
||||
_import_structure["modeling_wav2vec2"] = [
|
||||
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"Wav2Vec2ForCTC",
|
||||
"Wav2Vec2Model",
|
||||
"Wav2Vec2PreTrainedModel",
|
||||
]
|
||||
@ -41,6 +42,7 @@ if TYPE_CHECKING:
|
||||
if is_torch_available():
|
||||
from .modeling_wav2vec2 import (
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
Wav2Vec2ForCTC,
|
||||
Wav2Vec2ForMaskedLM,
|
||||
Wav2Vec2Model,
|
||||
Wav2Vec2PreTrainedModel,
|
||||
|
@ -20,7 +20,7 @@ import argparse
|
||||
import fairseq
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, logging
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
@ -141,7 +141,7 @@ def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
hf_wav2vec = Wav2Vec2ForMaskedLM(Wav2Vec2Config())
|
||||
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
|
||||
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||
|
@ -15,6 +15,7 @@
|
||||
""" PyTorch Wav2Vec2 model. """
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
@ -24,7 +25,7 @@ from torch import nn
|
||||
|
||||
from ...activations import ACT2FN
|
||||
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
||||
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput
|
||||
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, MaskedLMOutput
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from ...utils import logging
|
||||
from .configuration_wav2vec2 import Wav2Vec2Config
|
||||
@ -665,6 +666,10 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
warnings.warn(
|
||||
"The class `Wav2Vec2ForMaskedLM` is deprecated. Please use `Wav2Vec2ForCTC` instead.", FutureWarning
|
||||
)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
@ -729,3 +734,77 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
return output
|
||||
|
||||
return MaskedLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"""Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """,
|
||||
WAV_2_VEC_2_START_DOCSTRING,
|
||||
)
|
||||
class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
TODO(PVP): Fill out when adding training
|
||||
|
||||
Returns:
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import Wav2Vec2Tokenizer, Wav2Vec2Model
|
||||
>>> from datasets import load_dataset
|
||||
>>> import soundfile as sf
|
||||
|
||||
>>> tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
>>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
|
||||
>>> def map_to_array(batch):
|
||||
>>> speech, _ = sf.read(batch["file"])
|
||||
>>> batch["speech"] = speech
|
||||
>>> return batch
|
||||
|
||||
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
|
||||
>>> ds = ds.map(map_to_array)
|
||||
|
||||
>>> input_values = tokenizer(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
>>> transcription = tokenizer.decode(predicted_ids[0])
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
|
||||
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
|
@ -2229,6 +2229,11 @@ def load_tf_weights_in_transfo_xl(*args, **kwargs):
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class Wav2Vec2ForCTC:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class Wav2Vec2ForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2ForMaskedLM, Wav2Vec2Model, Wav2Vec2Tokenizer
|
||||
|
||||
|
||||
class Wav2Vec2ModelTester:
|
||||
@ -204,7 +204,7 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM) if is_torch_available() else ()
|
||||
all_model_classes = (Wav2Vec2Model, Wav2Vec2ForMaskedLM, Wav2Vec2ForCTC) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
test_headmasking = False
|
||||
test_torchscript = False
|
||||
@ -289,7 +289,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
return ds["speech"][:num_samples]
|
||||
|
||||
def test_inference_masked_lm_normal(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
@ -307,7 +307,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_normal_batched(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||
model.to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||
|
||||
@ -330,7 +330,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||
|
||||
def test_inference_masked_lm_robust_batched(self):
|
||||
model = Wav2Vec2ForMaskedLM.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||
|
||||
input_speech = self._load_datasamples(4)
|
||||
|
@ -118,6 +118,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
||||
"TFMT5EncoderModel",
|
||||
"TFOpenAIGPTDoubleHeadsModel",
|
||||
"TFT5EncoderModel",
|
||||
"Wav2Vec2ForCTC",
|
||||
"XLMForQuestionAnswering",
|
||||
"XLMProphetNetDecoder",
|
||||
"XLMProphetNetEncoder",
|
||||
@ -370,6 +371,7 @@ DEPRECATED_OBJECTS = [
|
||||
"TFBartPretrainedModel",
|
||||
"TextDataset",
|
||||
"TextDatasetForNextSentencePrediction",
|
||||
"Wav2Vec2ForMaskedLM",
|
||||
"glue_compute_metrics",
|
||||
"glue_convert_examples_to_features",
|
||||
"glue_output_modes",
|
||||
|
Reference in New Issue
Block a user