Blame view

LDA/03-perplex.py 2.76 KB
b6d0165d1   Killian   Initial commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
  import gensim
  import time
  import os
  import sys
  import pickle
  from gensim.models.ldamodel import  LdaModel
  from gensim.models.ldamulticore import LdaMulticore
  from collections import Counter
  import numpy as np
  import codecs
  import shelve
  import logging
  import glob
  from tinydb import TinyDB, where, Query
  from itertools import izip_longest, repeat
  from multiprocessing import Pool
  
  def grouper(n, iterable, fillvalue=None):
      "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
      args = [iter(iterable)] * n
      return izip_longest(fillvalue=fillvalue, *args)
  
  
  def calc_perp(params):
7db73861f   Killian   add vae et mmf
25
26
27
28
      try:
          in_dir,train = params
          name = in_dir.split("/")[-1]
          # s40_it1_sw50_a0.01_e0.1_p6_c1000
b6d0165d1   Killian   Initial commit
29

7db73861f   Killian   add vae et mmf
30
31
32
33
34
          entry = Query()
          value=db.search(entry.name == name)
          if len(value) > 0 :
              logging.warning("{} already done".format(name))
              return 
b6d0165d1   Killian   Initial commit
35

7db73861f   Killian   add vae et mmf
36
          sw_size = int(name.split("_")[2][2:])
b6d0165d1   Killian   Initial commit
37

7db73861f   Killian   add vae et mmf
38
          logging.warning(" go {} ".format(name))
b6d0165d1   Killian   Initial commit
39

7db73861f   Killian   add vae et mmf
40
41
42
43
44
45
          logging.warning("Redo Vocab and stop")
          asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
          trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
          asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
          trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
          stop_words=set(asr_sw) | set(trs_sw)
b6d0165d1   Killian   Initial commit
46

7db73861f   Killian   add vae et mmf
47
48
49
50
51
52
53
54
55
56
57
58
59
60
          logging.warning("TRS  to be done")
          
          dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]]
          lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir))
          perp_trs = lda_trs.log_perplexity(dev_trs)
          logging.warning("ASR  to be done")
          dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]]
          lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir))
          perp_asr = lda_asr.log_perplexity(dev_asr)
          logging.warning("ASR  saving")
          res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs}
          return res_dict
      except :
          return { "name" : name }
b6d0165d1   Killian   Initial commit
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  
  if __name__ == "__main__": 
      input_shelve = sys.argv[1]
      input_dir = sys.argv[2]
      db_path = sys.argv[3]
      logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING)
      folders = glob.glob("{}/*".format(input_dir))
  
      #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir)))
      train = dict(shelve.open(input_shelve))
      db  = TinyDB(db_path)
      names = [ x["name"] for x in db.all()]
      p = Pool(processes=14,maxtasksperchild=10)
  
      s = time.time()
      perplexs = p.map(calc_perp,zip(folders,repeat(train,len(folders))))
  
      for indx, perp in enumerate(perplexs) :
          if perp :
              db.insert(perp)
      e = time.time()
      print "FIN :  {} : {}".format(indx,e-s)