03-perplex.py 2.6 KB
import gensim
import time
import os
import sys
import pickle
from gensim.models.ldamodel import  LdaModel
from gensim.models.ldamulticore import LdaMulticore
from collections import Counter
import numpy as np
import codecs
import shelve
import logging
import glob
from tinydb import TinyDB, where, Query
from itertools import izip_longest, repeat
from multiprocessing import Pool

def grouper(n, iterable, fillvalue=None):
    "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return izip_longest(fillvalue=fillvalue, *args)


def calc_perp(params):
    in_dir,train = params
    name = in_dir.split("/")[-1]
    # s40_it1_sw50_a0.01_e0.1_p6_c1000

    entry = Query()
    value=db.search(entry.name == name)
    if len(value) > 0 :
        logging.warning("{} already done".format(name))
        return 

    sw_size = int(name.split("_")[2][2:])

    logging.warning(" go {} ".format(name))


    logging.warning("Redo Vocab and stop")
    asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
    trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
    asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
    trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
    stop_words=set(asr_sw) | set(trs_sw)

    logging.warning("TRS  to be done")
    
    dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]]
    lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir))
    perp_trs = lda_trs.log_perplexity(dev_trs)
    logging.warning("ASR  to be done")
    dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]]
    lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir))
    perp_asr = lda_asr.log_perplexity(dev_asr)
    logging.warning("ASR  saving")
    res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs}
    return res_dict

if __name__ == "__main__": 
    input_shelve = sys.argv[1]
    input_dir = sys.argv[2]
    db_path = sys.argv[3]
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING)
    folders = glob.glob("{}/*".format(input_dir))

    #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir)))
    train = dict(shelve.open(input_shelve))
    db  = TinyDB(db_path)
    names = [ x["name"] for x in db.all()]
    p = Pool(processes=14,maxtasksperchild=10)

    s = time.time()
    perplexs = p.map(calc_perp,zip(folders,repeat(train,len(folders))))

    for indx, perp in enumerate(perplexs) :
        if perp :
            db.insert(perp)
    e = time.time()
    print "FIN :  {} : {}".format(indx,e-s)