Commit adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8
1 parent
85d6f0944e
Exists in
master
Save the kmeans model
Showing 1 changed file with 4 additions and 0 deletions Side-by-side Diff
scripts/evaluations/clustering.py
... | ... | @@ -7,6 +7,7 @@ |
7 | 7 | import pandas as pd |
8 | 8 | import os |
9 | 9 | import time |
10 | +import pickle | |
10 | 11 | from sklearn.preprocessing import LabelEncoder |
11 | 12 | from sklearn.metrics.pairwise import pairwise_distances |
12 | 13 | from sklearn.metrics import f1_score |
... | ... | @@ -77,6 +78,9 @@ |
77 | 78 | estimator.fit(feats) |
78 | 79 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") |
79 | 80 | |
81 | + with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f: | |
82 | + pickle.dump(estimator, f) | |
83 | + | |
80 | 84 | # contains distance to each cluster for each sample |
81 | 85 | dist_space = estimator.transform(feats) |
82 | 86 | predictions = np.argmin(dist_space, axis=1) |