diff --git a/scripts/evaluations/clustering.py b/scripts/evaluations/clustering.py index 6b5e136..7c6e1eb 100644 --- a/scripts/evaluations/clustering.py +++ b/scripts/evaluations/clustering.py @@ -30,27 +30,34 @@ clustering to train in order to compute the average and the def train_clustering(label_encoder, feats, classes, outdir): num_classes = len(label_encoder.classes_) - - # Compute KMEANS clustering on data - kmeans_parameters = { - "n_clusters": num_classes, - "n_init": 100, - "tol": 10-6, - "algorithm": "elkan" - } - with open(os.path.join(outdir, f"{args.prefix}kmeans_parameters.json"), "w") as f: - json.dump(kmeans_parameters, f) - - # Save parameters - - estimator = KMeans( - **kmeans_parameters - ) - estimator.fit(feats) - print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") - - with open(os.path.join(outdir, f"{args.prefix}kmeans.pkl"), "wb") as f: - pickle.dump(estimator, f) + estimator = None + kmeans_filepath = os.path.join(outdir, f"{args.prefix}kmeans.pkl") + if args.onlymeasures: + print(f"Loading model: {kmeans_filepath}") + with open(kmeans_filepath, "rb") as f: + estimator = pickle.load(f) + else: + # Compute KMEANS clustering on data + print("Saving parameters") + kmeans_parameters = { + "n_clusters": num_classes, + "n_init": 100, + "tol": 10-6, + "algorithm": "elkan" + } + with open(os.path.join(outdir, f"{args.prefix}kmeans_parameters.json"), "w") as f: + json.dump(kmeans_parameters, f) + + # Fit the model and Save parameters + print(f"Fit the model: {kmeans_filepath}") + estimator = KMeans( + **kmeans_parameters + ) + estimator.fit(feats) + print(f"Kmeans: processed {estimator.n_iter_} iterations - intertia={estimator.inertia_}") + + with open(kmeans_filepath, "wb") as f: + pickle.dump(estimator, f) # contains distance to each cluster for each sample dist_space = estimator.transform(feats) @@ -179,7 +186,10 @@ if __name__ == "__main__": parser.add_argument("--nmodels", type=int, default=1, - help="specifies the number of models to train") + help="specifies the number of models to train") + parser.add_argument("--onlymeasures", + action='store_true', + help="Don't compute the clustering, compute only the measures") args = parser.parse_args() assert args.outdir