Compare commits

...

10 Commits

Author SHA1 Message Date
98af96a156 Release: v4.2.2 2021-01-21 09:06:41 +01:00
0ed8bccc9c Fix GPT conversion script (#9676) 2021-01-21 09:00:32 +01:00
c5f6719040 Fix imports in conversion scripts (#9674) 2021-01-21 09:00:26 +01:00
21d45958af [TF Led] Fix wrong decoder attention mask behavior (#9601)
* fix tf led

* remove loop file
2021-01-21 08:57:35 +01:00
4d4d2ce135 Remove branch overload in conda yml 2021-01-14 14:27:36 +01:00
1528b1007f Upload v4.2.1 to anaconda 2021-01-14 14:24:32 +01:00
236cc365af Release: v4.2.1 2021-01-14 14:17:56 +01:00
5b05321b56 BatchEncoding.to with device with tests (#9584) 2021-01-14 14:08:40 +01:00
412d878c5e Compliancy with tf-nightly (#9570)
* Compliancy with tf-nightly

* Add more version + restore min version check
2021-01-14 14:08:30 +01:00
59fbd64b1c Fix Trainer with a parallel model (#9578)
* Fix Trainer with a parallel model

* More clean up
2021-01-14 14:08:19 +01:00
38 changed files with 125 additions and 101 deletions

View File

@ -37,7 +37,8 @@ jobs:
- name: Build conda packages - name: Build conda packages
run: | run: |
conda info conda info
conda build .github/conda conda list
conda-build .github/conda
- name: Upload to Anaconda - name: Upload to Anaconda
run: anaconda upload `conda build .github/conda --output` --force run: anaconda upload `conda-build .github/conda --output` --force

View File

@ -248,7 +248,7 @@ install_requires = [
setup( setup(
name="transformers", name="transformers",
version="4.2.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="4.2.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors", author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
author_email="thomas@huggingface.co", author_email="thomas@huggingface.co",
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch", description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",

View File

@ -22,7 +22,7 @@
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names # to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
# in the namespace without actually importing anything (and especially none of the backends). # in the namespace without actually importing anything (and especially none of the backends).
__version__ = "4.2.0" __version__ = "4.2.2"
# Work around to update TensorFlow's absl.logging threshold which alters the # Work around to update TensorFlow's absl.logging threshold which alters the
# default Python logging output behavior when present. # default Python logging output behavior when present.

View File

@ -89,8 +89,20 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
try: try:
_tf_version = importlib_metadata.version("tensorflow-cpu") _tf_version = importlib_metadata.version("tensorflow-cpu")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_tf_version = None try:
_tf_available = False _tf_version = importlib_metadata.version("tensorflow-gpu")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly-cpu")
except importlib_metadata.PackageNotFoundError:
try:
_tf_version = importlib_metadata.version("tf-nightly-gpu")
except importlib_metadata.PackageNotFoundError:
_tf_version = None
_tf_available = False
if _tf_available: if _tf_available:
if version.parse(_tf_version) < version.parse("2"): if version.parse(_tf_version) < version.parse("2"):
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.") logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")

View File

@ -19,8 +19,8 @@ import argparse
import torch import torch
from ...utils import logging from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -23,9 +23,15 @@ import fairseq
import torch import torch
from packaging import version from packaging import version
from ...utils import logging from transformers import (
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer BartConfig,
from .modeling_bart import _make_linear_from_emb BartForConditionalGeneration,
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.models.bart.modeling_bart import _make_linear_from_emb
from transformers.utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]

View File

@ -28,8 +28,8 @@ import re
import tensorflow as tf import tensorflow as tf
import torch import torch
from ...utils import logging from transformers import BertConfig, BertModel
from . import BertConfig, BertModel from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -19,8 +19,8 @@ import argparse
import torch import torch
from ...utils import logging from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -22,7 +22,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from . import BertModel from transformers import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):

View File

@ -18,8 +18,8 @@ import argparse
import torch import torch
from ...models.bart import BartConfig, BartForConditionalGeneration from transformers import BartConfig, BartForConditionalGeneration
from ...utils import logging from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -17,7 +17,7 @@ import os
import torch import torch
from ...file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
DIALOGPT_MODELS = ["small", "medium", "large"] DIALOGPT_MODELS = ["small", "medium", "large"]

View File

@ -19,8 +19,7 @@ from pathlib import Path
import torch import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from ...models.bert import BertConfig from .transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
CheckpointState = collections.namedtuple( CheckpointState = collections.namedtuple(

View File

@ -19,8 +19,8 @@ import argparse
import torch import torch
from ...utils import logging from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -31,10 +31,11 @@ import torch
from fairseq import hub_utils from fairseq import hub_utils
from fairseq.data.dictionary import Dictionary from fairseq.data.dictionary import Dictionary
from ...file_utils import WEIGHTS_NAME from transfomers.models.fsmt.tokenization_fsmt import VOCAB_FILES_NAMES
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE from transformers import FSMTConfig, FSMTForConditionalGeneration
from ...utils import logging from transformers.file_utils import WEIGHTS_NAME
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
from transformers.utils import logging
logging.set_verbosity_warning() logging.set_verbosity_warning()

View File

@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_availa
_import_structure = { _import_structure = {
"configuration_funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig"], "configuration_funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig"],
"convert_funnel_original_tf_checkpoint_to_pytorch": [],
"tokenization_funnel": ["FunnelTokenizer"], "tokenization_funnel": ["FunnelTokenizer"],
} }

View File

@ -16,14 +16,14 @@
import argparse import argparse
import logging
import torch import torch
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
from transformers.utils import logging
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):

View File

@ -19,9 +19,9 @@ import argparse
import torch import torch
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
from ...utils import logging from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
hidden_states = inputs["inputs_embeds"] hidden_states = inputs["inputs_embeds"]
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1: if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
else: else:
@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None and input_shape[-1] > 1:
inputs["attention_mask"] = tf.cast( combined_attention_mask = combined_attention_mask + _expand_mask(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype inputs["attention_mask"], tgt_len=input_shape[-1]
)
inputs["attention_mask"] = tf.concat(
[
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
inputs["attention_mask"],
],
axis=-1,
)
else:
inputs["attention_mask"] = tf.ones(
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
) )
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None: if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:

View File

@ -20,7 +20,7 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from . import LongformerForQuestionAnswering, LongformerModel from transformers import LongformerForQuestionAnswering, LongformerModel
class LightningModel(pl.LightningModule): class LightningModel(pl.LightningModule):

View File

@ -16,14 +16,14 @@
import argparse import argparse
import logging
import torch import torch
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
from transformers.utils import logging
logging.basicConfig(level=logging.INFO) logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):

View File

@ -17,7 +17,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
from .convert_marian_to_pytorch import ( from transformers.models.marian.convert_marian_to_pytorch import (
FRONT_MATTER_TEMPLATE, FRONT_MATTER_TEMPLATE,
_parse_readme, _parse_readme,
convert_all_sentencepiece_models, convert_all_sentencepiece_models,

View File

@ -26,8 +26,8 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from ...hf_api import HfApi from transformers import MarianConfig, MarianMTModel, MarianTokenizer
from . import MarianConfig, MarianMTModel, MarianTokenizer from transformers.hf_api import HfApi
def remove_suffix(text: str, suffix: str): def remove_suffix(text: str, suffix: str):

View File

@ -16,9 +16,8 @@ import argparse
import torch import torch
from ..bart import BartForConditionalGeneration from transformers import BartForConditionalGeneration, MBartConfig
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_ from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
from . import MBartConfig
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):

View File

@ -16,8 +16,8 @@ import argparse
import torch import torch
from ...utils import logging from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -19,9 +19,9 @@ import argparse
import torch import torch
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
from ...utils import logging from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -22,8 +22,8 @@ import tensorflow as tf
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from .configuration_pegasus import DEFAULTS, task_specific_params from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
PATTERNS = [ PATTERNS = [

View File

@ -19,6 +19,8 @@ import argparse
import torch import torch
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here # transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively # original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
from transformers_old.modeling_prophetnet import ( from transformers_old.modeling_prophetnet import (
@ -28,8 +30,6 @@ from transformers_old.modeling_xlm_prophetnet import (
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld, XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
) )
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -21,8 +21,8 @@ import pickle
import numpy as np import numpy as np
import torch import torch
from ...utils import logging from transformers import ReformerConfig, ReformerModelWithLMHead
from . import ReformerConfig, ReformerModelWithLMHead from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -24,9 +24,15 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version from packaging import version
from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput from transformers import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
from ...utils import logging from transformers.models.bert.modeling_bert import (
from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):

View File

@ -17,8 +17,8 @@
import argparse import argparse
from ...utils import logging from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -17,8 +17,7 @@
import argparse import argparse
from ...utils import logging from transformers import (
from . import (
TapasConfig, TapasConfig,
TapasForMaskedLM, TapasForMaskedLM,
TapasForQuestionAnswering, TapasForQuestionAnswering,
@ -27,6 +26,7 @@ from . import (
TapasTokenizer, TapasTokenizer,
load_tf_weights_in_tapas, load_tf_weights_in_tapas,
) )
from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -22,11 +22,11 @@ import sys
import torch import torch
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
from ...utils import logging from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl from transformers.models.transfo_xl import tokenization_transfo_xl as data_utils
from . import tokenization_transfo_xl as data_utils from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -21,9 +21,9 @@ import json
import numpy import numpy
import torch import torch
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from ...utils import logging from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
from .tokenization_xlm import VOCAB_FILES_NAMES from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@ -20,15 +20,15 @@ import os
import torch import torch
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME from transformers import (
from ...utils import logging
from . import (
XLNetConfig, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetLMHeadModel, XLNetLMHeadModel,
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {

View File

@ -65,6 +65,12 @@ def _is_torch(x):
return isinstance(x, torch.Tensor) return isinstance(x, torch.Tensor)
def _is_torch_device(x):
import torch
return isinstance(x, torch.device)
def _is_tensorflow(x): def _is_tensorflow(x):
import tensorflow as tf import tensorflow as tf
@ -801,7 +807,7 @@ class BatchEncoding(UserDict):
# This check catches things like APEX blindly calling "to" on all inputs to a module # This check catches things like APEX blindly calling "to" on all inputs to a module
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs # Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor # into a HalfTensor
if isinstance(device, str) or isinstance(device, torch.device) or isinstance(device, int): if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) for k, v in self.data.items()} self.data = {k: v.to(device=device) for k, v in self.data.items()}
else: else:
logger.warning( logger.warning(

View File

@ -16,7 +16,7 @@ import json
import os import os
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
from .trainer_utils import EvaluationStrategy, SchedulerType from .trainer_utils import EvaluationStrategy, SchedulerType
@ -426,7 +426,6 @@ class TrainingArguments:
if is_torch_available() and self.device.type != "cuda" and self.fp16: if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.") raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
self._n_gpu = torch.cuda.device_count()
def __repr__(self): def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once # We override the default repr to remove deprecated arguments from the repr. This method should be removed once
@ -467,14 +466,14 @@ class TrainingArguments:
@cached_property @cached_property
@torch_required @torch_required
def _setup_devices(self) -> Tuple["torch.device", int]: def _setup_devices(self) -> "torch.device":
logger.info("PyTorch: setting up devices") logger.info("PyTorch: setting up devices")
if self.no_cuda: if self.no_cuda:
device = torch.device("cpu") device = torch.device("cpu")
n_gpu = 0 self._n_gpu = 0
elif is_torch_tpu_available(): elif is_torch_tpu_available():
device = xm.xla_device() device = xm.xla_device()
n_gpu = 0 self._n_gpu = 0
elif self.local_rank == -1: elif self.local_rank == -1:
# if n_gpu is > 1 we'll use nn.DataParallel. # if n_gpu is > 1 we'll use nn.DataParallel.
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
@ -485,9 +484,7 @@ class TrainingArguments:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
# the default value. # the default value.
if self._n_gpu == -1: self._n_gpu = torch.cuda.device_count()
self._n_gpu = torch.cuda.device_count()
n_gpu = self._n_gpu
else: else:
# Here, we'll use torch.distributed. # Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
@ -507,12 +504,12 @@ class TrainingArguments:
else: else:
torch.distributed.init_process_group(backend="nccl") torch.distributed.init_process_group(backend="nccl")
device = torch.device("cuda", self.local_rank) device = torch.device("cuda", self.local_rank)
n_gpu = 1 self._n_gpu = 1
if device.type == "cuda": if device.type == "cuda":
torch.cuda.set_device(device) torch.cuda.set_device(device)
return device, n_gpu return device
@property @property
@torch_required @torch_required
@ -520,7 +517,7 @@ class TrainingArguments:
""" """
The device used by this process. The device used by this process.
""" """
return self._setup_devices[0] return self._setup_devices
@property @property
@torch_required @torch_required
@ -532,7 +529,9 @@ class TrainingArguments:
This will only be greater than one when you have multiple GPUs available but are not using distributed This will only be greater than one when you have multiple GPUs available but are not using distributed
training. For distributed training, it will always be 1. training. For distributed training, it will always be 1.
""" """
return self._setup_devices[1] # Make sure `self._n_gpu` is properly setup.
_ = self._setup_devices
return self._n_gpu
@property @property
@torch_required @torch_required

View File

@ -1704,6 +1704,10 @@ class TokenizerTesterMixin:
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
sequence = " ".join(first_ten_tokens) sequence = " ".join(first_ten_tokens)
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
# Ensure that the BatchEncoding.to() method works.
encoded_sequence.to(model.device)
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail # This should not fail

View File

@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
# Make the Trainer believe it's a parallelized model # Make the Trainer believe it's a parallelized model
model.is_parallelizable = True model.is_parallelizable = True
model.model_parallel = True model.model_parallel = True
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset()) args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
# Check the Trainer was fooled # Check the Trainer was fooled
self.assertTrue(trainer.is_model_parallel) self.assertTrue(trainer.is_model_parallel)
self.assertEqual(trainer.args.n_gpu, 1)
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu # The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
self.assertEqual(trainer.get_train_dataloader().batch_size, 16) self.assertEqual(trainer.get_train_dataloader().batch_size, 16)