diff --git a/bin/extract_vectors.py b/bin/extract_vectors.py index babe09f..cff805d 100644 --- a/bin/extract_vectors.py +++ b/bin/extract_vectors.py @@ -7,6 +7,7 @@ vectors you want to keep. import os import numpy as np import argparse +from data import read_file, index_by_id, write_line parser = argparse.ArgumentParser(description='Extract a subset of vectors') parser.add_argument('vectorsfile', type=str, @@ -25,28 +26,16 @@ LIST_FILE = args.listfile OUTPUT_FILE = args.output # READ VECTOR DATA -data = {} -data["en-us"] = {} -data["fr-fr"] = {} -with open(VECTOR_FILE, "r") as f: - for i, line in enumerate(f): - if TOY_VERSION == True and i > 100: - break - spl_line = line.split(" ") - if(len(pvectors) == 0): - pvectors = np.empty((0, len(spl_line[1:])), np.float32) - spl_meta = spl_line.split(",") - lang = spl_meta[0] - iden = spl_meta[3] - data[lang][iden] = line - -# READ LIST AND WRITE NEW FILE -with open(LIST_FILE, "r") as f, open(OUTPUT_FILE, "w") as o: - for i, line in enumerate(LIST_FILE): - if TOY_VERSION == True and i > 100: - break - spl_meta = line.split(",") - lang = spl_meta[0] - iden = spl_meta[3] - OUTPUT_FILE.write(data[lang][iden]) +features = read_file(VECTOR_FILE) +features_ind = index_by_id(features) +lst = read_file(LIST_FILE) + + +# COMPUTE KEPT FEATS +kept_feats = [features_ind[x[0][0]][x[0][3]] for x in lst] + +# WRITE IN FILE +with open(OUTPUT_FILE, 'w') as f: + for feat in kept_feats: + write_line(feat[0], feat[1], f=f)