03-mono_perplex.py 2.21 KB
import gensim
import time
import os
import sys
import pickle
from gensim.models.ldamodel import  LdaModel
from gensim.models.ldamulticore import LdaMulticore
from collections import Counter
import numpy as np
import codecs
import shelve
import logging
import glob
from tinydb import TinyDB, where, Query


def calc_perp(in_dir,train):
    name = in_dir.split("/")[-1]
    # s40_it1_sw50_a0.01_e0.1_p6_c1000
    sw_size = int(name.split("_")[2][2:])

    logging.warning(" go {} ".format(name))


    logging.warning("Redo Vocab and stop")
    asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
    trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
    asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
    trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
    stop_words=set(asr_sw) | set(trs_sw)

    logging.warning("TRS  to be done")
    entry = Query()
    value=db.search(entry.name == name)
    if len(value) > 0 :
        logging.warning("{} already done".format(name))
        return 

    dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]]
    lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir))
    perp_trs = lda_trs.log_perplexity(dev_trs)
    logging.warning("ASR  to be done")
    dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]]
    lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir))
    perp_asr = lda_asr.log_perplexity(dev_asr)
    logging.warning("ASR  saving")
    res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs}
    return res_dict

if __name__ == "__main__": 
    input_shelve = sys.argv[1]
    input_dir = sys.argv[2]
    db_path = sys.argv[3]
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING)
    folders = glob.glob("{}/s*".format(input_dir))

    #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir)))
    train = shelve.open(input_shelve)
    db  = TinyDB(db_path)
    for indx, folder in enumerate(folders) :
        s = time.time()
        r=calc_perp(folder,train)
        if r :
            db.insert(r)
        e = time.time()
        print "FIN : {}  {} : {}".format(folder,indx,e-s)