Compare commits

...

7 Commits

Author SHA1 Message Date
9f43a425fe Release: v4.4.2 2021-03-18 15:09:04 -04:00
45dae78e61 Fix distributed evaluation (#10795)
* Fix distributed evaluation

* Use logger
2021-03-18 15:07:21 -04:00
12b04b5003 Smmp batch not divisible by microbatches fix (#10778)
* Added debug prints

* Added config

* Added prints

* Added prints

* Added extra samples to SequentialDistributedSampler

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

* Added deubg prints

* Removed extra prints

* Making predicitons and labels multiple of batchsize

* updated number of microbatches

* Removed extra prints

* Made start_remainder similar to DistributedSamplerWithLoop

* Minor spacing update

* Added debug prints

Added config

Added prints

Added prints

* Added extra samples to SequentialDistributedSampler

Updated SequentialDistributedSampler call

Added extra samples to SequentialDistributedSampler

Added deubg prints

Removed extra prints

Making predicitons and labels multiple of batchsize

updated number of microbatches

Removed extra prints

Squashing redundant commits

* Made start_remainder similar to DistributedSamplerWithLoop

Minor spacing update

Made start_remainder similar to DistributedSamplerWithLoop

* Test and styling

* Rename test

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
2021-03-18 15:05:26 -04:00
6460e9a0f3 Add support for detecting intel-tensorflow version (#10781)
Signed-off-by: Morgan Funtowicz <funtowiczmo@gmail.com>
2021-03-18 15:05:17 -04:00
f213d23941 Patches full import failure when sentencepiece is not installed (#10752)
* Patches full import failure when sentencepiece is not installed

* Dummies :)
2021-03-16 15:58:42 -04:00
c5d6a28810 Release: v4.4.1 2021-03-16 15:39:48 -04:00
3b9a733e03 Patches the full import failure and adds a test (#10750)
* Patches the full import failure and adds a test

* Add comment
2021-03-16 15:39:26 -04:00
12 changed files with 103 additions and 18 deletions

View File

@ -26,7 +26,9 @@ author = u'huggingface'
# The short X.Y version
version = u''
# The full version, including alpha/beta/rc tags
release = u'4.4.0'
release = u'4.4.2'
# Prefix link to point to master, comment this during version release and uncomment below line
extlinks = {'prefix_link': ('https://github.com/huggingface/transformers/blob/master/%s', '')}

View File

@ -278,7 +278,7 @@ install_requires = [
setup(
name="transformers",
version="4.4.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.4.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.4.0"
__version__ = "4.4.2"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.
@ -134,7 +134,7 @@ _import_structure = {
"Wav2Vec2FeatureExtractor",
"Wav2Vec2Processor",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100Tokenizer"],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"],
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
@ -171,7 +171,7 @@ _import_structure = {
"models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
"models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"],
"models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"],
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config", "DebertaV2Tokenizer"],
"models.deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
"models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"],
"models.dpr": [
"DPR_PRETRAINED_CONFIG_ARCHIVE_MAP",
@ -274,6 +274,8 @@ if is_sentencepiece_available():
_import_structure["models.barthez"].append("BarthezTokenizer")
_import_structure["models.bert_generation"].append("BertGenerationTokenizer")
_import_structure["models.camembert"].append("CamembertTokenizer")
_import_structure["models.deberta_v2"].append("DebertaV2Tokenizer")
_import_structure["models.m2m_100"].append("M2M100Tokenizer")
_import_structure["models.marian"].append("MarianTokenizer")
_import_structure["models.mbart"].append("MBartTokenizer")
_import_structure["models.mbart"].append("MBart50Tokenizer")
@ -652,10 +654,8 @@ if is_torch_available():
"IBertForQuestionAnswering",
"IBertForSequenceClassification",
"IBertForTokenClassification",
"IBertLayer",
"IBertModel",
"IBertPreTrainedModel",
"load_tf_weights_in_ibert",
]
)
_import_structure["models.layoutlm"].extend(
@ -1363,7 +1363,7 @@ if TYPE_CHECKING:
from .models.convbert import CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ConvBertConfig, ConvBertTokenizer
from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer
from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer
from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config, DebertaV2Tokenizer
from .models.deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer
from .models.dpr import (
DPR_PRETRAINED_CONFIG_ARCHIVE_MAP,
@ -1385,7 +1385,7 @@ if TYPE_CHECKING:
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config, M2M100Tokenizer
from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config
from .models.marian import MarianConfig
from .models.mbart import MBartConfig
from .models.mmbt import MMBTConfig
@ -1482,6 +1482,8 @@ if TYPE_CHECKING:
from .models.barthez import BarthezTokenizer
from .models.bert_generation import BertGenerationTokenizer
from .models.camembert import CamembertTokenizer
from .models.deberta_v2 import DebertaV2Tokenizer
from .models.m2m_100 import M2M100Tokenizer
from .models.marian import MarianTokenizer
from .models.mbart import MBart50Tokenizer, MBartTokenizer
from .models.mt5 import MT5Tokenizer

View File

@ -102,8 +102,12 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
try:
_tf_version = importlib_metadata.version("tf-nightly-gpu")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
# Support for intel-tensorflow version
try:
_tf_version = importlib_metadata.version("intel-tensorflow")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available:
if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")

View File

@ -18,7 +18,7 @@
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
from ...file_utils import _BaseLazyModule, is_torch_available
_import_structure = {
@ -28,6 +28,7 @@ _import_structure = {
if is_torch_available():
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertPreTrainedModel",
"IBertForMaskedLM",
"IBertForMultipleChoice",
"IBertForQuestionAnswering",
@ -48,6 +49,7 @@ if TYPE_CHECKING:
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
IBertPreTrainedModel,
)
else:

View File

@ -112,7 +112,12 @@ class SageMakerTrainer(Trainer):
def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
if self.is_model_parallel_enabled:
return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank())
return SequentialDistributedSampler(
eval_dataset,
num_replicas=smp.dp_size(),
rank=smp.dp_rank(),
batch_size=self.args.per_device_eval_batch_size,
)
else:
return super()._get_eval_sampler(eval_dataset)

View File

@ -670,7 +670,7 @@ class Trainer:
"""
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
"""
return len(dataloader.dataset)
@ -1783,8 +1783,13 @@ class Trainer:
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
if not prediction_loss_only:
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
# a batch size to the sampler)
make_multiple_of = None
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
make_multiple_of = dataloader.sampler.batch_size
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
model.eval()

View File

@ -220,7 +220,7 @@ class SequentialDistributedSampler(Sampler):
or `reduce` resulting tensors at the end of the loop.
"""
def __init__(self, dataset, num_replicas=None, rank=None):
def __init__(self, dataset, num_replicas=None, rank=None, batch_size=None):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@ -232,8 +232,14 @@ class SequentialDistributedSampler(Sampler):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
num_samples = len(self.dataset)
# Add extra samples to make num_samples a multiple of batch_size if passed
if batch_size is not None:
self.num_samples = int(math.ceil(num_samples / (batch_size * num_replicas))) * batch_size
else:
self.num_samples = int(math.ceil(num_samples / num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.batch_size = batch_size
def __iter__(self):
indices = list(range(len(self.dataset)))

View File

@ -38,6 +38,24 @@ class CamembertTokenizer:
requires_sentencepiece(self)
class DebertaV2Tokenizer:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
class M2M100Tokenizer:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
class MarianTokenizer:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)

View File

@ -15,6 +15,9 @@
import unittest
import requests
# Try to import everything from transformers to ensure every object can be loaded.
from transformers import * # noqa F406
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME, filename_to_url, get_from_cache, hf_bucket_url
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER

View File

@ -97,6 +97,11 @@ if __name__ == "__main__":
def compute_metrics(p: EvalPrediction) -> Dict:
sequential = list(range(len(dataset)))
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
if not success and training_args.local_rank == 0:
logger.warning(
"Predictions and/or labels do not match expected results:\n - predictions: "
f"{p.predictions.tolist()}\n - labels: {p.label_ids.tolist()}\n - expected: {sequential}"
)
return {"success": success}
trainer = Trainer(

View File

@ -31,6 +31,7 @@ if is_torch_available():
DistributedTensorGatherer,
LabelSmoother,
LengthGroupedSampler,
SequentialDistributedSampler,
get_parameter_names,
)
@ -167,3 +168,35 @@ class TrainerUtilsTest(unittest.TestCase):
self.assertEqual(set(total[:length]), set(dataset))
self.assertEqual(set(total[length:]), set(total[: (len(total) - length)]))
def test_sequential_distributed_sampler(self):
batch_size = 16
for length in [23, 64, 123]:
dataset = list(range(length))
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])
# With a batch_size passed
shard1 = SequentialDistributedSampler(dataset, num_replicas=2, rank=0, batch_size=batch_size)
shard2 = SequentialDistributedSampler(dataset, num_replicas=2, rank=1, batch_size=batch_size)
# Sample
samples1 = list(shard1)
samples2 = list(shard2)
self.assertTrue(len(samples1) % batch_size == 0)
self.assertTrue(len(samples2) % batch_size == 0)
total = samples1 + samples2
self.assertListEqual(total[:length], dataset)
self.assertListEqual(total[length:], dataset[: (len(total) - length)])