Commit 6bc3b63707bac9240e9df369b071a6f764aa5d2f

Authored by Mathias
1 parent 3b7e63994c
Exists in master

Now, it saves kmeans parameters on the json file.

Showing 1 changed file with 13 additions and 4 deletions Side-by-side Diff

scripts/evaluations/clustering.py
... ... @@ -9,6 +9,7 @@
9 9 import time
10 10 import pickle
11 11 import csv
  12 +import json
12 13  
13 14 from sklearn.preprocessing import LabelEncoder
14 15 from sklearn.metrics.pairwise import pairwise_distances
15 16  
... ... @@ -31,11 +32,19 @@
31 32 num_classes = len(label_encoder.classes_)
32 33  
33 34 # Compute KMEANS clustering on data
  35 + kmeans_parameters = {
  36 + "n_clusters": num_classes,
  37 + "n_init": 100,
  38 + "tol": 10-6,
  39 + "algorithm": "elkan"
  40 + }
  41 + with open(os.path.join(outdir, f"{args.prefix}kmeans_parameters.json"), "w") as f:
  42 + json.dump(kmeans_parameters, f)
  43 +
  44 + # Save parameters
  45 +
34 46 estimator = KMeans(
35   - n_clusters=num_classes,
36   - n_init=100,
37   - tol=10-6,
38   - algorithm="elkan"
  47 + **kmeans_parameters
39 48 )
40 49 estimator.fit(feats)
41 50 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}")