Commit 3b960e0f1923a5a9427417aa75b5fb2eef90657c
1 parent
d97ef08feb
Exists in
master
Clustering command allows you to compute kmeans specifying k, kmin and kmax or a list of k-values.
Showing 1 changed file with 90 additions and 0 deletions Inline Diff
volia/clustering.py
File was created | 1 | import argparse | |
2 | from os import path, mkdir | ||
3 | from utils import SubCommandRunner | ||
4 | from core.data import read_features, read_lst | ||
5 | |||
6 | import numpy as np | ||
7 | from sklearn.cluster import KMeans | ||
8 | import pickle | ||
9 | |||
10 | |||
11 | def kmeans_run(features: str, lst: str, k:int, kmax: int, klist, output: str): | ||
12 | """ | ||
13 | |||
14 | @param features: output features | ||
15 | @param lst: list file | ||
16 | @param k: k (kmin if kmax specified) | ||
17 | @param kmax: maximum k to compute | ||
18 | @param klist: list of k values to compute, ignore k value | ||
19 | @param output: output file if kmax not specified, else, output directory | ||
20 | """ | ||
21 | # -- READE FILES -- | ||
22 | features_dict = read_features(features) | ||
23 | lst_dict = read_lst(lst) | ||
24 | X = np.asarray([features_dict[x] for x in lst_dict]) | ||
25 | |||
26 | # Exception cases | ||
27 | if kmax is None and klist is None and path.isdir(output): | ||
28 | raise Exception("The \"output\" is an existing directory while the system is waiting the path of a file.") | ||
29 | |||
30 | if (kmax is not None or klist is not None) and path.isfile(output): | ||
31 | raise Exception("The \"output\" is an existing file while the system is waiting the path of a directory.") | ||
32 | |||
33 | # Mono value case | ||
34 | if kmax is None and klist is None: | ||
35 | print(f"Computing clustering with k={k}") | ||
36 | kmeans = KMeans(n_clusters=k, n_init=10, random_state=0).fit(X) | ||
37 | preds = kmeans.predict(X) | ||
38 | pickle.dump(kmeans, open(output, "wb")) | ||
39 | |||
40 | # Multi values case with kmax | ||
41 | if kmax is not None: | ||
42 | if not path.isdir(output): | ||
43 | mkdir(output) | ||
44 | Ks = range(k, kmax + 1) | ||
45 | for i in Ks: | ||
46 | print(f"Computing clustering with k={i}") | ||
47 | kmeans = KMeans(n_clusters=i, n_init=10, random_state=0).fit(X) | ||
48 | preds = kmeans.predict(X) | ||
49 | pickle.dump(kmeans, open(path.join(output, "clustering_" + str(i) + ".pkl"), "wb")) | ||
50 | |||
51 | # Second multi values case with klist | ||
52 | if klist is not None: | ||
53 | if not path.isdir(output): | ||
54 | mkdir(output) | ||
55 | for k in klist: | ||
56 | k = int(k) | ||
57 | print(f"Computing clustering with k={k}") | ||
58 | kmeans = KMeans(n_clusters=k, n_init=10, random_state=0).fit(X) | ||
59 | preds = kmeans.predict(X) | ||
60 | pickle.dump(kmeans, open(path.join(output, "clustering_" + str(k) + ".pkl"), "wb")) | ||
61 | |||
62 | |||
63 | if __name__ == "__main__": | ||
64 | # Main parser | ||
65 | parser = argparse.ArgumentParser(description="Clustering methods to apply") | ||
66 | subparsers = parser.add_subparsers(title="action") | ||
67 | |||
68 | # kmeans | ||
69 | parser_kmeans = subparsers.add_parser( | ||
70 | "kmeans", help="Compute clustering using k-means algorithm") | ||
71 | |||
72 | parser_kmeans.add_argument("--features", required=True, type=str, help="Features file (works with list)") | ||
73 | parser_kmeans.add_argument("--lst", required=True, type=str, help="List file (.lst)") | ||
74 | parser_kmeans.add_argument("-k", default=2, type=int, | ||
75 | help="number of clusters to compute. It is kmin if kmax is specified.") | ||
76 | parser_kmeans.add_argument("--kmax", default=None, type=int, help="if specified, k is kmin.") | ||
77 | parser_kmeans.add_argument("--klist", nargs="+", | ||
78 | help="List of k values to test. As kmax, activate the multi values mod.") | ||
79 | parser_kmeans.add_argument("--output", default=".kmeans", help="output file if only k. Output directory if multiple kmax specified.") | ||
80 | parser_kmeans.set_defaults(which="kmeans") | ||
81 | |||
82 | # Parse | ||
83 | args = parser.parse_args() | ||
84 | |||
85 | # Run commands | ||
86 | runner = SubCommandRunner({ | ||
87 | "kmeans": kmeans_run | ||
88 | }) | ||
89 | |||
90 | runner.run(args.which, args.__dict__, remove="which") |