Commit ef499b777c4665a9384b4167457abc2a17baf833

Authored by quillotm
1 parent 1bcb37e33d
Exists in master

Now we can extract labels and save them in a file. Useful to learn other systems…

… based on these labels. (or to create plots)

Showing 1 changed file with 36 additions and 3 deletions Side-by-side Diff

1 1 import argparse
2 2 from os import path, mkdir
3 3 from utils import SubCommandRunner
4   -from core.data import read_features, read_lst, read_labels
  4 +from core.data import read_features, read_lst, read_labels, write_line
5 5 import numpy as np
6 6 from sklearn.cluster import KMeans
7 7 import pickle
8 8  
... ... @@ -140,9 +140,27 @@
140 140 k = int(k)
141 141 fit_model(k, path.join(output, "clustering_" + str(i) + ".pkl"))
142 142  
143   - print(json_content)
  143 + print(json.dumps(json_content))
144 144  
145 145  
  146 +def extract_run(features, lst, model, modeltype, outfile):
  147 + feats_dict = read_features(features)
  148 + lst_dict = read_lst(lst)
  149 + lst_keys = [key for key in lst_dict]
  150 + feats = np.asarray([feats_dict[key] for key in lst_keys])
  151 +
  152 + module = CLUSTERING_METHODS[modeltype]
  153 + module.load(model)
  154 + Y_pred = module.predict(feats)
  155 + with open(outfile, "w") as f:
  156 + for i, key in enumerate(lst_keys):
  157 + write_line(key, Y_pred[i], f)
  158 + json_output = {
  159 + "outfile": outfile
  160 + }
  161 + print(json.dumps(json_output))
  162 +
  163 +
146 164 if __name__ == "__main__":
147 165 # Main parser
148 166 parser = argparse.ArgumentParser(description="Clustering methods to apply")
... ... @@ -211,6 +229,20 @@
211 229 help="...")
212 230 parser_disequilibrium.set_defaults(which="disequilibrium")
213 231  
  232 + # Extract
  233 + parser_extract = subparsers.add_parser(
  234 + "extract", help="extract cluster labels")
  235 +
  236 + parser_extract.add_argument("--features", required=True, type=str, help="...")
  237 + parser_extract.add_argument("--lst", required=True, type=str, help="...")
  238 + parser_extract.add_argument("--model", required=True, type=str, help="...")
  239 + parser_extract.add_argument("--modeltype",
  240 + required=True,
  241 + choices=[key for key in CLUSTERING_METHODS],
  242 + help="type of model for learning")
  243 + parser_extract.add_argument("--outfile", required=True, type=str, help="...")
  244 + parser_extract.set_defaults(which="extract")
  245 +
214 246 # Parse
215 247 args = parser.parse_args()
216 248  
... ... @@ -218,7 +250,8 @@
218 250 runner = SubCommandRunner({
219 251 "kmeans": kmeans_run,
220 252 "measure": measure_run,
221   - "disequilibrium": disequilibrium_run
  253 + "disequilibrium": disequilibrium_run,
  254 + "extract": extract_run
222 255 })
223 256  
224 257 runner.run(args.which, args.__dict__, remove="which")