diff --git a/volia/clustering.py b/volia/clustering.py index af8aefc..3093fc5 100644 --- a/volia/clustering.py +++ b/volia/clustering.py @@ -7,6 +7,7 @@ from sklearn.cluster import KMeans import pickle from clustering_modules.kmeans import kmeans from clustering_modules.kmeans_mahalanobis import kmeansMahalanobis +from clustering_modules.kmeans_multidistance import kmeansMultidistance from sklearn.preprocessing import LabelEncoder from sklearn.metrics import v_measure_score, homogeneity_score, completeness_score @@ -18,9 +19,13 @@ import json CLUSTERING_METHODS = { "k-means": kmeans(), "k-means-mahalanobis": kmeansMahalanobis(), - "k-means-mahalanobis-constrained": kmeansMahalanobis(constrained=True) + "k-means-mahalanobis-constrained": kmeansMahalanobis(constrained=True), + "k-means-basic-mahalanobis": kmeansMultidistance(distance="mahalanobis"), + "k-means-basic-cosine": kmeansMultidistance(distance="cosine") } +KMEANS_METHODS = [key for key in CLUSTERING_METHODS if key.startswith("k-means")] + EVALUATION_METHODS = { "entropy": core.measures.entropy_score, "purity": core.measures.purity_score, @@ -77,8 +82,8 @@ def kmeans_run(features: str, ninit: int, output: str, tol: float, - debug: bool = False, - mahalanobis: str = False): + modeltype: str, + debug: bool = False): """ @param features: output features @@ -94,11 +99,7 @@ def kmeans_run(features: str, def fit_model(k: int, output_file): if debug: print(f"Computing clustering with k={k}") - model = CLUSTERING_METHODS["k-means"] - if mahalanobis: - if debug: - print("Mahalanobis activated") - model = CLUSTERING_METHODS["k-means-mahalanobis-constrained"] + model = CLUSTERING_METHODS[modeltype] model.fit(X, k, tol, ninit, maxiter, debug) model.save(output_file) json_content["models"].append({ @@ -193,7 +194,10 @@ if __name__ == "__main__": parser_kmeans.add_argument("--output", default=".kmeans", help="output file if only k. Output directory if multiple kmax specified.") - parser_kmeans.add_argument("--mahalanobis", action="store_true") + parser_kmeans.add_argument("--modeltype", + required=True, + choices=KMEANS_METHODS, + help="type of model for learning") parser_kmeans.set_defaults(which="kmeans") # measure @@ -223,7 +227,7 @@ if __name__ == "__main__": parser_disequilibrium.add_argument("--lstrain", required=True, type=str, help="...") parser_disequilibrium.add_argument("--lstest", required=True, type=str, help="...") parser_disequilibrium.add_argument("--model", required=True, type=str, help="...") - parser_disequilibrium.add_argument("--model-type", + parser_disequilibrium.add_argument("--modeltype", required=True, choices=["kmeans", "2", "3"], help="...")