Commit 3b960e0f1923a5a9427417aa75b5fb2eef90657c
1 parent
d97ef08feb
Exists in
master
Clustering command allows you to compute kmeans specifying k, kmin and kmax or a list of k-values.
Showing 1 changed file with 90 additions and 0 deletions Side-by-side Diff
volia/clustering.py
1 | +import argparse | |
2 | +from os import path, mkdir | |
3 | +from utils import SubCommandRunner | |
4 | +from core.data import read_features, read_lst | |
5 | + | |
6 | +import numpy as np | |
7 | +from sklearn.cluster import KMeans | |
8 | +import pickle | |
9 | + | |
10 | + | |
11 | +def kmeans_run(features: str, lst: str, k:int, kmax: int, klist, output: str): | |
12 | + """ | |
13 | + | |
14 | + @param features: output features | |
15 | + @param lst: list file | |
16 | + @param k: k (kmin if kmax specified) | |
17 | + @param kmax: maximum k to compute | |
18 | + @param klist: list of k values to compute, ignore k value | |
19 | + @param output: output file if kmax not specified, else, output directory | |
20 | + """ | |
21 | + # -- READE FILES -- | |
22 | + features_dict = read_features(features) | |
23 | + lst_dict = read_lst(lst) | |
24 | + X = np.asarray([features_dict[x] for x in lst_dict]) | |
25 | + | |
26 | + # Exception cases | |
27 | + if kmax is None and klist is None and path.isdir(output): | |
28 | + raise Exception("The \"output\" is an existing directory while the system is waiting the path of a file.") | |
29 | + | |
30 | + if (kmax is not None or klist is not None) and path.isfile(output): | |
31 | + raise Exception("The \"output\" is an existing file while the system is waiting the path of a directory.") | |
32 | + | |
33 | + # Mono value case | |
34 | + if kmax is None and klist is None: | |
35 | + print(f"Computing clustering with k={k}") | |
36 | + kmeans = KMeans(n_clusters=k, n_init=10, random_state=0).fit(X) | |
37 | + preds = kmeans.predict(X) | |
38 | + pickle.dump(kmeans, open(output, "wb")) | |
39 | + | |
40 | + # Multi values case with kmax | |
41 | + if kmax is not None: | |
42 | + if not path.isdir(output): | |
43 | + mkdir(output) | |
44 | + Ks = range(k, kmax + 1) | |
45 | + for i in Ks: | |
46 | + print(f"Computing clustering with k={i}") | |
47 | + kmeans = KMeans(n_clusters=i, n_init=10, random_state=0).fit(X) | |
48 | + preds = kmeans.predict(X) | |
49 | + pickle.dump(kmeans, open(path.join(output, "clustering_" + str(i) + ".pkl"), "wb")) | |
50 | + | |
51 | + # Second multi values case with klist | |
52 | + if klist is not None: | |
53 | + if not path.isdir(output): | |
54 | + mkdir(output) | |
55 | + for k in klist: | |
56 | + k = int(k) | |
57 | + print(f"Computing clustering with k={k}") | |
58 | + kmeans = KMeans(n_clusters=k, n_init=10, random_state=0).fit(X) | |
59 | + preds = kmeans.predict(X) | |
60 | + pickle.dump(kmeans, open(path.join(output, "clustering_" + str(k) + ".pkl"), "wb")) | |
61 | + | |
62 | + | |
63 | +if __name__ == "__main__": | |
64 | + # Main parser | |
65 | + parser = argparse.ArgumentParser(description="Clustering methods to apply") | |
66 | + subparsers = parser.add_subparsers(title="action") | |
67 | + | |
68 | + # kmeans | |
69 | + parser_kmeans = subparsers.add_parser( | |
70 | + "kmeans", help="Compute clustering using k-means algorithm") | |
71 | + | |
72 | + parser_kmeans.add_argument("--features", required=True, type=str, help="Features file (works with list)") | |
73 | + parser_kmeans.add_argument("--lst", required=True, type=str, help="List file (.lst)") | |
74 | + parser_kmeans.add_argument("-k", default=2, type=int, | |
75 | + help="number of clusters to compute. It is kmin if kmax is specified.") | |
76 | + parser_kmeans.add_argument("--kmax", default=None, type=int, help="if specified, k is kmin.") | |
77 | + parser_kmeans.add_argument("--klist", nargs="+", | |
78 | + help="List of k values to test. As kmax, activate the multi values mod.") | |
79 | + parser_kmeans.add_argument("--output", default=".kmeans", help="output file if only k. Output directory if multiple kmax specified.") | |
80 | + parser_kmeans.set_defaults(which="kmeans") | |
81 | + | |
82 | + # Parse | |
83 | + args = parser.parse_args() | |
84 | + | |
85 | + # Run commands | |
86 | + runner = SubCommandRunner({ | |
87 | + "kmeans": kmeans_run | |
88 | + }) | |
89 | + | |
90 | + runner.run(args.which, args.__dict__, remove="which") |