mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
10 Commits
merging_to
...
v4.2.2
Author | SHA1 | Date | |
---|---|---|---|
98af96a156 | |||
0ed8bccc9c | |||
c5f6719040 | |||
21d45958af | |||
4d4d2ce135 | |||
1528b1007f | |||
236cc365af | |||
5b05321b56 | |||
412d878c5e | |||
59fbd64b1c |
5
.github/workflows/release-conda.yml
vendored
5
.github/workflows/release-conda.yml
vendored
@ -37,7 +37,8 @@ jobs:
|
||||
- name: Build conda packages
|
||||
run: |
|
||||
conda info
|
||||
conda build .github/conda
|
||||
conda list
|
||||
conda-build .github/conda
|
||||
|
||||
- name: Upload to Anaconda
|
||||
run: anaconda upload `conda build .github/conda --output` --force
|
||||
run: anaconda upload `conda-build .github/conda --output` --force
|
||||
|
2
setup.py
2
setup.py
@ -248,7 +248,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="transformers",
|
||||
version="4.2.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="4.2.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
author="Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Sam Shleifer, Patrick von Platen, Sylvain Gugger, Google AI Language Team Authors, Open AI team Authors, Facebook AI Authors, Carnegie Mellon University Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch",
|
||||
|
@ -22,7 +22,7 @@
|
||||
# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names
|
||||
# in the namespace without actually importing anything (and especially none of the backends).
|
||||
|
||||
__version__ = "4.2.0"
|
||||
__version__ = "4.2.2"
|
||||
|
||||
# Work around to update TensorFlow's absl.logging threshold which alters the
|
||||
# default Python logging output behavior when present.
|
||||
|
@ -89,8 +89,20 @@ if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VA
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow-cpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tf_version = None
|
||||
_tf_available = False
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tensorflow-gpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tf-nightly")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tf-nightly-cpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
try:
|
||||
_tf_version = importlib_metadata.version("tf-nightly-gpu")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_tf_version = None
|
||||
_tf_available = False
|
||||
if _tf_available:
|
||||
if version.parse(_tf_version) < version.parse("2"):
|
||||
logger.info(f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum.")
|
||||
|
@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -23,9 +23,15 @@ import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from ...utils import logging
|
||||
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
|
||||
from .modeling_bart import _make_linear_from_emb
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.models.bart.modeling_bart import _make_linear_from_emb
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
|
@ -28,8 +28,8 @@ import re
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import BertConfig, BertModel
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -22,7 +22,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from . import BertModel
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
|
@ -18,8 +18,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...models.bart import BartConfig, BartForConditionalGeneration
|
||||
from ...utils import logging
|
||||
from transformers import BartConfig, BartForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from ...file_utils import WEIGHTS_NAME
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
|
||||
|
||||
DIALOGPT_MODELS = ["small", "medium", "large"]
|
||||
|
@ -19,8 +19,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
from ...models.bert import BertConfig
|
||||
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from .transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
|
||||
|
||||
CheckpointState = collections.namedtuple(
|
||||
|
@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -31,10 +31,11 @@ import torch
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
|
||||
from ...file_utils import WEIGHTS_NAME
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import logging
|
||||
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
||||
from transfomers.models.fsmt.tokenization_fsmt import VOCAB_FILES_NAMES
|
||||
from transformers import FSMTConfig, FSMTForConditionalGeneration
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
@ -23,6 +23,7 @@ from ...file_utils import _BaseLazyModule, is_tf_available, is_tokenizers_availa
|
||||
|
||||
_import_structure = {
|
||||
"configuration_funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig"],
|
||||
"convert_funnel_original_tf_checkpoint_to_pytorch": [],
|
||||
"tokenization_funnel": ["FunnelTokenizer"],
|
||||
}
|
||||
|
||||
|
@ -16,14 +16,14 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
|
@ -19,9 +19,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||
from transformers import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -1862,7 +1862,6 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
hidden_states = inputs["inputs_embeds"]
|
||||
|
||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||
combined_attention_mask = None
|
||||
if input_shape[-1] > 1:
|
||||
combined_attention_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
|
||||
else:
|
||||
@ -1870,20 +1869,9 @@ class TFLEDDecoder(tf.keras.layers.Layer):
|
||||
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["attention_mask"] is None and inputs["input_ids"] is not None and input_shape[-1] > 1:
|
||||
inputs["attention_mask"] = tf.cast(
|
||||
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), inputs["input_ids"].dtype
|
||||
)
|
||||
inputs["attention_mask"] = tf.concat(
|
||||
[
|
||||
tf.ones((input_shape[0], past_key_values_length), dtype=inputs["attention_mask"].dtype),
|
||||
inputs["attention_mask"],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
inputs["attention_mask"] = tf.ones(
|
||||
(input_shape[0], input_shape[1] + past_key_values_length), dtype=tf.int32
|
||||
if inputs["attention_mask"] is not None and input_shape[-1] > 1:
|
||||
combined_attention_mask = combined_attention_mask + _expand_mask(
|
||||
inputs["attention_mask"], tgt_len=input_shape[-1]
|
||||
)
|
||||
|
||||
if inputs["encoder_hidden_states"] is not None and inputs["encoder_attention_mask"] is not None:
|
||||
|
@ -20,7 +20,7 @@ import argparse
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from . import LongformerForQuestionAnswering, LongformerModel
|
||||
from transformers import LongformerForQuestionAnswering, LongformerModel
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
|
@ -16,14 +16,14 @@
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.set_verbosity_info()
|
||||
|
||||
|
||||
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
|
||||
|
@ -17,7 +17,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from .convert_marian_to_pytorch import (
|
||||
from transformers.models.marian.convert_marian_to_pytorch import (
|
||||
FRONT_MATTER_TEMPLATE,
|
||||
_parse_readme,
|
||||
convert_all_sentencepiece_models,
|
||||
|
@ -26,8 +26,8 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...hf_api import HfApi
|
||||
from . import MarianConfig, MarianMTModel, MarianTokenizer
|
||||
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||
from transformers.hf_api import HfApi
|
||||
|
||||
|
||||
def remove_suffix(text: str, suffix: str):
|
||||
|
@ -16,9 +16,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ..bart import BartForConditionalGeneration
|
||||
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||
from . import MBartConfig
|
||||
from transformers import BartForConditionalGeneration, MBartConfig
|
||||
from transformers.models.bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||
|
||||
|
||||
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
||||
|
@ -16,8 +16,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -19,9 +19,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||
from transformers import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -22,8 +22,8 @@ import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from .configuration_pegasus import DEFAULTS, task_specific_params
|
||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
|
||||
|
||||
|
||||
PATTERNS = [
|
||||
|
@ -19,6 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||
|
||||
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
||||
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
||||
from transformers_old.modeling_prophetnet import (
|
||||
@ -28,8 +30,6 @@ from transformers_old.modeling_xlm_prophetnet import (
|
||||
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
||||
)
|
||||
|
||||
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logging.set_verbosity_info()
|
||||
|
@ -21,8 +21,8 @@ import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from . import ReformerConfig, ReformerModelWithLMHead
|
||||
from transformers import ReformerConfig, ReformerModelWithLMHead
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -24,9 +24,15 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from packaging import version
|
||||
|
||||
from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||
from ...utils import logging
|
||||
from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from transformers import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
|
@ -17,8 +17,8 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from ...utils import logging
|
||||
from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -17,8 +17,7 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from ...utils import logging
|
||||
from . import (
|
||||
from transformers import (
|
||||
TapasConfig,
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
@ -27,6 +26,7 @@ from . import (
|
||||
TapasTokenizer,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -22,11 +22,11 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||
from . import tokenization_transfo_xl as data_utils
|
||||
from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
from transformers import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.models.transfo_xl import tokenization_transfo_xl as data_utils
|
||||
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -21,9 +21,9 @@ import json
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from .tokenization_xlm import VOCAB_FILES_NAMES
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
@ -20,15 +20,15 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import (
|
||||
from transformers import (
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetLMHeadModel,
|
||||
load_tf_weights_in_xlnet,
|
||||
)
|
||||
from transformers.file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
|
@ -65,6 +65,12 @@ def _is_torch(x):
|
||||
return isinstance(x, torch.Tensor)
|
||||
|
||||
|
||||
def _is_torch_device(x):
|
||||
import torch
|
||||
|
||||
return isinstance(x, torch.device)
|
||||
|
||||
|
||||
def _is_tensorflow(x):
|
||||
import tensorflow as tf
|
||||
|
||||
@ -801,7 +807,7 @@ class BatchEncoding(UserDict):
|
||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
# into a HalfTensor
|
||||
if isinstance(device, str) or isinstance(device, torch.device) or isinstance(device, int):
|
||||
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
|
||||
self.data = {k: v.to(device=device) for k, v in self.data.items()}
|
||||
else:
|
||||
logger.warning(
|
||||
|
@ -16,7 +16,7 @@ import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .file_utils import cached_property, is_torch_available, is_torch_tpu_available, torch_required
|
||||
from .trainer_utils import EvaluationStrategy, SchedulerType
|
||||
@ -426,7 +426,6 @@ class TrainingArguments:
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
@ -467,14 +466,14 @@ class TrainingArguments:
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
def _setup_devices(self) -> Tuple["torch.device", int]:
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
n_gpu = 0
|
||||
self._n_gpu = 0
|
||||
elif is_torch_tpu_available():
|
||||
device = xm.xla_device()
|
||||
n_gpu = 0
|
||||
self._n_gpu = 0
|
||||
elif self.local_rank == -1:
|
||||
# if n_gpu is > 1 we'll use nn.DataParallel.
|
||||
# If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
|
||||
@ -485,9 +484,7 @@ class TrainingArguments:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
# Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
|
||||
# the default value.
|
||||
if self._n_gpu == -1:
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
n_gpu = self._n_gpu
|
||||
self._n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
# Here, we'll use torch.distributed.
|
||||
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
|
||||
@ -507,12 +504,12 @@ class TrainingArguments:
|
||||
else:
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
device = torch.device("cuda", self.local_rank)
|
||||
n_gpu = 1
|
||||
self._n_gpu = 1
|
||||
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(device)
|
||||
|
||||
return device, n_gpu
|
||||
return device
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
@ -520,7 +517,7 @@ class TrainingArguments:
|
||||
"""
|
||||
The device used by this process.
|
||||
"""
|
||||
return self._setup_devices[0]
|
||||
return self._setup_devices
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
@ -532,7 +529,9 @@ class TrainingArguments:
|
||||
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
||||
training. For distributed training, it will always be 1.
|
||||
"""
|
||||
return self._setup_devices[1]
|
||||
# Make sure `self._n_gpu` is properly setup.
|
||||
_ = self._setup_devices
|
||||
return self._n_gpu
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
|
@ -1704,6 +1704,10 @@ class TokenizerTesterMixin:
|
||||
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
|
||||
sequence = " ".join(first_ten_tokens)
|
||||
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
|
||||
|
||||
# Ensure that the BatchEncoding.to() method works.
|
||||
encoded_sequence.to(model.device)
|
||||
|
||||
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
|
||||
# This should not fail
|
||||
|
||||
|
@ -381,9 +381,11 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
# Make the Trainer believe it's a parallelized model
|
||||
model.is_parallelizable = True
|
||||
model.model_parallel = True
|
||||
trainer = Trainer(model=model, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
args = TrainingArguments("./regression", per_device_train_batch_size=16, per_device_eval_batch_size=16)
|
||||
trainer = Trainer(model, args, train_dataset=RegressionDataset(), eval_dataset=RegressionDataset())
|
||||
# Check the Trainer was fooled
|
||||
self.assertTrue(trainer.is_model_parallel)
|
||||
self.assertEqual(trainer.args.n_gpu, 1)
|
||||
|
||||
# The batch size of the training and evaluation dataloaders should be 16, not 16 * n_gpu
|
||||
self.assertEqual(trainer.get_train_dataloader().batch_size, 16)
|
||||
|
Reference in New Issue
Block a user