Blame view
LDA/03-perplex.py
2.6 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gensim import time import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import glob from tinydb import TinyDB, where, Query from itertools import izip_longest, repeat from multiprocessing import Pool def grouper(n, iterable, fillvalue=None): "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx" args = [iter(iterable)] * n return izip_longest(fillvalue=fillvalue, *args) def calc_perp(params): in_dir,train = params name = in_dir.split("/")[-1] # s40_it1_sw50_a0.01_e0.1_p6_c1000 entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return sw_size = int(name.split("_")[2][2:]) logging.warning(" go {} ".format(name)) logging.warning("Redo Vocab and stop") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) logging.warning("TRS to be done") dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir)) perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir)) perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs} return res_dict if __name__ == "__main__": input_shelve = sys.argv[1] input_dir = sys.argv[2] db_path = sys.argv[3] logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) folders = glob.glob("{}/*".format(input_dir)) #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = dict(shelve.open(input_shelve)) db = TinyDB(db_path) names = [ x["name"] for x in db.all()] p = Pool(processes=14,maxtasksperchild=10) s = time.time() perplexs = p.map(calc_perp,zip(folders,repeat(train,len(folders)))) for indx, perp in enumerate(perplexs) : if perp : db.insert(perp) e = time.time() print "FIN : {} : {}".format(indx,e-s) |