Commit adbca3b1ce8ad1cd5bc482b687715ee5e3b2d3d8
1 parent
85d6f0944e
Exists in
master
Save the kmeans model
Showing 1 changed file with 4 additions and 0 deletions Side-by-side Diff
scripts/evaluations/clustering.py
| ... | ... | @@ -7,6 +7,7 @@ |
| 7 | 7 | import pandas as pd |
| 8 | 8 | import os |
| 9 | 9 | import time |
| 10 | +import pickle | |
| 10 | 11 | from sklearn.preprocessing import LabelEncoder |
| 11 | 12 | from sklearn.metrics.pairwise import pairwise_distances |
| 12 | 13 | from sklearn.metrics import f1_score |
| ... | ... | @@ -77,6 +78,9 @@ |
| 77 | 78 | estimator.fit(feats) |
| 78 | 79 | print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") |
| 79 | 80 | |
| 81 | + with open(os.path.join(args.outdir, "kmeans.pkl"), "wb") as f: | |
| 82 | + pickle.dump(estimator, f) | |
| 83 | + | |
| 80 | 84 | # contains distance to each cluster for each sample |
| 81 | 85 | dist_space = estimator.transform(feats) |
| 82 | 86 | predictions = np.argmin(dist_space, axis=1) |