mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
24 Commits
Author | SHA1 | Date | |
---|---|---|---|
ce37b8e481 | |||
05053d163c | |||
63ae5d2134 | |||
029bdc0d50 | |||
ebaacba38b | |||
870d71636e | |||
982339d829 | |||
60e01ac427 | |||
6b2136a8a9 | |||
061eeca84a | |||
fd32ebed81 | |||
eed255a58d | |||
2f21497d3e | |||
9ff2b7d86d | |||
37b6c9b21b | |||
da73925f6a | |||
6f4be31d0d | |||
dd56cfd89a | |||
6c4789e4e8 | |||
956c917344 | |||
27ee0fff3c | |||
aa50fd196f | |||
7c91e51c26 | |||
e113101702 |
90
README.md
90
README.md
@ -2,7 +2,7 @@
|
||||
|
||||
This repository contains an op-for-op PyTorch reimplementation of [Google's TensorFlow repository for the BERT model](https://github.com/google-research/bert) that was released together with the paper [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805) by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova.
|
||||
|
||||
This implementation is provided with [Google's pre-trained models](https://github.com/google-research/bert)) and a conversion script to load any pre-trained TensorFlow checkpoint for BERT is also provided.
|
||||
This implementation is provided with [Google's pre-trained models](https://github.com/google-research/bert), examples, notebooks and a command-line interface to load any pre-trained TensorFlow checkpoint for BERT is also provided.
|
||||
|
||||
## Content
|
||||
|
||||
@ -25,7 +25,7 @@ This repo was tested on Python 3.5+ and PyTorch 0.4.1
|
||||
|
||||
PyTorch pretrained bert can be installed by pip as follows:
|
||||
```bash
|
||||
pip install pytorch_pretrained_bert
|
||||
pip install pytorch-pretrained-bert
|
||||
```
|
||||
|
||||
### From source
|
||||
@ -46,23 +46,23 @@ python -m pytest -sv tests/
|
||||
|
||||
This package comprises the following classes that can be imported in Python and are detailed in the [Doc](#doc) section of this readme:
|
||||
|
||||
- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights:
|
||||
- `BertModel` - raw BERT Transformer model (**fully pre-trained**),
|
||||
- `BertForMaskedLM` - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- `BertForNextSentencePrediction` - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- `BertForPreTraining` - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- `BertForSequenceClassification` - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- `BertForQuestionAnswering` - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
- Six PyTorch models (`torch.nn.Module`) for Bert with pre-trained weights (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||
- [`BertModel`](./pytorch_pretrained_bert/modeling.py#L535) - raw BERT Transformer model (**fully pre-trained**),
|
||||
- [`BertForMaskedLM`](./pytorch_pretrained_bert/modeling.py#L689) - BERT Transformer with the pre-trained masked language modeling head on top (**fully pre-trained**),
|
||||
- [`BertForNextSentencePrediction`](./pytorch_pretrained_bert/modeling.py#L750) - BERT Transformer with the pre-trained next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForPreTraining`](./pytorch_pretrained_bert/modeling.py#L618) - BERT Transformer with masked language modeling head and next sentence prediction classifier on top (**fully pre-trained**),
|
||||
- [`BertForSequenceClassification`](./pytorch_pretrained_bert/modeling.py#L812) - BERT Transformer with a sequence classification head on top (BERT Transformer is **pre-trained**, the sequence classification head **is only initialized and has to be trained**),
|
||||
- [`BertForQuestionAnswering`](./pytorch_pretrained_bert/modeling.py#L877) - BERT Transformer with a token classification head on top (BERT Transformer is **pre-trained**, the token classification head **is only initialized and has to be trained**).
|
||||
|
||||
- Three tokenizers:
|
||||
- Three tokenizers (in the [`tokenization.py`](./pytorch_pretrained_bert/tokenization.py) file):
|
||||
- `BasicTokenizer` - basic tokenization (punctuation splitting, lower casing, etc.),
|
||||
- `WordpieceTokenizer` - WordPiece tokenization,
|
||||
- `BertTokenizer` - perform end-to-end tokenization, i.e. basic tokenization followed by WordPiece tokenization.
|
||||
|
||||
- One optimizer:
|
||||
- One optimizer (in the [`optimization.py`](./pytorch_pretrained_bert/optimization.py) file):
|
||||
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
|
||||
|
||||
- A configuration class:
|
||||
- A configuration class (in the [`modeling.py`](./pytorch_pretrained_bert/modeling.py) file):
|
||||
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilisities to read and write from JSON configuration files.
|
||||
|
||||
The repository further comprises:
|
||||
@ -99,7 +99,7 @@ from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
# Tokenized input
|
||||
tokenized_text = "Who was Jim Henson ? Jim Henson was a puppeteer"
|
||||
text = "Who was Jim Henson ? Jim Henson was a puppeteer"
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
|
||||
# Mask a token that we will try to predict back with `BertForMaskedLM`
|
||||
@ -159,17 +159,16 @@ Here is a detailed documentation of the classes in the package and how to use th
|
||||
|
||||
### Loading Google AI's pre-trained weigths and PyTorch dump
|
||||
|
||||
To load Google AI's pre-trained weight or a PyTorch saved instance of `BertForPreTraining`, the PyTorch model classes and the tokenizer can be instantiated as
|
||||
To load one of Google AI's pre-trained models or a PyTorch saved model (an instance of `BertForPreTraining` saved with `torch.save()`), the PyTorch model classes and the tokenizer can be instantiated as
|
||||
|
||||
```python
|
||||
model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH)
|
||||
model = BERT_CLASS.from_pretrain(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
|
||||
```
|
||||
|
||||
where
|
||||
|
||||
- `BERT_CLASS` is either the `BertTokenizer` class (to load the vocabulary) or one of the six PyTorch model classes (to load the pre-trained weights): `BertModel`, `BertForMaskedLM`, `BertForNextSentencePrediction`, `BertForPreTraining`, `BertForSequenceClassification` or `BertForQuestionAnswering`, and
|
||||
|
||||
- `PRE_TRAINED_MODEL_NAME` is either:
|
||||
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
|
||||
|
||||
- the shortcut name of a Google AI's pre-trained model selected in the list:
|
||||
|
||||
@ -180,10 +179,12 @@ where
|
||||
- `bert-base-chinese`: Chinese Simplified and Traditional, 12-layer, 768-hidden, 12-heads, 110M parameters
|
||||
|
||||
- a path or url to a pretrained model archive containing:
|
||||
. `bert_config.json` a configuration file for the model
|
||||
. `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
|
||||
|
||||
- `bert_config.json` a configuration file for the model, and
|
||||
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance `BertForPreTraining` (saved with the usual `torch.save()`)
|
||||
|
||||
If `PRE_TRAINED_MODEL_NAME` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
|
||||
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_bert/modeling.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_bert/`).
|
||||
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights. This option is useful in particular when you are using distributed training: to avoid concurrent access to the same weights you can set for example `cache_dir='./pretrained_model_{}'.format(args.local_rank)` (see the section on distributed training for more information)
|
||||
|
||||
Example:
|
||||
```python
|
||||
@ -202,15 +203,15 @@ We detail them here. This model takes as *inputs*:
|
||||
|
||||
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
|
||||
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
||||
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if some input sequence lengths are smaller than the max input sequence length of the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
|
||||
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
|
||||
|
||||
This model *outputs* a tuple composed of:
|
||||
|
||||
- `encoded_layers`: controled by the value of the `output_encoded_layers` argument:
|
||||
|
||||
. `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
. `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block,
|
||||
- `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
- `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block, i.e. a single torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
|
||||
|
||||
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
|
||||
|
||||
@ -232,6 +233,7 @@ An example on how to use this class is given in the `extract_features.py` script
|
||||
|
||||
- if `masked_lm_labels` and `next_sentence_label` are not `None`: Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss.
|
||||
- if `masked_lm_labels` or `next_sentence_label` is `None`: Outputs a tuple comprising
|
||||
|
||||
- the masked language modeling logits, and
|
||||
- the next sentence classification logits.
|
||||
|
||||
@ -304,15 +306,15 @@ Please refer to the doc strings and code in [`tokenization.py`](./pytorch_pretra
|
||||
The optimizer accepts the following arguments:
|
||||
|
||||
- `lr` : learning rate
|
||||
- `warmup` : portion of t_total for the warmup, -1 means no warmup. Default : -1
|
||||
- `warmup` : portion of `t_total` for the warmup, `-1` means no warmup. Default : `-1`
|
||||
- `t_total` : total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default : -1
|
||||
- `schedule` : schedule to use for the warmup (see above). Default : 'warmup_linear'
|
||||
- `b1` : Adams b1. Default : 0.9
|
||||
- `b2` : Adams b2. Default : 0.999
|
||||
- `e` : Adams epsilon. Default : 1e-6
|
||||
- `weight_decay_rate:` Weight decay. Default : 0.01
|
||||
- `max_grad_norm` : Maximum norm for the gradients (-1 means no clipping). Default : 1.0
|
||||
rate schedule, `-1` means constant learning rate. Default : `-1`
|
||||
- `schedule` : schedule to use for the warmup (see above). Default : `'warmup_linear'`
|
||||
- `b1` : Adams b1. Default : `0.9`
|
||||
- `b2` : Adams b2. Default : `0.999`
|
||||
- `e` : Adams epsilon. Default : `1e-6`
|
||||
- `weight_decay_rate:` Weight decay. Default : `0.01`
|
||||
- `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0`
|
||||
|
||||
## Examples
|
||||
|
||||
@ -419,10 +421,7 @@ To get these results we used a combination of:
|
||||
Here is the full list of hyper-parameters for this run:
|
||||
```bash
|
||||
python ./run_squad.py \
|
||||
--vocab_file $BERT_LARGE_DIR/vocab.txt \
|
||||
--bert_config_file $BERT_LARGE_DIR/bert_config.json \
|
||||
--init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin \
|
||||
--do_lower_case \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
@ -442,10 +441,7 @@ If you have a recent GPU (starting from NVIDIA Volta series), you should try **1
|
||||
Here is an example of hyper-parameters for a FP16 run we tried:
|
||||
```bash
|
||||
python ./run_squad.py \
|
||||
--vocab_file $BERT_LARGE_DIR/vocab.txt \
|
||||
--bert_config_file $BERT_LARGE_DIR/bert_config.json \
|
||||
--init_checkpoint $BERT_LARGE_DIR/pytorch_model.bin \
|
||||
--do_lower_case \
|
||||
--bert_model bert-large-uncased \
|
||||
--do_train \
|
||||
--do_predict \
|
||||
--train_file $SQUAD_TRAIN \
|
||||
@ -467,23 +463,21 @@ The results were similar to the above FP32 results (actually slightly higher):
|
||||
|
||||
## Notebooks
|
||||
|
||||
Comparing the PyTorch model and the TensorFlow model predictions
|
||||
|
||||
We also include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
||||
We include [three Jupyter Notebooks](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/notebooks) that can be used to check that the predictions of the PyTorch model are identical to the predictions of the original TensorFlow model.
|
||||
|
||||
- The first NoteBook ([Comparing-TF-and-PT-models.ipynb](./notebooks/Comparing-TF-and-PT-models.ipynb)) extracts the hidden states of a full sequence on each layers of the TensorFlow and the PyTorch models and computes the standard deviation between them. In the given example, we get a standard deviation of 1.5e-7 to 9e-7 on the various hidden state of the models.
|
||||
|
||||
- The second NoteBook ([Comparing-TF-and-PT-models-SQuAD.ipynb](./notebooks/Comparing-TF-and-PT-models-SQuAD.ipynb)) compares the loss computed by the TensorFlow and the PyTorch models for identical initialization of the fine-tuning layer of the `BertForQuestionAnswering` and computes the standard deviation between them. In the given example, we get a standard deviation of 2.5e-7 between the models.
|
||||
|
||||
- The third NoteBook ([Comparing-TF-and-PT-models-MLM-NSP.ipynb](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb)) compares the predictions computed by the TensorFlow and the PyTorch models for masked token using the pre-trained masked language modeling model.
|
||||
- The third NoteBook ([Comparing-TF-and-PT-models-MLM-NSP.ipynb](./notebooks/Comparing-TF-and-PT-models-MLM-NSP.ipynb)) compares the predictions computed by the TensorFlow and the PyTorch models for masked token language modeling using the pre-trained masked language modeling model.
|
||||
|
||||
Please follow the instructions given in the notebooks to run and modify them.
|
||||
|
||||
## Command-line interface
|
||||
|
||||
A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch checkpoint
|
||||
A command-line interface is provided to convert a TensorFlow checkpoint in a PyTorch dump of the `BertForPreTraining` class (see above).
|
||||
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
|
||||
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`./pytorch_pretrained_bert/convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
|
||||
|
||||
This CLI takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
|
||||
|
||||
@ -497,9 +491,9 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
|
||||
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
||||
|
||||
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
|
||||
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
|
||||
--bert_config_file $BERT_BASE_DIR/bert_config.json \
|
||||
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
|
||||
$BERT_BASE_DIR/bert_model.ckpt \
|
||||
$BERT_BASE_DIR/bert_config.json \
|
||||
$BERT_BASE_DIR/pytorch_model.bin
|
||||
```
|
||||
|
||||
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
|
||||
|
@ -208,6 +208,10 @@ def main():
|
||||
type=int,
|
||||
default=-1,
|
||||
help = "local_rank for distributed training on gpus")
|
||||
parser.add_argument("--no_cuda",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -396,10 +396,6 @@ def main():
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--save_checkpoints_steps",
|
||||
default=1000,
|
||||
type=int,
|
||||
help="How often to save the model checkpoint.")
|
||||
parser.add_argument("--no_cuda",
|
||||
default=False,
|
||||
action='store_true',
|
||||
@ -486,7 +482,8 @@ def main():
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list),
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
@ -507,8 +504,8 @@ def main():
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
@ -559,7 +556,8 @@ def main():
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# scale down gradients for fp16 training
|
||||
for param in model.parameters():
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
if param.grad is not None:
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||
if is_nan:
|
||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||
|
@ -729,10 +729,6 @@ def main():
|
||||
parser.add_argument("--warmup_proportion", default=0.1, type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
|
||||
"of training.")
|
||||
parser.add_argument("--save_checkpoints_steps", default=1000, type=int,
|
||||
help="How often to save the model checkpoint.")
|
||||
parser.add_argument("--iterations_per_loop", default=1000, type=int,
|
||||
help="How many steps to make in each estimator call.")
|
||||
parser.add_argument("--n_best_size", default=20, type=int,
|
||||
help="The total number of n-best predictions to generate in the nbest_predictions.json "
|
||||
"output file.")
|
||||
@ -825,7 +821,8 @@ def main():
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
|
||||
|
||||
# Prepare model
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model)
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
@ -846,8 +843,8 @@ def main():
|
||||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias', 'gamma', 'beta']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
|
||||
]
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
@ -902,7 +899,8 @@ def main():
|
||||
if args.fp16 and args.loss_scale != 1.0:
|
||||
# scale down gradients for fp16 training
|
||||
for param in model.parameters():
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
if param.grad is not None:
|
||||
param.grad.data = param.grad.data / args.loss_scale
|
||||
is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
|
||||
if is_nan:
|
||||
logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
|
||||
|
@ -3,3 +3,4 @@ from .modeling import (BertConfig, BertModel, BertForPreTraining,
|
||||
BertForMaskedLM, BertForNextSentencePrediction,
|
||||
BertForSequenceClassification, BertForQuestionAnswering)
|
||||
from .optimization import BertAdam
|
||||
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
|
@ -443,7 +443,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
module.bias.data.zero_()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
@ -468,7 +468,7 @@ class PreTrainedBertModel(nn.Module):
|
||||
archive_file = pretrained_model_name
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file)
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
@ -678,8 +678,8 @@ class BertForPreTraining(PreTrainedBertModel):
|
||||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
return total_loss
|
||||
else:
|
||||
@ -741,7 +741,7 @@ class BertForMaskedLM(PreTrainedBertModel):
|
||||
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
return masked_lm_loss
|
||||
else:
|
||||
return prediction_scores
|
||||
@ -781,7 +781,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
||||
# Already been converted into WordPiece token ids
|
||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
|
||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||
|
||||
config = BertConfig(vocab_size=32000, hidden_size=512,
|
||||
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
|
||||
@ -803,7 +803,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
|
||||
|
||||
if next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
return next_sentence_loss
|
||||
else:
|
||||
return seq_relationship_score
|
||||
@ -856,6 +856,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
"""
|
||||
def __init__(self, config, num_labels=2):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
@ -868,7 +869,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits, labels)
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
||||
|
@ -38,16 +38,6 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
|
||||
}
|
||||
|
||||
def convert_to_unicode(text):
|
||||
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
elif isinstance(text, bytes):
|
||||
return text.decode("utf-8", "ignore")
|
||||
else:
|
||||
raise ValueError("Unsupported string type: %s" % (type(text)))
|
||||
|
||||
|
||||
def printable_text(text):
|
||||
"""Returns text encoded in a way suitable for print or `tf.logging`."""
|
||||
|
||||
@ -65,9 +55,9 @@ def load_vocab(vocab_file):
|
||||
"""Loads a vocabulary file into a dictionary."""
|
||||
vocab = collections.OrderedDict()
|
||||
index = 0
|
||||
with open(vocab_file, "r") as reader:
|
||||
with open(vocab_file, "r", encoding="utf-8") as reader:
|
||||
while True:
|
||||
token = convert_to_unicode(reader.readline())
|
||||
token = reader.readline()
|
||||
if not token:
|
||||
break
|
||||
token = token.strip()
|
||||
@ -164,7 +154,6 @@ class BasicTokenizer(object):
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Tokenizes a piece of text."""
|
||||
text = convert_to_unicode(text)
|
||||
text = self._clean_text(text)
|
||||
# This was added on November 1st, 2018 for the multilingual and Chinese
|
||||
# models. This is also applied to the English models now, but it doesn't
|
||||
@ -290,8 +279,6 @@ class WordpieceTokenizer(object):
|
||||
A list of wordpiece tokens.
|
||||
"""
|
||||
|
||||
text = convert_to_unicode(text)
|
||||
|
||||
output_tokens = []
|
||||
for token in whitespace_tokenize(text):
|
||||
chars = list(token)
|
||||
|
2
setup.py
2
setup.py
@ -2,7 +2,7 @@ from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
name="pytorch_pretrained_bert",
|
||||
version="0.1.2",
|
||||
version="0.2.0",
|
||||
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
|
||||
author_email="thomas@huggingface.co",
|
||||
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
|
||||
|
Reference in New Issue
Block a user