datalab/docker-images-datalab/myactivetigger/activetigger/embed_sbert.py

43 lines
1.2 KiB
Python
Raw Normal View History

2024-03-06 15:54:50 +01:00
#!/usr/bin/env python
# coding: utf-8
## SBERT embed sentences
## Requires data file with columns id and text
import argparse
from os.path import expanduser
import pandas as pd
import pyarrow as pa
import pyarrow.feather as feather
import re
from sentence_transformers import SentenceTransformer
def main(args):
print("SBERT: Importing data")
datapath = expanduser(args.data)
dat = feather.read_feather(datapath)
outfile = re.sub("[.]feather$", "_sb.feather", datapath)
print("SBERT: Loading model")
sbert = SentenceTransformer(expanduser(args.model))
sbert.max_seq_length = 512
print("SBERT: Embedding sentences")
emb = sbert.encode(dat["text"])
print("SBERT: Exporting")
emb = pd.DataFrame(emb)
emb.columns = ["sb%03d" % (x + 1) for x in range(len(emb.columns))]
emb = pd.concat([dat["id"], emb], axis=1)
feather.write_feather(emb, outfile)
print("SBERT: Done")
if __name__ == "__main__":
argParser = argparse.ArgumentParser()
argParser.add_argument("-m", "--model", help="Model name or path", default="distiluse-base-multilingual-cased-v1")
argParser.add_argument("-d", "--data", help="Path to data (feather)")
args = argParser.parse_args()
main(args)