Blame view
LDA/02-lda.py
4.95 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import gensim import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import dill from tinydb import TinyDB, where, Query import time def calc_perp(models,train): stop_words=models[1] name = models[0] logging.warning(" go {} ".format(name)) logging.warning("TRS to be done") entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = models[2] perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = models[5] perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs } return res_dict def train_lda(out_dir,train,size,it,sw_size,alpha,eta,passes,chunk): name = "s{}_it{}_sw{}_a{}_e{}_p{}_c{}".format(size,it,sw_size,alpha,eta,passes,chunk) logging.warning(name) if os.path.isfile(out_dir+"/"+name+".dill"): logging.error(name+" already done") return logging.warning(name+" to be done") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) stop_words=[ x.strip() for x in open("french.txt").readlines() ] logging.warning("TRS to be done") lda_trs = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) logging.warning("ASR to be done") lda_asr = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) dico = train["vocab"] word_list = [ dico[x] for x in range(len(train["vocab"]))] asr_probs = [] for line in lda_asr.expElogbeta: nline = line / np.sum(line) asr_probs.append( str(x) for x in nline) trs_probs = [] for line in lda_trs.expElogbeta: nline = line / np.sum(line) trs_probs.append( str(x) for x in nline) K = lda_asr.num_topics topicWordProbMat_asr = lda_asr.print_topics(K,10) K = lda_trs.num_topics topicWordProbMat_trs = lda_trs.print_topics(K,10) return [name, stop_words, lda_asr , asr_probs , topicWordProbMat_asr, lda_trs, trs_probs, topicWordProbMat_trs] if __name__ == "__main__": logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) input_shelve = sys.argv[1] db_path = sys.argv[2] size = [ int(x) for x in sys.argv[3].split("_")] workers = int(sys.argv[4]) name = sys.argv[5] it = [ int(x) for x in sys.argv[6].split("_")] sw_size = [ int(x) for x in sys.argv[7].split("_")] if sys.argv[8] != "None" : alpha = [ "symmetric", "auto" ] + [ float(x) for x in sys.argv[8].split("_")] eta = ["auto"] + [ float(x) for x in sys.argv[9].split("_")] else : alpha = ["symmetric"] eta = ["auto"] passes = [ int(x) for x in sys.argv[10].split("_")] chunk = [ int(x) for x in sys.argv[11].split("_")] #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = shelve.open(input_shelve) try : os.mkdir(name) except : logging.warning(" folder already existe " ) db = TinyDB(db_path) nb_model = len(passes) * len(chunk) * len(it) * len(sw_size) * len(alpha) * len(eta) * len(size) logging.warning(" hey will train {} models ".format(nb_model)) for p in passes: for c in chunk: for i in it : for sw in sw_size: for a in alpha: for e in eta: for s in size: st=time.time() logging.warning(" ; ".join([str(x) for x in [s,i,sw,a,e,p,c]])) models = train_lda(name,train,s,i,sw,a,e,p,c) if models: m = calc_perp(models,train) dill.dump(models,open("{}/{}.dill".format(name,models[0]),"wb")) db.insert(m) e = time.time() logging.warning("fin en : {}".format(e-st)) |