Commit 993ea26cfef22ea3fc8abf3c6973de1d0fc297c0
1 parent
0c12dd8941
Exists in
master
take into account list option
Showing 1 changed file with 6 additions and 3 deletions Side-by-side Diff
bin/extract_kmeans.py
... | ... | @@ -13,7 +13,7 @@ |
13 | 13 | parser = argparse.ArgumentParser(description="extract clusters") |
14 | 14 | parser.add_argument("model", type=str, help="k-means model pickle") |
15 | 15 | parser.add_argument("features", type=str, help="features") |
16 | -parser.add_argument("list", type=str, help="list file") | |
16 | +parser.add_argument("--list", type=str, default=None, help="list file") | |
17 | 17 | parser.add_argument("--outfile", type=str, default=None, help="output file std") |
18 | 18 | |
19 | 19 | args = vars(parser.parse_args()) |
20 | 20 | |
21 | 21 | |
... | ... | @@ -31,13 +31,16 @@ |
31 | 31 | features = read_file(FEATURES) |
32 | 32 | feat_ind = index_by_id(features) |
33 | 33 | |
34 | -lst = read_file(LST) | |
34 | +if LST is not None: | |
35 | + lst = read_file(LST) | |
36 | +else: | |
37 | + lst = features | |
35 | 38 | |
36 | 39 | kmeans = pickle.load(open(MODEL, "rb")) |
37 | 40 | |
38 | - | |
39 | 41 | # -- CONVERT TO NUMPY -- |
40 | 42 | X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst]) |
43 | + | |
41 | 44 | predictions = kmeans.predict(X) |
42 | 45 | |
43 | 46 | for i, line in enumerate(lst): |