Commit adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8

Authored by Mathias
1 parent 85d6f0944e
Exists in master

Save the kmeans model

Showing 1 changed file with 4 additions and 0 deletions Side-by-side Diff

scripts/evaluations/clustering.py
... ... @@ -7,6 +7,7 @@
7 7 import pandas as pd
8 8 import os
9 9 import time
  10 +import pickle
10 11 from sklearn.preprocessing import LabelEncoder
11 12 from sklearn.metrics.pairwise import pairwise_distances
12 13 from sklearn.metrics import f1_score
... ... @@ -77,6 +78,9 @@
77 78 estimator.fit(feats)
78 79 print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}")
79 80  
  81 + with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f:
  82 + pickle.dump(estimator, f)
  83 +
80 84 # contains distance to each cluster for each sample
81 85 dist_space = estimator.transform(feats)
82 86 predictions = np.argmin(dist_space, axis=1)