Commit ef499b777c4665a9384b4167457abc2a17baf833
1 parent
1bcb37e33d
Exists in
master
Now we can extract labels and save them in a file. Useful to learn other systems…
… based on these labels. (or to create plots)
Showing 1 changed file with 36 additions and 3 deletions Side-by-side Diff
volia/clustering.py
| 1 | 1 | import argparse |
| 2 | 2 | from os import path, mkdir |
| 3 | 3 | from utils import SubCommandRunner |
| 4 | -from core.data import read_features, read_lst, read_labels | |
| 4 | +from core.data import read_features, read_lst, read_labels, write_line | |
| 5 | 5 | import numpy as np |
| 6 | 6 | from sklearn.cluster import KMeans |
| 7 | 7 | import pickle |
| 8 | 8 | |
| ... | ... | @@ -140,9 +140,27 @@ |
| 140 | 140 | k = int(k) |
| 141 | 141 | fit_model(k, path.join(output, "clustering_" + str(i) + ".pkl")) |
| 142 | 142 | |
| 143 | - print(json_content) | |
| 143 | + print(json.dumps(json_content)) | |
| 144 | 144 | |
| 145 | 145 | |
| 146 | +def extract_run(features, lst, model, modeltype, outfile): | |
| 147 | + feats_dict = read_features(features) | |
| 148 | + lst_dict = read_lst(lst) | |
| 149 | + lst_keys = [key for key in lst_dict] | |
| 150 | + feats = np.asarray([feats_dict[key] for key in lst_keys]) | |
| 151 | + | |
| 152 | + module = CLUSTERING_METHODS[modeltype] | |
| 153 | + module.load(model) | |
| 154 | + Y_pred = module.predict(feats) | |
| 155 | + with open(outfile, "w") as f: | |
| 156 | + for i, key in enumerate(lst_keys): | |
| 157 | + write_line(key, Y_pred[i], f) | |
| 158 | + json_output = { | |
| 159 | + "outfile": outfile | |
| 160 | + } | |
| 161 | + print(json.dumps(json_output)) | |
| 162 | + | |
| 163 | + | |
| 146 | 164 | if __name__ == "__main__": |
| 147 | 165 | # Main parser |
| 148 | 166 | parser = argparse.ArgumentParser(description="Clustering methods to apply") |
| ... | ... | @@ -211,6 +229,20 @@ |
| 211 | 229 | help="...") |
| 212 | 230 | parser_disequilibrium.set_defaults(which="disequilibrium") |
| 213 | 231 | |
| 232 | + # Extract | |
| 233 | + parser_extract = subparsers.add_parser( | |
| 234 | + "extract", help="extract cluster labels") | |
| 235 | + | |
| 236 | + parser_extract.add_argument("--features", required=True, type=str, help="...") | |
| 237 | + parser_extract.add_argument("--lst", required=True, type=str, help="...") | |
| 238 | + parser_extract.add_argument("--model", required=True, type=str, help="...") | |
| 239 | + parser_extract.add_argument("--modeltype", | |
| 240 | + required=True, | |
| 241 | + choices=[key for key in CLUSTERING_METHODS], | |
| 242 | + help="type of model for learning") | |
| 243 | + parser_extract.add_argument("--outfile", required=True, type=str, help="...") | |
| 244 | + parser_extract.set_defaults(which="extract") | |
| 245 | + | |
| 214 | 246 | # Parse |
| 215 | 247 | args = parser.parse_args() |
| 216 | 248 | |
| ... | ... | @@ -218,7 +250,8 @@ |
| 218 | 250 | runner = SubCommandRunner({ |
| 219 | 251 | "kmeans": kmeans_run, |
| 220 | 252 | "measure": measure_run, |
| 221 | - "disequilibrium": disequilibrium_run | |
| 253 | + "disequilibrium": disequilibrium_run, | |
| 254 | + "extract": extract_run | |
| 222 | 255 | }) |
| 223 | 256 | |
| 224 | 257 | runner.run(args.which, args.__dict__, remove="which") |