02-lda_split.py 5.01 KB
import gensim
import os
import sys
import pickle
from gensim.models.ldamodel import  LdaModel
from gensim.models.ldamulticore import LdaMulticore
from collections import Counter
import numpy as np
import codecs
import shelve
import logging

def calc_perp(in_dir,train):
    name = in_dir.split("/")[-1]
    # s40_it1_sw50_a0.01_e0.1_p6_c1000
    sw_size = int(name.split("_")[2][2:])

    logging.warning(" go {} ".format(name))


    logging.warning("Redo Vocab and stop")
    asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
    trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
    asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
    trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
    stop_words=set(asr_sw) | set(trs_sw)

    logging.warning("TRS  to be done")
    entry = Query()
    value=db.search(entry.name == name)
    if len(value) > 0 :
        logging.warning("{} already done".format(name))
        return 

    dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]]
    lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir))
    perp_trs = lda_trs.log_perplexity(dev_trs)
    logging.warning("ASR  to be done")
    dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]]
    lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir))
    perp_asr = lda_asr.log_perplexity(dev_asr)
    logging.warning("ASR  saving")
    res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs}
    return res_dict




def train_lda(out_dir,train,name,size,it,sw_size,alpha,eta,passes,chunk):
    output_dir = "{}/s{}_it{}_sw{}_a{}_e{}_p{}_c{}".format(out_dir,size,it,sw_size,alpha,eta,passes,chunk)
    os.mkdir(output_dir)
    logging.info(output_dir+" to be done")
    asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y])
    trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y])
    asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ]
    trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ]
    stop_words=set(asr_sw) | set(trs_sw)

    logging.info("TRS  to be done")

    lda_trs = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=1000,iterations=it)

    logging.info("ASR  to be done")
    lda_asr = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=1000,iterations=it)

    #logger.info("ASR  saving")
    #lda_asr.save("{}/lda_asr.model".format(output_dir,name,size,it))
    #lda_trs.save("{}/lda_trs.model".format(output_dir,name,size,it))


    out_file_asr=codecs.open("{}/asr_wordTopic.txt".format(output_dir),"w","utf-8")
    out_file_trs=codecs.open("{}/trs_wordTopic.txt".format(output_dir),"w","utf-8")

    dico = train["vocab"]
    print >>out_file_asr, ",\t".join( [ dico[x] for x in range(len(train["vocab"]))])
    for line in lda_asr.expElogbeta:
        nline = line / np.sum(line)
        print >>out_file_asr, ",\t".join( str(x) for x in nline)
    out_file_asr.close()

    print >>out_file_trs, ",\t".join( [ dico[x] for x in range(len(train["vocab"]))])
    for line in lda_trs.expElogbeta:
        nline = line / np.sum(line)
        print >>out_file_trs, ",\t".join( str(x) for x in nline)
    out_file_trs.close()

    K = lda_asr.num_topics
    topicWordProbMat = lda_asr.print_topics(K,10)
    out_file_asr=codecs.open("{}/asr_best10.txt".format(output_dir),"w","utf-8")
    for i in topicWordProbMat:
        print >>out_file_asr,i
    out_file_asr.close()

    K = lda_trs.num_topics
    topicWordProbMat = lda_trs.print_topics(K,10)
    out_file_trs=codecs.open("{}/trs_best10.txt".format(output_dir),"w","utf-8")
    for i in topicWordProbMat:
        print >>out_file_trs,i
    out_file_trs.close()

if __name__ == "__main__": 
    logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING)

    input_shelve = sys.argv[1]
    output_dir = sys.argv[2]
    size = [ int(x) for x in sys.argv[3].split("_")]
    workers = int(sys.argv[4])
    name = sys.argv[5]
    it = [ int(x) for x in sys.argv[6].split("_")]
    sw_size = [ int(x) for x in sys.argv[7].split("_")]
    alpha = ["auto" , "symmetric"] + [ float(x) for x in sys.argv[8].split("_")]
    eta = ["auto"] + [ float(x) for x in sys.argv[9].split("_")]
    passes = [ int(x) for x in sys.argv[10].split("_")]
    chunk = [ int(x) for x in sys.argv[11].split("_")]

    #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir)))
    train = shelve.open(input_shelve)
    out_dir = "{}/{}".format(output_dir,name)
    os.mkdir(out_dir)

    for s in size: 
        for i in it :
            for sw in sw_size:
                for a in alpha:
                    for e in eta:
                        for p in passes:
                            for c in chunk: 
                                train_lda(out_dir,train,name,s,i,sw,a,e,p,c)