Blame view

bin/extract_kmeans.py 1.33 KB
ac78b07ea   Mathias Quillot   All base bin file...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  '''
  This script aims to extract k-means clustering from a 
  a priori trained k-means.
  '''
  
  import argparse
  import numpy as np
  import pickle
  from data import read_file, index_by_id, write_line
  import sys
  
  # -- ARGPARSE --
  parser = argparse.ArgumentParser(description="extract clusters")
  parser.add_argument("model", type=str, help="k-means model pickle")
  parser.add_argument("features", type=str, help="features")
993ea26cf   Mathias Quillot   take into account...
16
  parser.add_argument("--list", type=str, default=None, help="list file")
ac78b07ea   Mathias Quillot   All base bin file...
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  parser.add_argument("--outfile", type=str, default=None, help="output file std")
  
  args = vars(parser.parse_args())
  MODEL = args["model"]
  FEATURES = args["features"]
  LST = args["list"]
  OUTFILE = args["outfile"]
  
  if OUTFILE == None:
      OUTFILE = sys.stdout
  else:
      OUTFILE = open(OUTFILE, "w")
  
  # -- READ FILE --
  features = read_file(FEATURES)
  feat_ind = index_by_id(features)
993ea26cf   Mathias Quillot   take into account...
33
34
35
36
  if LST is not None:  
      lst = read_file(LST)
  else:
      lst = features
ac78b07ea   Mathias Quillot   All base bin file...
37
38
  
  kmeans = pickle.load(open(MODEL, "rb"))
ac78b07ea   Mathias Quillot   All base bin file...
39
40
  # -- CONVERT TO NUMPY --
  X = np.asarray([feat_ind[x[0][0]][x[0][3]][1] for x in lst])
993ea26cf   Mathias Quillot   take into account...
41

ac78b07ea   Mathias Quillot   All base bin file...
42
43
44
45
46
47
48
49
50
51
52
53
54
55
  predictions = kmeans.predict(X)
  
  for i, line in enumerate(lst):
      meta = line[0]
      meta[1] = str(predictions[i])
      write_line(
          meta,
          feat_ind[meta[0]][meta[3]][1],
          OUTFILE
      )
  
  # -- CLOSE OUT FILE IF NECESSARY --
  if not OUTFILE == sys.stdout:
      OUTFILE.close()