From adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8 Mon Sep 17 00:00:00 2001 From: Mathias Date: Mon, 14 Sep 2020 22:15:33 +0200 Subject: [PATCH] Save the kmeans model --- scripts/evaluations/clustering.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/scripts/evaluations/clustering.py b/scripts/evaluations/clustering.py index 765d950..72293e2 100644 --- a/scripts/evaluations/clustering.py +++ b/scripts/evaluations/clustering.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import os import time +import pickle from sklearn.preprocessing import LabelEncoder from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics import f1_score @@ -77,6 +78,9 @@ if __name__ == "__main__": estimator.fit(feats) print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") + with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f: + pickle.dump(estimator, f) + # contains distance to each cluster for each sample dist_space = estimator.transform(feats) predictions = np.argmin(dist_space, axis=1) -- 1.8.2.3