Skip to content

Commit

Permalink
various fix and clean up on run_lm_finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Aug 20, 2019
1 parent f94f1c6 commit a690eda
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 102 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,7 @@ proc_data

# examples
runs
examples/runs
examples/runs

# data
data
159 changes: 109 additions & 50 deletions examples/run_generative_finetuning.py → examples/run_lm_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,54 +25,103 @@
import glob
import logging
import os
import pickle
import random

import numpy as np
import torch
from torch.utils.data import (DataLoader, SequentialSampler,)
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange

from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertConfig, BertForMaskedLM, BertTokenizer, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import AdamW, WarmupLinearSchedule
logger = logging.getLogger(__name__)
from pytorch_transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)


from utils_lm import WikiTextDataset
logger = logging.getLogger(__name__)


MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer)
}


class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, f'cached_lm_{block_size}_{filename}')

if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, 'rb') as handle:
self.examples = pickle.load(handle)
else:
logger.info("Creating features from dataset file at %s", directory)

self.examples = []
with open(file_path, encoding="utf-8") as f:
text = f.read()

tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
while len(tokenized_text) >= block_size: # Truncate in block of block_size
self.examples.append(tokenized_text[:block_size])
tokenized_text = tokenized_text[block_size:]
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you
# can change this behavior by adding (model specific) padding.

logger.info("Saving features into cached file %s", cached_features_file)
with open(cached_features_file, 'wb') as handle:
pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

def __len__(self):
return len(self.examples)

def __getitem__(self, item):
return torch.tensor(self.examples[item])


def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
return dataset


def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)

# Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original

def mask_tokens(inputs, tokenizer, args):
""" Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
masked_indices = torch.bernoulli(torch.full(labels.shape, args.mlm_probability)).byte()
labels[~masked_indices.bool()] = -1 # We only compute loss on masked tokens
labels[~masked_indices] = -1 # We only compute loss on masked tokens

# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).byte() & masked_indices
inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]

inputs[indices_replaced.bool()] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) # 80% of the time, replace masked input tokens with [MASK]
indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced).bool()
random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[
indices_random] # 10% of the time, replace masked input tokens with random word
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels


def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
Expand Down Expand Up @@ -146,13 +195,15 @@ def train(args, train_dataset, model, tokenizer):
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
Expand Down Expand Up @@ -240,24 +291,22 @@ def evaluate(args, model, tokenizer, prefix=""):
return results


def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = WikiTextDataset(args, tokenizer, file="test" if evaluate else "train", directory=args.data_dir)
return dataset


def main():
parser = argparse.ArgumentParser()

## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--train_data_file", default=None, type=str, required=True,
help="The input training data file (a text file).")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")

## Other parameters
parser.add_argument("--model_name", default="bert", type=str,
parser.add_argument("--eval_data_file", default=None, type=str,
help="An optional input evaluation data file to evaluate the perplexity on (a text file).")

parser.add_argument("--model_type", default="bert", type=str,
help="The model architecture to be fine-tuned.")
parser.add_argument("--model_checkpoint", default="bert-base-cased", type=str,
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str,
help="The model checkpoint for weights initialization.")

parser.add_argument("--mlm", action='store_true',
Expand All @@ -266,20 +315,21 @@ def main():
help="Ratio of tokens to mask for masked language modeling loss")

parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
help="Optional pretrained config name or path if not the same as model_name_or_path")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
help="Optional pretrained tokenizer name or path if not the same as model_name_or_path")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
help="Optional directory to store the pre-trained models downloaded from s3 (instread of the default one)")
parser.add_argument("--block_size", default=-1, type=int,
help="Optional input sequence length after tokenization."
"The training dataset will be truncated in block of this size for training."
"Default to the model max input length.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the dev set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.")
help="Run evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")

Expand Down Expand Up @@ -309,7 +359,7 @@ def main():
parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
help="Evaluate all checkpoints starting with the same prefix as model_name_or_path ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
Expand All @@ -330,9 +380,12 @@ def main():
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
args = parser.parse_args()

if args.model_name in ["bert", "roberta"] and not args.mlm:
if args.model_type in ["bert", "roberta"] and not args.mlm:
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling).")
if args.eval_data_file is None and args.do_eval:
raise ValueError("Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument.")

if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
Expand Down Expand Up @@ -368,30 +421,36 @@ def main():

# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_name]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_checkpoint)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_checkpoint, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_checkpoint, from_tf=bool('.ckpt' in args.model_checkpoint), config=config)
args.num_embeddings = config.vocab_size # We need this to create the model at next line (number of embeddings to use)
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab

config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)
if args.block_size <= 0:
args.block_size = tokenizer.max_len # Our input block size will be the max possible for the model
model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), config=config)
model.to(args.device)

if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab

model.to(args.device)
torch.distributed.barrier() # End of barrier to make sure only the first process in distributed training download model & vocab

logger.info("Training/evaluation parameters %s", args)


# Training
if args.do_train:
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training process the dataset, and the others will use the cache

train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False)

if args.local_rank == 0:
torch.distributed.barrier()

global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)


# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
# Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
Expand All @@ -409,7 +468,7 @@ def main():

# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)


Expand Down
51 changes: 0 additions & 51 deletions examples/utils_lm.py

This file was deleted.

0 comments on commit a690eda

Please sign in to comment.