Blame view
LDA/03-perplex.py
2.76 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import gensim import time import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import glob from tinydb import TinyDB, where, Query from itertools import izip_longest, repeat from multiprocessing import Pool def grouper(n, iterable, fillvalue=None): "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return izip_longest(fillvalue=fillvalue, *args) def calc_perp(params): |
7db73861f add vae et mmf |
25 26 27 28 |
try: in_dir,train = params name = in_dir.split("/")[-1] # s40_it1_sw50_a0.01_e0.1_p6_c1000 |
b6d0165d1 Initial commit |
29 |
|
7db73861f add vae et mmf |
30 31 32 33 34 |
entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return |
b6d0165d1 Initial commit |
35 |
|
7db73861f add vae et mmf |
36 |
sw_size = int(name.split("_")[2][2:]) |
b6d0165d1 Initial commit |
37 |
|
7db73861f add vae et mmf |
38 |
logging.warning(" go {} ".format(name)) |
b6d0165d1 Initial commit |
39 |
|
7db73861f add vae et mmf |
40 41 42 43 44 45 |
logging.warning("Redo Vocab and stop") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) |
b6d0165d1 Initial commit |
46 |
|
7db73861f add vae et mmf |
47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
logging.warning("TRS to be done") dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir)) perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir)) perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs} return res_dict except : return { "name" : name } |
b6d0165d1 Initial commit |
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
if __name__ == "__main__": input_shelve = sys.argv[1] input_dir = sys.argv[2] db_path = sys.argv[3] logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) folders = glob.glob("{}/*".format(input_dir)) #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = dict(shelve.open(input_shelve)) db = TinyDB(db_path) names = [ x["name"] for x in db.all()] p = Pool(processes=14,maxtasksperchild=10) s = time.time() perplexs = p.map(calc_perp,zip(folders,repeat(train,len(folders)))) for indx, perp in enumerate(perplexs) : if perp : db.insert(perp) e = time.time() print "FIN : {} : {}".format(indx,e-s) |