diff --git a/volia/clustering.py b/volia/clustering.py index d6b96fd..1dfae8a 100644 --- a/volia/clustering.py +++ b/volia/clustering.py @@ -1,7 +1,7 @@ import argparse from os import path, mkdir from utils import SubCommandRunner -from core.data import read_features, read_lst, read_labels +from core.data import read_features, read_lst, read_labels, write_line import numpy as np from sklearn.cluster import KMeans import pickle @@ -140,7 +140,25 @@ def kmeans_run(features: str, k = int(k) fit_model(k, path.join(output, "clustering_" + str(i) + ".pkl")) - print(json_content) + print(json.dumps(json_content)) + + +def extract_run(features, lst, model, modeltype, outfile): + feats_dict = read_features(features) + lst_dict = read_lst(lst) + lst_keys = [key for key in lst_dict] + feats = np.asarray([feats_dict[key] for key in lst_keys]) + + module = CLUSTERING_METHODS[modeltype] + module.load(model) + Y_pred = module.predict(feats) + with open(outfile, "w") as f: + for i, key in enumerate(lst_keys): + write_line(key, Y_pred[i], f) + json_output = { + "outfile": outfile + } + print(json.dumps(json_output)) if __name__ == "__main__": @@ -211,6 +229,20 @@ if __name__ == "__main__": help="...") parser_disequilibrium.set_defaults(which="disequilibrium") + # Extract + parser_extract = subparsers.add_parser( + "extract", help="extract cluster labels") + + parser_extract.add_argument("--features", required=True, type=str, help="...") + parser_extract.add_argument("--lst", required=True, type=str, help="...") + parser_extract.add_argument("--model", required=True, type=str, help="...") + parser_extract.add_argument("--modeltype", + required=True, + choices=[key for key in CLUSTERING_METHODS], + help="type of model for learning") + parser_extract.add_argument("--outfile", required=True, type=str, help="...") + parser_extract.set_defaults(which="extract") + # Parse args = parser.parse_args() @@ -218,7 +250,8 @@ if __name__ == "__main__": runner = SubCommandRunner({ "kmeans": kmeans_run, "measure": measure_run, - "disequilibrium": disequilibrium_run + "disequilibrium": disequilibrium_run, + "extract": extract_run }) runner.run(args.which, args.__dict__, remove="which")