Commit a9912f135f481a97c6113e5723b33d69de6a919d

Authored by quillotm
1 parent 05afc43e54
Exists in master

We can now precise the modeltype in parameter of the kmeans learning command. Th…

…is is more permissive to evolution.

Showing 1 changed file with 14 additions and 10 deletions Side-by-side Diff

... ... @@ -7,6 +7,7 @@
7 7 import pickle
8 8 from clustering_modules.kmeans import kmeans
9 9 from clustering_modules.kmeans_mahalanobis import kmeansMahalanobis
  10 +from clustering_modules.kmeans_multidistance import kmeansMultidistance
10 11  
11 12 from sklearn.preprocessing import LabelEncoder
12 13 from sklearn.metrics import v_measure_score, homogeneity_score, completeness_score
13 14  
... ... @@ -18,9 +19,13 @@
18 19 CLUSTERING_METHODS = {
19 20 "k-means": kmeans(),
20 21 "k-means-mahalanobis": kmeansMahalanobis(),
21   - "k-means-mahalanobis-constrained": kmeansMahalanobis(constrained=True)
  22 + "k-means-mahalanobis-constrained": kmeansMahalanobis(constrained=True),
  23 + "k-means-basic-mahalanobis": kmeansMultidistance(distance="mahalanobis"),
  24 + "k-means-basic-cosine": kmeansMultidistance(distance="cosine")
22 25 }
23 26  
  27 +KMEANS_METHODS = [key for key in CLUSTERING_METHODS if key.startswith("k-means")]
  28 +
24 29 EVALUATION_METHODS = {
25 30 "entropy": core.measures.entropy_score,
26 31 "purity": core.measures.purity_score,
... ... @@ -77,8 +82,8 @@
77 82 ninit: int,
78 83 output: str,
79 84 tol: float,
80   - debug: bool = False,
81   - mahalanobis: str = False):
  85 + modeltype: str,
  86 + debug: bool = False):
82 87 """
83 88  
84 89 @param features: output features
... ... @@ -94,11 +99,7 @@
94 99 def fit_model(k: int, output_file):
95 100 if debug:
96 101 print(f"Computing clustering with k={k}")
97   - model = CLUSTERING_METHODS["k-means"]
98   - if mahalanobis:
99   - if debug:
100   - print("Mahalanobis activated")
101   - model = CLUSTERING_METHODS["k-means-mahalanobis-constrained"]
  102 + model = CLUSTERING_METHODS[modeltype]
102 103 model.fit(X, k, tol, ninit, maxiter, debug)
103 104 model.save(output_file)
104 105 json_content["models"].append({
... ... @@ -193,7 +194,10 @@
193 194 parser_kmeans.add_argument("--output",
194 195 default=".kmeans",
195 196 help="output file if only k. Output directory if multiple kmax specified.")
196   - parser_kmeans.add_argument("--mahalanobis", action="store_true")
  197 + parser_kmeans.add_argument("--modeltype",
  198 + required=True,
  199 + choices=KMEANS_METHODS,
  200 + help="type of model for learning")
197 201 parser_kmeans.set_defaults(which="kmeans")
198 202  
199 203 # measure
... ... @@ -223,7 +227,7 @@
223 227 parser_disequilibrium.add_argument("--lstrain", required=True, type=str, help="...")
224 228 parser_disequilibrium.add_argument("--lstest", required=True, type=str, help="...")
225 229 parser_disequilibrium.add_argument("--model", required=True, type=str, help="...")
226   - parser_disequilibrium.add_argument("--model-type",
  230 + parser_disequilibrium.add_argument("--modeltype",
227 231 required=True,
228 232 choices=["kmeans", "2", "3"],
229 233 help="...")