Commit 993ea26cfef22ea3fc8abf3c6973de1d0fc297c0

Authored by Mathias Quillot
1 parent 0c12dd8941
Exists in master

take into account list option

Showing 1 changed file with 6 additions and 3 deletions Inline Diff

bin/extract_kmeans.py
1 ''' 1 '''
2 This script aims to extract k-means clustering from a 2 This script aims to extract k-means clustering from a
3 a priori trained k-means. 3 a priori trained k-means.
4 ''' 4 '''
5 5
6 import argparse 6 import argparse
7 import numpy as np 7 import numpy as np
8 import pickle 8 import pickle
9 from data import read_file, index_by_id, write_line 9 from data import read_file, index_by_id, write_line
10 import sys 10 import sys
11 11
12 # -- ARGPARSE -- 12 # -- ARGPARSE --
13 parser = argparse.ArgumentParser(description="extract clusters") 13 parser = argparse.ArgumentParser(description="extract clusters")
14 parser.add_argument("model", type=str, help="k-means model pickle") 14 parser.add_argument("model", type=str, help="k-means model pickle")
15 parser.add_argument("features", type=str, help="features") 15 parser.add_argument("features", type=str, help="features")
16 parser.add_argument("list", type=str, help="list file") 16 parser.add_argument("--list", type=str, default=None, help="list file")
17 parser.add_argument("--outfile", type=str, default=None, help="output file std") 17 parser.add_argument("--outfile", type=str, default=None, help="output file std")
18 18
19 args = vars(parser.parse_args()) 19 args = vars(parser.parse_args())
20 MODEL = args["model"] 20 MODEL = args["model"]
21 FEATURES = args["features"] 21 FEATURES = args["features"]
22 LST = args["list"] 22 LST = args["list"]
23 OUTFILE = args["outfile"] 23 OUTFILE = args["outfile"]
24 24
25 if OUTFILE == None: 25 if OUTFILE == None:
26 OUTFILE = sys.stdout 26 OUTFILE = sys.stdout
27 else: 27 else:
28 OUTFILE = open(OUTFILE, "w") 28 OUTFILE = open(OUTFILE, "w")
29 29
30 # -- READ FILE -- 30 # -- READ FILE --
31 features = read_file(FEATURES) 31 features = read_file(FEATURES)
32 feat_ind = index_by_id(features) 32 feat_ind = index_by_id(features)
33 33
34 lst = read_file(LST) 34 if LST is not None:
35 lst = read_file(LST)
36 else:
37 lst = features
35 38
36 kmeans = pickle.load(open(MODEL, "rb")) 39 kmeans = pickle.load(open(MODEL, "rb"))
37 40
38
39 # -- CONVERT TO NUMPY -- 41 # -- CONVERT TO NUMPY --
40 X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst]) 42 X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst])
43
41 predictions = kmeans.predict(X) 44 predictions = kmeans.predict(X)
42 45
43 for i, line in enumerate(lst): 46 for i, line in enumerate(lst):
44 meta = line[0] 47 meta = line[0]
45 meta[1] = str(predictions[i]) 48 meta[1] = str(predictions[i])
46 write_line( 49 write_line(
47 meta, 50 meta,
48 feat_ind[meta[0]][meta[3]][1], 51 feat_ind[meta[0]][meta[3]][1],
49 OUTFILE 52 OUTFILE
50 ) 53 )
51 54
52 # -- CLOSE OUT FILE IF NECESSARY -- 55 # -- CLOSE OUT FILE IF NECESSARY --
53 if not OUTFILE == sys.stdout: 56 if not OUTFILE == sys.stdout: