#!/usr/bin/env python # coding: utf-8 ## SBERT embed sentences ## Requires data file with columns id and text import argparse from os.path import expanduser import pandas as pd import pyarrow as pa import pyarrow.feather as feather import re from sentence_transformers import SentenceTransformer def main(args): print("SBERT: Importing data") datapath = expanduser(args.data) dat = feather.read_feather(datapath) outfile = re.sub("[.]feather$", "_sb.feather", datapath) print("SBERT: Loading model") sbert = SentenceTransformer(expanduser(args.model)) sbert.max_seq_length = 512 print("SBERT: Embedding sentences") emb = sbert.encode(dat["text"]) print("SBERT: Exporting") emb = pd.DataFrame(emb) emb.columns = ["sb%03d" % (x + 1) for x in range(len(emb.columns))] emb = pd.concat([dat["id"], emb], axis=1) feather.write_feather(emb, outfile) print("SBERT: Done") if __name__ == "__main__": argParser = argparse.ArgumentParser() argParser.add_argument("-m", "--model", help="Model name or path", default="distiluse-base-multilingual-cased-v1") argParser.add_argument("-d", "--data", help="Path to data (feather)") args = argParser.parse_args() main(args)