Commit 46812c2eac74cd2c7e2dc0410d9620450dbc3ed3
1 parent
57883c9873
Exists in
master
implementation of a extraction script of kmeans clustering from a kmeans model for skyrim corpus
Showing 1 changed file with 57 additions and 0 deletions Side-by-side Diff
bin/extract_kmeans_skyrim.py
1 | +''' | |
2 | +This script aims to extract k-means clustering from a | |
3 | +a priori trained k-means. | |
4 | +''' | |
5 | + | |
6 | +import argparse | |
7 | +import numpy as np | |
8 | +import pickle | |
9 | +from data import read_file_skyrim, index_by_id_skyrim, write_line_skyrim | |
10 | +import sys | |
11 | + | |
12 | +# -- ARGPARSE -- | |
13 | +parser = argparse.ArgumentParser(description="extract clusters") | |
14 | +parser.add_argument("model", type=str, help="k-means model pickle") | |
15 | +parser.add_argument("features", type=str, help="features") | |
16 | +parser.add_argument("--list", type=str, default=None, help="list file") | |
17 | +parser.add_argument("--outfile", type=str, default=None, help="output file std") | |
18 | + | |
19 | +args = vars(parser.parse_args()) | |
20 | +MODEL = args["model"] | |
21 | +FEATURES = args["features"] | |
22 | +LST = args["list"] | |
23 | +OUTFILE = args["outfile"] | |
24 | + | |
25 | +if OUTFILE == None: | |
26 | + OUTFILE = sys.stdout | |
27 | +else: | |
28 | + OUTFILE = open(OUTFILE, "w") | |
29 | + | |
30 | +# -- READ FILE -- | |
31 | +features = read_file_skyrim(FEATURES) | |
32 | +feat_ind = index_by_id_skyrim(features) | |
33 | + | |
34 | +if LST is not None: | |
35 | + lst = read_file(LST) | |
36 | +else: | |
37 | + lst = features | |
38 | + | |
39 | +kmeans = pickle.load(open(MODEL, "rb")) | |
40 | + | |
41 | +# -- CONVERT TO NUMPY -- | |
42 | +X = np.asarray([feat_ind[x[0][0]][x[0][2]][1] for x in lst]) | |
43 | + | |
44 | +predictions = kmeans.predict(X) | |
45 | + | |
46 | +for i, line in enumerate(lst): | |
47 | + meta = line[0] | |
48 | + meta[1] = str(predictions[i]) | |
49 | + write_line_skyrim( | |
50 | + meta, | |
51 | + feat_ind[meta[0]][meta[2]][1], | |
52 | + OUTFILE | |
53 | + ) | |
54 | + | |
55 | +# -- CLOSE OUT FILE IF NECESSARY -- | |
56 | +if not OUTFILE == sys.stdout: | |
57 | + OUTFILE.close() |