From 6bc3b63707bac9240e9df369b071a6f764aa5d2f Mon Sep 17 00:00:00 2001 From: Mathias Date: Mon, 28 Sep 2020 14:35:34 +0200 Subject: [PATCH] Now, it saves kmeans parameters on the json file. --- scripts/evaluations/clustering.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/evaluations/clustering.py b/scripts/evaluations/clustering.py index 7eba2ae..6b5e136 100644 --- a/scripts/evaluations/clustering.py +++ b/scripts/evaluations/clustering.py @@ -9,6 +9,7 @@ import os import time import pickle import csv +import json from sklearn.preprocessing import LabelEncoder from sklearn.metrics.pairwise import pairwise_distances @@ -31,11 +32,19 @@ def train_clustering(label_encoder, feats, classes, outdir): num_classes = len(label_encoder.classes_) # Compute KMEANS clustering on data + kmeans_parameters = { + "n_clusters": num_classes, + "n_init": 100, + "tol": 10-6, + "algorithm": "elkan" + } + with open(os.path.join(outdir, f"{args.prefix}kmeans_parameters.json"), "w") as f: + json.dump(kmeans_parameters, f) + + # Save parameters + estimator = KMeans( - n_clusters=num_classes, - n_init=100, - tol=10-6, - algorithm="elkan" + **kmeans_parameters ) estimator.fit(feats) print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") -- 1.8.2.3