diff --git a/scripts/evaluations/clustering.py b/scripts/evaluations/clustering.py index 765d950..72293e2 100644 --- a/scripts/evaluations/clustering.py +++ b/scripts/evaluations/clustering.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd import os import time +import pickle from sklearn.preprocessing import LabelEncoder from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics import f1_score @@ -77,6 +78,9 @@ if __name__ == "__main__": estimator.fit(feats) print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") + with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f: + pickle.dump(estimator, f) + # contains distance to each cluster for each sample dist_space = estimator.transform(feats) predictions = np.argmin(dist_space, axis=1)