03-perplex.py
2.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gensim
import time
import os
import sys
import pickle
from gensim.models.ldamodel import LdaModel
from gensim.models.ldamulticore import LdaMulticore
from collections import Counter
import numpy as np
import codecs
import shelve
import logging
import glob
from tinydb import TinyDB, where, Query
from itertools import izip_longest, repeat
from multiprocessing import Pool
def grouper(n, iterable, fillvalue=None):
"grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
args = [iter(iterable)] * n
return izip_longest(fillvalue=fillvalue, *args)
def calc_perp(params):
try:
in_dir,train = params
name = in_dir.split("/")[-1]
# s40_it1_sw50_a0.01_e0.1_p6_c1000
entry = Query()
value=db.search(entry.name == name)
if len(value) > 0 :
logging.warning("{} already done".format(name))
return
sw_size = int(name.split("_")[2][2:])
logging.warning(" go {} ".format(name))
logging.warning("Redo Vocab and stop")
asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
stop_words=set(asr_sw) | set(trs_sw)
logging.warning("TRS to be done")
dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]]
lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir))
perp_trs = lda_trs.log_perplexity(dev_trs)
logging.warning("ASR to be done")
dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]]
lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir))
perp_asr = lda_asr.log_perplexity(dev_asr)
logging.warning("ASR saving")
res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs}
return res_dict
except :
return { "name" : name }
if __name__ == "__main__":
input_shelve = sys.argv[1]
input_dir = sys.argv[2]
db_path = sys.argv[3]
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING)
folders = glob.glob("{}/*".format(input_dir))
#train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir)))
train = dict(shelve.open(input_shelve))
db = TinyDB(db_path)
names = [ x["name"] for x in db.all()]
p = Pool(processes=14,maxtasksperchild=10)
s = time.time()
perplexs = p.map(calc_perp,zip(folders,repeat(train,len(folders))))
for indx, perp in enumerate(perplexs) :
if perp :
db.insert(perp)
e = time.time()
print "FIN : {} : {}".format(indx,e-s)