Blame view
LDA/03-mono_perplex.py
2.21 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gensim import time import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import glob from tinydb import TinyDB, where, Query def calc_perp(in_dir,train): name = in_dir.split("/")[-1] # s40_it1_sw50_a0.01_e0.1_p6_c1000 sw_size = int(name.split("_")[2][2:]) logging.warning(" go {} ".format(name)) logging.warning("Redo Vocab and stop") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) logging.warning("TRS to be done") entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir)) perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir)) perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs} return res_dict if __name__ == "__main__": input_shelve = sys.argv[1] input_dir = sys.argv[2] db_path = sys.argv[3] logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) |
7db73861f add vae et mmf |
55 |
folders = glob.glob("{}/s*".format(input_dir)) |
b6d0165d1 Initial commit |
56 57 58 59 60 61 62 63 64 65 66 |
#train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = shelve.open(input_shelve) db = TinyDB(db_path) for indx, folder in enumerate(folders) : s = time.time() r=calc_perp(folder,train) if r : db.insert(r) e = time.time() print "FIN : {} {} : {}".format(folder,indx,e-s) |