Commit ef499b777c4665a9384b4167457abc2a17baf833
1 parent
1bcb37e33d
Exists in
master
Now we can extract labels and save them in a file. Useful to learn other systems…
… based on these labels. (or to create plots)
Showing 1 changed file with 36 additions and 3 deletions Side-by-side Diff
volia/clustering.py
1 | 1 | import argparse |
2 | 2 | from os import path, mkdir |
3 | 3 | from utils import SubCommandRunner |
4 | -from core.data import read_features, read_lst, read_labels | |
4 | +from core.data import read_features, read_lst, read_labels, write_line | |
5 | 5 | import numpy as np |
6 | 6 | from sklearn.cluster import KMeans |
7 | 7 | import pickle |
8 | 8 | |
... | ... | @@ -140,9 +140,27 @@ |
140 | 140 | k = int(k) |
141 | 141 | fit_model(k, path.join(output, "clustering_" + str(i) + ".pkl")) |
142 | 142 | |
143 | - print(json_content) | |
143 | + print(json.dumps(json_content)) | |
144 | 144 | |
145 | 145 | |
146 | +def extract_run(features, lst, model, modeltype, outfile): | |
147 | + feats_dict = read_features(features) | |
148 | + lst_dict = read_lst(lst) | |
149 | + lst_keys = [key for key in lst_dict] | |
150 | + feats = np.asarray([feats_dict[key] for key in lst_keys]) | |
151 | + | |
152 | + module = CLUSTERING_METHODS[modeltype] | |
153 | + module.load(model) | |
154 | + Y_pred = module.predict(feats) | |
155 | + with open(outfile, "w") as f: | |
156 | + for i, key in enumerate(lst_keys): | |
157 | + write_line(key, Y_pred[i], f) | |
158 | + json_output = { | |
159 | + "outfile": outfile | |
160 | + } | |
161 | + print(json.dumps(json_output)) | |
162 | + | |
163 | + | |
146 | 164 | if __name__ == "__main__": |
147 | 165 | # Main parser |
148 | 166 | parser = argparse.ArgumentParser(description="Clustering methods to apply") |
... | ... | @@ -211,6 +229,20 @@ |
211 | 229 | help="...") |
212 | 230 | parser_disequilibrium.set_defaults(which="disequilibrium") |
213 | 231 | |
232 | + # Extract | |
233 | + parser_extract = subparsers.add_parser( | |
234 | + "extract", help="extract cluster labels") | |
235 | + | |
236 | + parser_extract.add_argument("--features", required=True, type=str, help="...") | |
237 | + parser_extract.add_argument("--lst", required=True, type=str, help="...") | |
238 | + parser_extract.add_argument("--model", required=True, type=str, help="...") | |
239 | + parser_extract.add_argument("--modeltype", | |
240 | + required=True, | |
241 | + choices=[key for key in CLUSTERING_METHODS], | |
242 | + help="type of model for learning") | |
243 | + parser_extract.add_argument("--outfile", required=True, type=str, help="...") | |
244 | + parser_extract.set_defaults(which="extract") | |
245 | + | |
214 | 246 | # Parse |
215 | 247 | args = parser.parse_args() |
216 | 248 | |
... | ... | @@ -218,7 +250,8 @@ |
218 | 250 | runner = SubCommandRunner({ |
219 | 251 | "kmeans": kmeans_run, |
220 | 252 | "measure": measure_run, |
221 | - "disequilibrium": disequilibrium_run | |
253 | + "disequilibrium": disequilibrium_run, | |
254 | + "extract": extract_run | |
222 | 255 | }) |
223 | 256 | |
224 | 257 | runner.run(args.which, args.__dict__, remove="which") |