#!/usr/bin/env python # coding: utf-8 ## BERT inference to be called by server.R import argparse import datasets import json import numpy as np from os import path, remove import pandas as pd import pyarrow.feather as feather import re from torch import no_grad from transformers import AutoModelForSequenceClassification, AutoTokenizer def chunker(seq, batch_size): return (seq[pos:pos + batch_size] for pos in range(0, len(seq), batch_size)) def main(args): print("Importing data") with open(path.expanduser(args.logfile), "w") as progfile: progfile.write("Importing data") dat = feather.read_feather(path.expanduser(args.dat)) with open(path.expanduser(args.logfile), "w") as progfile: progfile.write("Tokenizing") ## Tokenizer print("Tokenizing") with open(path.join(path.expanduser(args.model), "config.json"), "r") as jsonfile: modeltype = json.load(jsonfile)["_name_or_path"] tokenizer = AutoTokenizer.from_pretrained(modeltype) ## Model print("Loading model") model = AutoModelForSequenceClassification.from_pretrained(path.expanduser(args.model)) if (args.gpu): model.cuda() ## Prediction functions def get_predprobs(text): inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt") if (args.gpu): inputs = inputs.to("cuda") with no_grad(): outputs = model(**inputs) res = outputs[0] if (args.gpu): res = res.cpu() res = res.softmax(1).detach().numpy() return res print("Computing predictions") chunks = chunker([str(x) for x in dat[args.txtname]], args.batch) pred = [] for i, x in enumerate(chunks): if (i % 5 == 0): percent = round(100 * i * args.batch / len(dat), 1) logmsg = "Computing: " + str(percent) + "% (" + str(i * args.batch) + "/" + str(len(dat)) + ")" with open(path.expanduser(args.logfile), "w") as progfile: progfile.write(logmsg) pred.append(get_predprobs(x)) pred = np.concatenate(pred) pred = pd.DataFrame(pred) pred.columns = ["bertpred_" + v for i, v in model.config.id2label.items()] pred = pd.concat([dat[args.idname], pred], axis=1) feather.write_feather(pred, path.abspath(args.output)) remove(path.expanduser(args.logfile)) if __name__ == "__main__": argParser = argparse.ArgumentParser() argParser.add_argument("-m", "--model", help="Trained model path") argParser.add_argument("-d", "--dat", help="Path to data (feather file)") argParser.add_argument("-o", "--output", help="Output path of predictions", default="tiggerbert.feather") argParser.add_argument("-i", "--idname", help="Name of id variable", default="id") argParser.add_argument("-x", "--txtname", help="Name of text variable", default="text") argParser.add_argument("-l", "--logfile", help="Path to log file", default="tiggerbert-progress.txt") argParser.add_argument("-G", "--gpu", help="Use GPU (CUDA)", type=int, choices=[0,1], default=1) argParser.add_argument("-b", "--batch", help="Batch size", type=int, default=128) args = argParser.parse_args() main(args)