175 lines
6.6 KiB
Python
175 lines
6.6 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
## BERT trainer to be called by server.R
|
|
## Requires two data files with columns id, label and text
|
|
|
|
import argparse
|
|
import datasets
|
|
from datasets import load_metric
|
|
import numpy as np
|
|
from os.path import expanduser
|
|
import os
|
|
import pandas as pd
|
|
import re
|
|
from sklearn import metrics
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
from transformers import Trainer, TrainingArguments, TrainerCallback
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
def main(args):
|
|
print("Importing data")
|
|
dattrain = pd.read_csv(expanduser(args.traindat))
|
|
datval = pd.read_csv(expanduser(args.valdat))
|
|
datval_id = datval["id"]
|
|
classcolname = "label"
|
|
|
|
## Make class_names
|
|
class_names = [x for x in dattrain[classcolname].unique()]
|
|
|
|
## Labels to class number
|
|
dattrain[classcolname] = [class_names.index(x) for x in dattrain[classcolname].to_list()]
|
|
datval[classcolname] = [class_names.index(x) for x in datval[classcolname].to_list()]
|
|
|
|
## Transform to datasets
|
|
dattrain = datasets.Dataset.from_pandas(dattrain[['text', 'label']])
|
|
datval = datasets.Dataset.from_pandas(datval[['text', 'label']])
|
|
|
|
# Model choice
|
|
modelname = expanduser(args.model)
|
|
|
|
## Tokenizer
|
|
print("Tokenizing")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(modelname)
|
|
|
|
# toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length"), batched=True)
|
|
# toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length"), batched=True)
|
|
if args.adapt:
|
|
toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding=True, max_length=512), batched=True)
|
|
toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding=True, max_length=512), batched=True)
|
|
else:
|
|
toktrain = dattrain.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length", max_length=512), batched=True)
|
|
toktest = datval.map(lambda e: tokenizer(e['text'], truncation=True, padding="max_length", max_length=512), batched=True)
|
|
|
|
del(dattrain)
|
|
|
|
## Model
|
|
print("Loading model")
|
|
model = AutoModelForSequenceClassification.from_pretrained(modelname, num_labels = len(class_names))
|
|
if (args.gpu):
|
|
model.cuda()
|
|
|
|
## Train using Trainer interface
|
|
print("Training...")
|
|
BATCH_SIZE = args.batchsize
|
|
GRAD_ACC = args.gradacc
|
|
epochs = args.epochs
|
|
|
|
total_steps = (epochs * len(toktrain)) // (BATCH_SIZE * GRAD_ACC)
|
|
warmup_steps = (total_steps) // 10
|
|
eval_steps = total_steps // args.eval
|
|
|
|
training_args = TrainingArguments(
|
|
output_dir=args.session + "_train",
|
|
learning_rate=args.lrate,
|
|
weight_decay=args.wdecay,
|
|
num_train_epochs=epochs,
|
|
gradient_accumulation_steps=GRAD_ACC,
|
|
per_device_train_batch_size=BATCH_SIZE,
|
|
# per_device_eval_batch_size=BATCH_SIZE,
|
|
per_device_eval_batch_size=32,
|
|
warmup_steps=warmup_steps,
|
|
|
|
eval_steps=eval_steps,
|
|
evaluation_strategy="steps",
|
|
save_strategy="steps",
|
|
save_steps=eval_steps,
|
|
logging_steps=eval_steps,
|
|
do_eval=True,
|
|
greater_is_better=False,
|
|
load_best_model_at_end=bool(args.best),
|
|
metric_for_best_model="eval_loss"
|
|
)
|
|
|
|
trainer = Trainer(model=model, args=training_args,
|
|
train_dataset=toktrain, eval_dataset=toktest)
|
|
|
|
the_session = args.session
|
|
class HaltCallback(TrainerCallback):
|
|
"A callback that checks for _stop file to interrupt training"
|
|
|
|
def on_step_begin(self, args, state, control, **kwargs):
|
|
if os.path.exists(the_session + "_stop"):
|
|
print("\nHalted by user.\n")
|
|
control.should_training_stop = True
|
|
return(control)
|
|
else:
|
|
print("\nNot halted by user.\n")
|
|
|
|
trainer.add_callback(HaltCallback)
|
|
|
|
trainer.train()
|
|
|
|
## Add class names to model
|
|
label_to_id = {v: i for i, v in enumerate(class_names)}
|
|
model.config.label2id = label_to_id
|
|
model.config.id2label = {id: label for label, id in model.config.label2id.items()}
|
|
|
|
## Save model
|
|
model.save_pretrained(args.session)
|
|
|
|
|
|
## Prediction functions
|
|
|
|
|
|
def get_predprobs(text):
|
|
# inputs = tokenizer(text, padding="max_length", truncation=True, return_tensors="pt")
|
|
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
|
if (args.gpu):
|
|
inputs = inputs.to("cuda")
|
|
outputs = model(**inputs)
|
|
res = outputs[0]
|
|
if (args.gpu):
|
|
res = res.cpu()
|
|
res = res.softmax(1).detach().numpy()
|
|
return res
|
|
|
|
|
|
def get_prediction(text):
|
|
return class_names[get_predprobs(text).argmax()]
|
|
|
|
## Metrics on validation set
|
|
print("Computing predictions")
|
|
testpred = [get_prediction(txt) for txt in datval["text"]]
|
|
testtruth = [class_names[x] for x in datval["label"]]
|
|
|
|
exportpred = pd.DataFrame(datval_id)
|
|
exportpred.columns = ["id"]
|
|
exportpred["bertpred"] = testpred
|
|
exportpred.to_csv(args.session + "_predval.csv", index=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argParser = argparse.ArgumentParser()
|
|
argParser.add_argument("-m", "--model", help="Model name or path", default="microsoft/Multilingual-MiniLM-L12-H384")
|
|
argParser.add_argument("-t", "--traindat", help="Path to training data")
|
|
argParser.add_argument("-v", "--valdat", help="Path to validation data")
|
|
argParser.add_argument("-b", "--batchsize", help="Batch size for training", type=int, default=4)
|
|
argParser.add_argument("-g", "--gradacc", help="Gradient accumulation for training", type=int, default=1)
|
|
argParser.add_argument("-e", "--epochs", help="Number of training epochs", type=float, default=3)
|
|
argParser.add_argument("-l", "--lrate", help="Learning rate", type=float, default=5e-05)
|
|
argParser.add_argument("-w", "--wdecay", help="Weight decay", type=float, default=.01)
|
|
argParser.add_argument("-B", "--best", help="Load best model instead of last", type=int, choices=[0,1], default=1)
|
|
argParser.add_argument("-E", "--eval", help="Number of intermediary evaluations", type=int, default=10)
|
|
argParser.add_argument("-s", "--session", help="Session name (used to save results)")
|
|
argParser.add_argument("-G", "--gpu", help="Use GPU (CUDA)", type=int, choices=[0,1], default=0)
|
|
argParser.add_argument("-A", "--adapt", help="Adapt token length to batch", type=int, choices=[0,1], default=1)
|
|
|
|
|
|
args = argParser.parse_args()
|
|
|
|
main(args)
|
|
|