#!/usr/bin/env python # coding: utf-8 ## BERT trainer to be called by server.R ## Requires two data files with columns id, label and text import argparse import datasets from datasets import load_metric import numpy as np from os.path import expanduser import os import pandas as pd import re from sklearn import metrics from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import Trainer, TrainingArguments, TrainerCallback os.environ["TOKENIZERS_PARALLELISM"] = "false" def main(args): print("Importing data") dattrain = pd.read_csv(expanduser(args.traindat)) datval = pd.read_csv(expanduser(args.valdat)) datval_id = datval["id"] classcolname = "label" ## Make class_names class_names = [x for x in dattrain[classcolname].unique()] ## Labels to class number dattrain[classcolname] = [class_names.index(x) for x in dattrain[classcolname].to_list()] datval[classcolname] = [class_names.index(x) for x in datval[classcolname].to_list()] ## Transform to datasets dattrain = datasets.Dataset.from_pandas(dattrain[['text', 'label']]) datval = datasets.Dataset.from_pandas(datval[['text', 'label']]) # Model choice modelname = expanduser(args.model) ## Tokenizer print("Tokenizing") tokenizer = AutoTokenizer.from_pretrained(modelname) # toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length"), batched=True) # toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length"), batched=True) if args.adapt: toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding=True, max_length=512), batched=True) toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding=True, max_length=512), batched=True) else: toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length", max_length=512), batched=True) toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length", max_length=512), batched=True) del(dattrain) ## Model print("Loading model") model = AutoModelForSequenceClassification.from_pretrained(modelname, num_labels = len(class_names)) if (args.gpu): model.cuda() ## Train using Trainer interface print("Training...") BATCH_SIZE = args.batchsize GRAD_ACC = args.gradacc epochs = args.epochs total_steps = (epochs * len(toktrain)) // (BATCH_SIZE * GRAD_ACC) warmup_steps = (total_steps) // 10 eval_steps = total_steps // args.eval training_args = TrainingArguments( output_dir=args.session + "_train", learning_rate=args.lrate, weight_decay=args.wdecay, num_train_epochs=epochs, gradient_accumulation_steps=GRAD_ACC, per_device_train_batch_size=BATCH_SIZE, # per_device_eval_batch_size=BATCH_SIZE, per_device_eval_batch_size=32, warmup_steps=warmup_steps, eval_steps=eval_steps, evaluation_strategy="steps", save_strategy="steps", save_steps=eval_steps, logging_steps=eval_steps, do_eval=True, greater_is_better=False, load_best_model_at_end=bool(args.best), metric_for_best_model="eval_loss" ) trainer = Trainer(model=model, args=training_args, train_dataset=toktrain, eval_dataset=toktest) the_session = args.session class HaltCallback(TrainerCallback): "A callback that checks for _stop file to interrupt training" def on_step_begin(self, args, state, control, **kwargs): if os.path.exists(the_session + "_stop"): print("\nHalted by user.\n") control.should_training_stop = True return(control) else: print("\nNot halted by user.\n") trainer.add_callback(HaltCallback) trainer.train() ## Add class names to model label_to_id = {v: i for i, v in enumerate(class_names)} model.config.label2id = label_to_id model.config.id2label = {id: label for label, id in model.config.label2id.items()} ## Save model model.save_pretrained(args.session) ## Prediction functions def get_predprobs(text): # inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt") inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt") if (args.gpu): inputs = inputs.to("cuda") outputs = model(**inputs) res = outputs[0] if (args.gpu): res = res.cpu() res = res.softmax(1).detach().numpy() return res def get_prediction(text): return class_names[get_predprobs(text).argmax()] ## Metrics on validation set print("Computing predictions") testpred = [get_prediction(txt) for txt in datval["text"]] testtruth = [class_names[x] for x in datval["label"]] exportpred = pd.DataFrame(datval_id) exportpred.columns = ["id"] exportpred["bertpred"] = testpred exportpred.to_csv(args.session + "_predval.csv", index=False) if __name__ == "__main__": argParser = argparse.ArgumentParser() argParser.add_argument("-m", "--model", help="Model name or path", default="microsoft/Multilingual-MiniLM-L12-H384") argParser.add_argument("-t", "--traindat", help="Path to training data") argParser.add_argument("-v", "--valdat", help="Path to validation data") argParser.add_argument("-b", "--batchsize", help="Batch size for training", type=int, default=4) argParser.add_argument("-g", "--gradacc", help="Gradient accumulation for training", type=int, default=1) argParser.add_argument("-e", "--epochs", help="Number of training epochs", type=float, default=3) argParser.add_argument("-l", "--lrate", help="Learning rate", type=float, default=5e-05) argParser.add_argument("-w", "--wdecay", help="Weight decay", type=float, default=.01) argParser.add_argument("-B", "--best", help="Load best model instead of last", type=int, choices=[0,1], default=1) argParser.add_argument("-E", "--eval", help="Number of intermediary evaluations", type=int, default=10) argParser.add_argument("-s", "--session", help="Session name (used to save results)") argParser.add_argument("-G", "--gpu", help="Use GPU (CUDA)", type=int, choices=[0,1], default=0) argParser.add_argument("-A", "--adapt", help="Adapt token length to batch", type=int, choices=[0,1], default=1) args = argParser.parse_args() main(args)