diff --git a/.gitignore b/.gitignore index 6bbe32df6c8..bbc738b931d 100644 --- a/.gitignore +++ b/.gitignore @@ -127,4 +127,7 @@ proc_data # examples runs -examples/runs \ No newline at end of file +examples/runs + +# data +data \ No newline at end of file diff --git a/examples/run_generative_finetuning.py b/examples/run_lm_finetuning.py similarity index 75% rename from examples/run_generative_finetuning.py rename to examples/run_lm_finetuning.py index 8501364ae4f..bd7047a5872 100644 --- a/examples/run_generative_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -25,33 +25,75 @@ import argparse import glob import logging import os +import pickle import random import numpy as np import torch -from torch.utils.data import (DataLoader, SequentialSampler,) +from torch.utils.data import DataLoader, Dataset, SequentialSampler from torch.utils.data.distributed import DistributedSampler from tensorboardX import SummaryWriter from tqdm import tqdm, trange -from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, - OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, - BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) -from pytorch_transformers import AdamW, WarmupLinearSchedule -logger = logging.getLogger(__name__) +from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, + BertConfig, BertForMaskedLM, BertTokenizer, + GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, + OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, + RobertaConfig, RobertaForMaskedLM, RobertaTokenizer) -from utils_lm import WikiTextDataset + +logger = logging.getLogger(__name__) MODEL_CLASSES = { 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer), 'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), - "bert": (BertConfig, BertForMaskedLM, BertTokenizer), - "roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer) + 'bert': (BertConfig, BertForMaskedLM, BertTokenizer), + 'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer) } +class TextDataset(Dataset): + def __init__(self, tokenizer, file_path='train', block_size=512): + assert os.path.isfile(file_path) + directory, filename = os.path.split(file_path) + cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}') + + if os.path.exists(cached_features_file): + logger.info("Loading features from cached file %s", cached_features_file) + with open(cached_features_file, 'rb') as handle: + self.examples = pickle.load(handle) + else: + logger.info("Creating features from dataset file at %s", directory) + + self.examples = [] + with open(file_path, encoding="utf-8") as f: + text = f.read() + + tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) + while len(tokenized_text) >= block_size: # Truncate in block of block_size + self.examples.append(tokenized_text[:block_size]) + tokenized_text = tokenized_text[block_size:] + # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # If your dataset is small, first you should loook for a bigger one :-) and second you + # can change this behavior by adding (model specific) padding. + + logger.info("Saving features into cached file %s", cached_features_file) + with open(cached_features_file, 'wb') as handle: + pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def __len__(self): + return len(self.examples) + + def __getitem__(self, item): + return torch.tensor(self.examples[item]) + + +def load_and_cache_examples(args, tokenizer, evaluate=False): + dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) + return dataset + + def set_seed(args): random.seed(args.seed) np.random.seed(args.seed) @@ -59,20 +101,27 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) -# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original -def mask_tokens(inputs, tokenizer, args): - labels = inputs.clone() - masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte() - labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens - indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices - inputs[indices_replaced.bool()] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 80% of the time, replace masked input tokens with [MASK] - indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool() - random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long) - inputs[indices_random] = random_words[ - indices_random] # 10% of the time, replace masked input tokens with random word +def mask_tokens(inputs, tokenizer, args): + """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ + labels = inputs.clone() + # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) + masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte() + labels[~masked_indices] = -1 # We only compute loss on masked tokens + + # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) + indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices + inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + + # 10% of the time, we replace masked input tokens with random word + indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced + random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) + inputs[indices_random] = random_words[indices_random] + + # The rest of the time (10% of the time) we keep the masked input tokens unchanged return inputs, labels + def train(args, train_dataset, model, tokenizer): """ Train the model """ if args.local_rank in [-1, 0]: @@ -146,13 +195,15 @@ def train(args, train_dataset, model, tokenizer): if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() - torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) tr_loss += loss.item() if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() @@ -240,24 +291,22 @@ def evaluate(args, model, tokenizer, prefix=""): return results -def load_and_cache_examples(args, tokenizer, evaluate=False): - dataset = WikiTextDataset(args, tokenizer, file="test" if evaluate else "train", directory=args.data_dir) - return dataset - - def main(): parser = argparse.ArgumentParser() ## Required parameters - parser.add_argument("--data_dir", default=None, type=str, required=True, - help="The input data dir. Should contain the .tsv files (or other data files) for the task.") + parser.add_argument("--train_data_file", default=None, type=str, required=True, + help="The input training data file (a text file).") parser.add_argument("--output_dir", default=None, type=str, required=True, help="The output directory where the model predictions and checkpoints will be written.") ## Other parameters - parser.add_argument("--model_name", default="bert", type=str, + parser.add_argument("--eval_data_file", default=None, type=str, + help="An optional input evaluation data file to evaluate the perplexity on (a text file).") + + parser.add_argument("--model_type", default="bert", type=str, help="The model architecture to be fine-tuned.") - parser.add_argument("--model_checkpoint", default="bert-base-cased", type=str, + parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str, help="The model checkpoint for weights initialization.") parser.add_argument("--mlm", action='store_true', @@ -266,20 +315,21 @@ def main(): help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument("--config_name", default="", type=str, - help="Pretrained config name or path if not the same as model_name") + help="Optional pretrained config name or path if not the same as model_name_or_path") parser.add_argument("--tokenizer_name", default="", type=str, - help="Pretrained tokenizer name or path if not the same as model_name") + help="Optional pretrained tokenizer name or path if not the same as model_name_or_path") parser.add_argument("--cache_dir", default="", type=str, - help="Where do you want to store the pre-trained models downloaded from s3") - parser.add_argument("--max_seq_length", default=128, type=int, - help="The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded.") + help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)") + parser.add_argument("--block_size", default=-1, type=int, + help="Optional input sequence length after tokenization." + "The training dataset will be truncated in block of this size for training." + "Default to the model max input length.") parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--evaluate_during_training", action='store_true', - help="Rul evaluation during training at each logging step.") + help="Run evaluation during training at each logging step.") parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.") @@ -309,7 +359,7 @@ def main(): parser.add_argument('--save_steps', type=int, default=50, help="Save checkpoint every X updates steps.") parser.add_argument("--eval_all_checkpoints", action='store_true', - help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number") parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument('--overwrite_output_dir', action='store_true', @@ -330,9 +380,12 @@ def main(): parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") args = parser.parse_args() - if args.model_name in ["bert", "roberta"] and not args.mlm: + if args.model_type in ["bert", "roberta"] and not args.mlm: raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm " "flag (masked language modeling).") + if args.eval_data_file is None and args.do_eval: + raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file " + "or remove the --do_eval argument.") if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) @@ -368,30 +421,36 @@ def main(): # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: - torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab - config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_name] - config = config_class.from_pretrained(args.config_name if args.config_name else args.model_checkpoint) - tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_checkpoint, do_lower_case=args.do_lower_case) - model = model_class.from_pretrained(args.model_checkpoint, from_tf=bool('.ckpt' in args.model_checkpoint), config=config) - args.num_embeddings = config.vocab_size # We need this to create the model at next line (number of embeddings to use) + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) + if args.block_size <= 0: + args.block_size = tokenizer.max_len # Our input block size will be the max possible for the model + model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config) + model.to(args.device) if args.local_rank == 0: - torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab - - model.to(args.device) + torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab logger.info("Training/evaluation parameters %s", args) - # Training if args.do_train: + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache + train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False) + + if args.local_rank == 0: + torch.distributed.barrier() + global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) - # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() + # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained() if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): # Create output directory if needed if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: @@ -409,7 +468,7 @@ def main(): # Load a trained model and vocabulary that you have fine-tuned model = model_class.from_pretrained(args.output_dir) - tokenizer = tokenizer_class.from_pretrained(args.output_dir) + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) model.to(args.device) diff --git a/examples/utils_lm.py b/examples/utils_lm.py deleted file mode 100644 index 251aea90e12..00000000000 --- a/examples/utils_lm.py +++ /dev/null @@ -1,51 +0,0 @@ -from torch.utils.data import Dataset, DataLoader -import os -import random -import torch -import torch.nn.functional as F -import logging -import pickle - -logger = logging.getLogger(__name__) - - -class WikiTextDataset(Dataset): - def __init__(self, args, tokenizer, file='train', directory='wikitext', max_context_length=512, cache=None): - if args.local_rank not in [-1, 0]: - torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache - - - cached_features_file = os.path.join(args.data_dir, f'cached_lm_{file}_{args.max_seq_length}') - - if os.path.exists(cached_features_file): - logger.info("Loading features from cached file %s", cached_features_file) - with open(cached_features_file, 'rb') as handle: - self.examples = pickle.load(handle) - else: - logger.info("Creating features from dataset file at %s", args.data_dir) - - self.max_context_length = max_context_length - - self.examples = [] - - with open(os.path.join(directory, f"wiki.{file}.raw"), encoding="utf-8") as f: - text = f.read() - tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) - - while len(tokenized_text) > max_context_length: - self.examples.append(tokenized_text[:max_context_length]) - tokenized_text = tokenized_text[max_context_length:] - - if args.local_rank in [-1, 0]: - logger.info("Saving features into cached file %s", cached_features_file) - with open(cached_features_file, 'wb') as handle: - pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL) - - if args.local_rank == 0: - torch.distributed.barrier() - - def __len__(self): - return len(self.examples) - - def __getitem__(self, item): - return torch.tensor(self.examples[item])