Blame view
LDA/02-lda_split.py
5.01 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gensim import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging def calc_perp(in_dir,train): name = in_dir.split("/")[-1] # s40_it1_sw50_a0.01_e0.1_p6_c1000 sw_size = int(name.split("_")[2][2:]) logging.warning(" go {} ".format(name)) logging.warning("Redo Vocab and stop") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) logging.warning("TRS to be done") entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = LdaModel.load("{}/lda_trs.model".format(in_dir)) perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = LdaModel.load("{}/lda_asr.model".format(in_dir)) perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs} return res_dict def train_lda(out_dir,train,name,size,it,sw_size,alpha,eta,passes,chunk): output_dir = "{}/s{}_it{}_sw{}_a{}_e{}_p{}_c{}".format(out_dir,size,it,sw_size,alpha,eta,passes,chunk) os.mkdir(output_dir) logging.info(output_dir+" to be done") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) logging.info("TRS to be done") lda_trs = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=1000,iterations=it) logging.info("ASR to be done") lda_asr = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=1000,iterations=it) #logger.info("ASR saving") #lda_asr.save("{}/lda_asr.model".format(output_dir,name,size,it)) #lda_trs.save("{}/lda_trs.model".format(output_dir,name,size,it)) out_file_asr=codecs.open("{}/asr_wordTopic.txt".format(output_dir),"w","utf-8") out_file_trs=codecs.open("{}/trs_wordTopic.txt".format(output_dir),"w","utf-8") dico = train["vocab"] print >>out_file_asr, ",\t".join( [ dico[x] for x in range(len(train["vocab"]))]) for line in lda_asr.expElogbeta: nline = line / np.sum(line) print >>out_file_asr, ",\t".join( str(x) for x in nline) out_file_asr.close() print >>out_file_trs, ",\t".join( [ dico[x] for x in range(len(train["vocab"]))]) for line in lda_trs.expElogbeta: nline = line / np.sum(line) print >>out_file_trs, ",\t".join( str(x) for x in nline) out_file_trs.close() K = lda_asr.num_topics topicWordProbMat = lda_asr.print_topics(K,10) out_file_asr=codecs.open("{}/asr_best10.txt".format(output_dir),"w","utf-8") for i in topicWordProbMat: print >>out_file_asr,i out_file_asr.close() K = lda_trs.num_topics topicWordProbMat = lda_trs.print_topics(K,10) out_file_trs=codecs.open("{}/trs_best10.txt".format(output_dir),"w","utf-8") for i in topicWordProbMat: print >>out_file_trs,i out_file_trs.close() if __name__ == "__main__": logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) input_shelve = sys.argv[1] output_dir = sys.argv[2] size = [ int(x) for x in sys.argv[3].split("_")] workers = int(sys.argv[4]) name = sys.argv[5] it = [ int(x) for x in sys.argv[6].split("_")] sw_size = [ int(x) for x in sys.argv[7].split("_")] alpha = ["auto" , "symmetric"] + [ float(x) for x in sys.argv[8].split("_")] eta = ["auto"] + [ float(x) for x in sys.argv[9].split("_")] passes = [ int(x) for x in sys.argv[10].split("_")] chunk = [ int(x) for x in sys.argv[11].split("_")] #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = shelve.open(input_shelve) out_dir = "{}/{}".format(output_dir,name) os.mkdir(out_dir) for s in size: for i in it : for sw in sw_size: for a in alpha: for e in eta: for p in passes: for c in chunk: train_lda(out_dir,train,name,s,i,sw,a,e,p,c) |