Commit 993ea26cfef22ea3fc8abf3c6973de1d0fc297c0
1 parent
0c12dd8941
Exists in
master
take into account list option
Showing 1 changed file with 6 additions and 3 deletions Inline Diff
bin/extract_kmeans.py
1 | ''' | 1 | ''' |
2 | This script aims to extract k-means clustering from a | 2 | This script aims to extract k-means clustering from a |
3 | a priori trained k-means. | 3 | a priori trained k-means. |
4 | ''' | 4 | ''' |
5 | 5 | ||
6 | import argparse | 6 | import argparse |
7 | import numpy as np | 7 | import numpy as np |
8 | import pickle | 8 | import pickle |
9 | from data import read_file, index_by_id, write_line | 9 | from data import read_file, index_by_id, write_line |
10 | import sys | 10 | import sys |
11 | 11 | ||
12 | # -- ARGPARSE -- | 12 | # -- ARGPARSE -- |
13 | parser = argparse.ArgumentParser(description="extract clusters") | 13 | parser = argparse.ArgumentParser(description="extract clusters") |
14 | parser.add_argument("model", type=str, help="k-means model pickle") | 14 | parser.add_argument("model", type=str, help="k-means model pickle") |
15 | parser.add_argument("features", type=str, help="features") | 15 | parser.add_argument("features", type=str, help="features") |
16 | parser.add_argument("list", type=str, help="list file") | 16 | parser.add_argument("--list", type=str, default=None, help="list file") |
17 | parser.add_argument("--outfile", type=str, default=None, help="output file std") | 17 | parser.add_argument("--outfile", type=str, default=None, help="output file std") |
18 | 18 | ||
19 | args = vars(parser.parse_args()) | 19 | args = vars(parser.parse_args()) |
20 | MODEL = args["model"] | 20 | MODEL = args["model"] |
21 | FEATURES = args["features"] | 21 | FEATURES = args["features"] |
22 | LST = args["list"] | 22 | LST = args["list"] |
23 | OUTFILE = args["outfile"] | 23 | OUTFILE = args["outfile"] |
24 | 24 | ||
25 | if OUTFILE == None: | 25 | if OUTFILE == None: |
26 | OUTFILE = sys.stdout | 26 | OUTFILE = sys.stdout |
27 | else: | 27 | else: |
28 | OUTFILE = open(OUTFILE, "w") | 28 | OUTFILE = open(OUTFILE, "w") |
29 | 29 | ||
30 | # -- READ FILE -- | 30 | # -- READ FILE -- |
31 | features = read_file(FEATURES) | 31 | features = read_file(FEATURES) |
32 | feat_ind = index_by_id(features) | 32 | feat_ind = index_by_id(features) |
33 | 33 | ||
34 | lst = read_file(LST) | 34 | if LST is not None: |
35 | lst = read_file(LST) | ||
36 | else: | ||
37 | lst = features | ||
35 | 38 | ||
36 | kmeans = pickle.load(open(MODEL, "rb")) | 39 | kmeans = pickle.load(open(MODEL, "rb")) |
37 | 40 | ||
38 | |||
39 | # -- CONVERT TO NUMPY -- | 41 | # -- CONVERT TO NUMPY -- |
40 | X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst]) | 42 | X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst]) |
43 | |||
41 | predictions = kmeans.predict(X) | 44 | predictions = kmeans.predict(X) |
42 | 45 | ||
43 | for i, line in enumerate(lst): | 46 | for i, line in enumerate(lst): |
44 | meta = line[0] | 47 | meta = line[0] |
45 | meta[1] = str(predictions[i]) | 48 | meta[1] = str(predictions[i]) |
46 | write_line( | 49 | write_line( |
47 | meta, | 50 | meta, |
48 | feat_ind[meta[0]][meta[3]][1], | 51 | feat_ind[meta[0]][meta[3]][1], |
49 | OUTFILE | 52 | OUTFILE |
50 | ) | 53 | ) |
51 | 54 | ||
52 | # -- CLOSE OUT FILE IF NECESSARY -- | 55 | # -- CLOSE OUT FILE IF NECESSARY -- |
53 | if not OUTFILE == sys.stdout: | 56 | if not OUTFILE == sys.stdout: |