Compare commits

...

8 Commits

Author SHA1 Message Date
7655f11076 Release v4.11.2 2021-09-30 11:54:39 -04:00
6b87918441 Fix gather for TPU (#13813) 2021-09-30 11:53:13 -04:00
54f9d62c61 Release v4.11.1 2021-09-29 12:04:25 -04:00
22d3156881 Fix length of IterableDatasetShard and add test (#13792)
* Fix length of IterableDatasetShard and add test

* Add comments
2021-09-29 12:03:56 -04:00
9bb3d33a46 Implement len in IterableDatasetShard (#13780) 2021-09-29 12:03:51 -04:00
a05400e020 Fix warning for gradient_checkpointing (#13767) 2021-09-29 12:03:46 -04:00
10083244a3 Fix LayoutLM ONNX test error (#13710)
Fix LayoutLM ONNX test error
2021-09-29 12:03:43 -04:00
11144a3048 up (#13777) 2021-09-29 12:03:40 -04:00
11 changed files with 63 additions and 11 deletions

View File

@ -27,7 +27,9 @@ author = "huggingface"
# The short X.Y version
version = ""
# The full version, including alpha/beta/rc tags
release = "4.11.0"
release = "4.11.2"

View File

@ -42,6 +42,7 @@ Ready-made configurations include the following models:
- BERT
- DistilBERT
- GPT-2
- LayoutLM
- RoBERTa
- T5
- XLM-RoBERTa

View File

@ -344,7 +344,7 @@ install_requires = [
setup(
name="transformers",
version="4.11.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="4.11.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Suraj Patil, Stas Bekman, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.11.0"
__version__ = "4.11.2"
# Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present.

View File

@ -332,7 +332,7 @@ class PretrainedConfig(PushToHubMixin):
self.transformers_version = kwargs.pop("transformers_version", None)
# Deal with gradient checkpointing
if kwargs.get("gradient_checkpointing", True):
if kwargs.get("gradient_checkpointing", False):
warnings.warn(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "

View File

@ -964,6 +964,14 @@ class HubertForCTC(HubertPreTrainedModel):
self.hubert = HubertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `HubertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()

View File

@ -183,11 +183,6 @@ class LayoutLMOnnxConfig(OnnxConfig):
raise ValueError("Cannot generate dummy inputs without PyTorch installed.")
import torch
input_dict["bbox"] = torch.tensor(
[
[0] * 4,
*[box] * seq_length,
[self.max_2d_positions] * 4,
]
).tile(batch_size, 1, 1)
batch_size, seq_length = input_dict["input_ids"].shape
input_dict["bbox"] = torch.tensor([*[box] * seq_length]).tile(batch_size, 1, 1)
return input_dict

View File

@ -1416,6 +1416,14 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`."
"or define `vocab_size` of your model's configuration."
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()

View File

@ -152,6 +152,8 @@ def nested_xla_mesh_reduce(tensors, name):
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_xla_mesh_reduce(t, f"{name}_{i}") for i, t in enumerate(tensors))
if tensors.ndim == 0:
tensors = tensors[None]
return xm.mesh_reduce(name, tensors, torch.cat)
else:
raise ImportError("Torch xla must be installed to use `nested_xla_mesh_reduce`")
@ -772,6 +774,13 @@ class IterableDatasetShard(IterableDataset):
for i in process_slice:
yield current_batch[i]
def __len__(self):
# Will raise an error if the underlying dataset is not sized.
if self.drop_last:
return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size
else:
return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
# helper methods here

View File

@ -355,6 +355,34 @@ class TrainerUtilsTest(unittest.TestCase):
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
def test_iterable_dataset_shard_with_length(self):
sampler_shards = [
IterableDatasetShard(list(range(100)), batch_size=4, drop_last=True, num_processes=2, process_index=i)
for i in range(2)
]
# Build expected shards: each process will have batches of size 4 until there is not enough elements to
# form two full batches (so we stop at 96 = (100 // (4 * 2)) * 4)
expected_shards = [[], []]
current_shard = 0
for i in range(0, 96, 4):
expected_shards[current_shard].extend(list(range(i, i + 4)))
current_shard = 1 - current_shard
self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards)
self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards])
sampler_shards = [
IterableDatasetShard(list(range(100)), batch_size=4, drop_last=False, num_processes=2, process_index=i)
for i in range(2)
]
# When drop_last=False, we get two last full batches by looping back to the beginning.
expected_shards[0].extend(list(range(96, 100)))
expected_shards[1].extend(list(range(0, 4)))
self.assertListEqual([list(shard) for shard in sampler_shards], expected_shards)
self.assertListEqual([len(shard) for shard in sampler_shards], [len(shard) for shard in expected_shards])
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
shards = [
ShardSampler(

View File

@ -281,6 +281,7 @@ SPECIAL_MODULE_TO_TEST_MAP = {
"test_trainer_distributed.py",
"test_trainer_tpu.py",
],
"train_pt_utils.py": "test_trainer_utils.py",
"utils/versions.py": "test_versions_utils.py",
}