datalab/docker-images-datalab/myactivetigger/activetigger/gobert_infer.py

95 lines
3.0 KiB
Python

#!/usr/bin/env python
# coding: utf-8
## BERT inference to be called by server.R
import argparse
import datasets
import json
import numpy as np
from os import path, remove
import pandas as pd
import pyarrow.feather as feather
import re
from torch import no_grad
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def chunker(seq, batch_size):
return (seq[pos:pos + batch_size] for pos in range(0, len(seq), batch_size))
def main(args):
print("Importing data")
with open(path.expanduser(args.logfile), "w") as progfile:
progfile.write("Importing data")
dat = feather.read_feather(path.expanduser(args.dat))
with open(path.expanduser(args.logfile), "w") as progfile:
progfile.write("Tokenizing")
## Tokenizer
print("Tokenizing")
with open(path.join(path.expanduser(args.model), "config.json"), "r") as jsonfile:
modeltype = json.load(jsonfile)["_name_or_path"]
tokenizer = AutoTokenizer.from_pretrained(modeltype)
## Model
print("Loading model")
model = AutoModelForSequenceClassification.from_pretrained(path.expanduser(args.model))
if (args.gpu):
model.cuda()
## Prediction functions
def get_predprobs(text):
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt")
if (args.gpu):
inputs = inputs.to("cuda")
with no_grad():
outputs = model(**inputs)
res = outputs[0]
if (args.gpu):
res = res.cpu()
res = res.softmax(1).detach().numpy()
return res
print("Computing predictions")
chunks = chunker([str(x) for x in dat[args.txtname]], args.batch)
pred = []
for i, x in enumerate(chunks):
if (i % 5 == 0):
percent = round(100 * i * args.batch / len(dat), 1)
logmsg = "Computing: " + str(percent) + "% (" + str(i * args.batch) + "/" + str(len(dat)) + ")"
with open(path.expanduser(args.logfile), "w") as progfile:
progfile.write(logmsg)
pred.append(get_predprobs(x))
pred = np.concatenate(pred)
pred = pd.DataFrame(pred)
pred.columns = ["bertpred_" + v for i, v in model.config.id2label.items()]
pred = pd.concat([dat[args.idname], pred], axis=1)
feather.write_feather(pred, path.abspath(args.output))
remove(path.expanduser(args.logfile))
if __name__ == "__main__":
argParser = argparse.ArgumentParser()
argParser.add_argument("-m", "--model", help="Trained model path")
argParser.add_argument("-d", "--dat", help="Path to data (feather file)")
argParser.add_argument("-o", "--output", help="Output path of predictions", default="tiggerbert.feather")
argParser.add_argument("-i", "--idname", help="Name of id variable", default="id")
argParser.add_argument("-x", "--txtname", help="Name of text variable", default="text")
argParser.add_argument("-l", "--logfile", help="Path to log file", default="tiggerbert-progress.txt")
argParser.add_argument("-G", "--gpu", help="Use GPU (CUDA)", type=int, choices=[0,1], default=1)
argParser.add_argument("-b", "--batch", help="Batch size", type=int, default=128)
args = argParser.parse_args()
main(args)