Commit 46812c2eac74cd2c7e2dc0410d9620450dbc3ed3
1 parent
57883c9873
Exists in
master
implementation of a extraction script of kmeans clustering from a kmeans model for skyrim corpus
Showing 1 changed file with 57 additions and 0 deletions Inline Diff
bin/extract_kmeans_skyrim.py
File was created | 1 | ''' | |
2 | This script aims to extract k-means clustering from a | ||
3 | a priori trained k-means. | ||
4 | ''' | ||
5 | |||
6 | import argparse | ||
7 | import numpy as np | ||
8 | import pickle | ||
9 | from data import read_file_skyrim, index_by_id_skyrim, write_line_skyrim | ||
10 | import sys | ||
11 | |||
12 | # -- ARGPARSE -- | ||
13 | parser = argparse.ArgumentParser(description="extract clusters") | ||
14 | parser.add_argument("model", type=str, help="k-means model pickle") | ||
15 | parser.add_argument("features", type=str, help="features") | ||
16 | parser.add_argument("--list", type=str, default=None, help="list file") | ||
17 | parser.add_argument("--outfile", type=str, default=None, help="output file std") | ||
18 | |||
19 | args = vars(parser.parse_args()) | ||
20 | MODEL = args["model"] | ||
21 | FEATURES = args["features"] | ||
22 | LST = args["list"] | ||
23 | OUTFILE = args["outfile"] | ||
24 | |||
25 | if OUTFILE == None: | ||
26 | OUTFILE = sys.stdout | ||
27 | else: | ||
28 | OUTFILE = open(OUTFILE, "w") | ||
29 | |||
30 | # -- READ FILE -- | ||
31 | features = read_file_skyrim(FEATURES) | ||
32 | feat_ind = index_by_id_skyrim(features) | ||
33 | |||
34 | if LST is not None: | ||
35 | lst = read_file(LST) | ||
36 | else: | ||
37 | lst = features | ||
38 | |||
39 | kmeans = pickle.load(open(MODEL, "rb")) | ||
40 | |||
41 | # -- CONVERT TO NUMPY -- | ||
42 | X = np.asarray([feat_ind[x[0][0]][x[0][2]][1] for x in lst]) | ||
43 | |||
44 | predictions = kmeans.predict(X) | ||
45 | |||
46 | for i, line in enumerate(lst): | ||
47 | meta = line[0] | ||
48 | meta[1] = str(predictions[i]) | ||
49 | write_line_skyrim( | ||
50 | meta, | ||
51 | feat_ind[meta[0]][meta[2]][1], | ||
52 | OUTFILE | ||
53 | ) | ||
54 | |||
55 | # -- CLOSE OUT FILE IF NECESSARY -- | ||
56 | if not OUTFILE == sys.stdout: | ||
57 | OUTFILE.close() |