43 lines
1.2 KiB
Python
43 lines
1.2 KiB
Python
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
## SBERT embed sentences
|
|
## Requires data file with columns id and text
|
|
|
|
import argparse
|
|
from os.path import expanduser
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pyarrow.feather as feather
|
|
import re
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
def main(args):
|
|
print("SBERT: Importing data")
|
|
datapath = expanduser(args.data)
|
|
dat = feather.read_feather(datapath)
|
|
outfile = re.sub("[.]feather$", "_sb.feather", datapath)
|
|
|
|
print("SBERT: Loading model")
|
|
sbert = SentenceTransformer(expanduser(args.model))
|
|
sbert.max_seq_length = 512
|
|
print("SBERT: Embedding sentences")
|
|
emb = sbert.encode(dat["text"])
|
|
|
|
print("SBERT: Exporting")
|
|
emb = pd.DataFrame(emb)
|
|
emb.columns = ["sb%03d" % (x + 1) for x in range(len(emb.columns))]
|
|
emb = pd.concat([dat["id"], emb], axis=1)
|
|
feather.write_feather(emb, outfile)
|
|
print("SBERT: Done")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
argParser = argparse.ArgumentParser()
|
|
argParser.add_argument("-m", "--model", help="Model name or path", default="distiluse-base-multilingual-cased-v1")
|
|
argParser.add_argument("-d", "--data", help="Path to data (feather)")
|
|
args = argParser.parse_args()
|
|
main(args)
|
|
|