95 lines
3.0 KiB
Python
95 lines
3.0 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
## BERT inference to be called by server.R
|
|
|
|
import argparse
|
|
import datasets
|
|
import json
|
|
import numpy as np
|
|
from os import path, remove
|
|
import pandas as pd
|
|
import pyarrow.feather as feather
|
|
import re
|
|
from torch import no_grad
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
|
|
def chunker(seq, batch_size):
|
|
return (seq[pos:pos + batch_size] for pos in range(0, len(seq), batch_size))
|
|
|
|
|
|
def main(args):
|
|
print("Importing data")
|
|
with open(path.expanduser(args.logfile), "w") as progfile:
|
|
progfile.write("Importing data")
|
|
|
|
dat = feather.read_feather(path.expanduser(args.dat))
|
|
|
|
with open(path.expanduser(args.logfile), "w") as progfile:
|
|
progfile.write("Tokenizing")
|
|
|
|
## Tokenizer
|
|
print("Tokenizing")
|
|
with open(path.join(path.expanduser(args.model), "config.json"), "r") as jsonfile:
|
|
modeltype = json.load(jsonfile)["_name_or_path"]
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(modeltype)
|
|
|
|
## Model
|
|
print("Loading model")
|
|
model = AutoModelForSequenceClassification.from_pretrained(path.expanduser(args.model))
|
|
if (args.gpu):
|
|
model.cuda()
|
|
|
|
## Prediction functions
|
|
|
|
|
|
def get_predprobs(text):
|
|
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
|
if (args.gpu):
|
|
inputs = inputs.to("cuda")
|
|
with no_grad():
|
|
outputs = model(**inputs)
|
|
res = outputs[0]
|
|
if (args.gpu):
|
|
res = res.cpu()
|
|
res = res.softmax(1).detach().numpy()
|
|
return res
|
|
|
|
print("Computing predictions")
|
|
|
|
chunks = chunker([str(x) for x in dat[args.txtname]], args.batch)
|
|
pred = []
|
|
for i, x in enumerate(chunks):
|
|
if (i % 5 == 0):
|
|
percent = round(100 * i * args.batch / len(dat), 1)
|
|
logmsg = "Computing: " + str(percent) + "% (" + str(i * args.batch) + "/" + str(len(dat)) + ")"
|
|
with open(path.expanduser(args.logfile), "w") as progfile:
|
|
progfile.write(logmsg)
|
|
pred.append(get_predprobs(x))
|
|
|
|
pred = np.concatenate(pred)
|
|
pred = pd.DataFrame(pred)
|
|
pred.columns = ["bertpred_" + v for i, v in model.config.id2label.items()]
|
|
pred = pd.concat([dat[args.idname], pred], axis=1)
|
|
feather.write_feather(pred, path.abspath(args.output))
|
|
remove(path.expanduser(args.logfile))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argParser = argparse.ArgumentParser()
|
|
argParser.add_argument("-m", "--model", help="Trained model path")
|
|
argParser.add_argument("-d", "--dat", help="Path to data (feather file)")
|
|
argParser.add_argument("-o", "--output", help="Output path of predictions", default="tiggerbert.feather")
|
|
argParser.add_argument("-i", "--idname", help="Name of id variable", default="id")
|
|
argParser.add_argument("-x", "--txtname", help="Name of text variable", default="text")
|
|
argParser.add_argument("-l", "--logfile", help="Path to log file", default="tiggerbert-progress.txt")
|
|
argParser.add_argument("-G", "--gpu", help="Use GPU (CUDA)", type=int, choices=[0,1], default=1)
|
|
argParser.add_argument("-b", "--batch", help="Batch size", type=int, default=128)
|
|
|
|
args = argParser.parse_args()
|
|
|
|
main(args)
|
|
|