Blame view
LDA/02-lda.py
5.35 KB
b6d0165d1
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import gensim import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import dill from tinydb import TinyDB, where, Query import time |
7db73861f
|
15 |
from joblib import Parallel, delayed |
b6d0165d1
|
16 17 |
def calc_perp(models,train): |
7db73861f
|
18 |
|
b6d0165d1
|
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
stop_words=models[1] name = models[0] logging.warning(" go {} ".format(name)) logging.warning("TRS to be done") entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = models[2] perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = models[5] perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs } return res_dict def train_lda(out_dir,train,size,it,sw_size,alpha,eta,passes,chunk): name = "s{}_it{}_sw{}_a{}_e{}_p{}_c{}".format(size,it,sw_size,alpha,eta,passes,chunk) logging.warning(name) |
7db73861f
|
48 49 |
deep_out_dir = out_dir+"/"+name if os.path.isdir(deep_out_dir): |
b6d0165d1
|
50 51 52 53 54 55 56 57 |
logging.error(name+" already done") return logging.warning(name+" to be done") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) |
b6d0165d1
|
58 59 60 61 62 63 64 65 66 67 68 69 70 |
logging.warning("TRS to be done") lda_trs = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) logging.warning("ASR to be done") lda_asr = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) dico = train["vocab"] word_list = [ dico[x] for x in range(len(train["vocab"]))] asr_probs = [] for line in lda_asr.expElogbeta: nline = line / np.sum(line) |
7db73861f
|
71 |
asr_probs.append([ str(x) for x in nline]) |
b6d0165d1
|
72 73 74 |
trs_probs = [] for line in lda_trs.expElogbeta: nline = line / np.sum(line) |
7db73861f
|
75 |
trs_probs.append([str(x) for x in nline]) |
b6d0165d1
|
76 77 78 79 80 81 |
K = lda_asr.num_topics topicWordProbMat_asr = lda_asr.print_topics(K,10) K = lda_trs.num_topics topicWordProbMat_trs = lda_trs.print_topics(K,10) |
7db73861f
|
82 83 84 85 86 87 |
os.mkdir(deep_out_dir) dill.dump([x for x in stop_words],open(deep_out_dir+"/stopwords.dill","w")) lda_asr.save(deep_out_dir+"/lda_asr.model") lda_trs.save(deep_out_dir+"/lda_trs.model") dill.dump([x for x in asr_probs],open(deep_out_dir+"/lda_asr_probs.dill","w")) dill.dump([x for x in trs_probs],open(deep_out_dir+"/lda_trs_probs.dill","w")) |
b6d0165d1
|
88 |
return [name, stop_words, lda_asr , asr_probs , topicWordProbMat_asr, lda_trs, trs_probs, topicWordProbMat_trs] |
7db73861f
|
89 90 91 92 93 94 95 96 97 98 99 100 |
def train_one(name,train,s,i,sw,a,e,p,c): st=time.time() logging.warning(" ; ".join([str(x) for x in [s,i,sw,a,e,p,c]])) models = train_lda(name,train,s,i,sw,a,e,p,c) if models: m = calc_perp(models,train) #dill.dump(models,open("{}/{}.dill".format(name,models[0]),"wb")) else : m = None e = time.time() logging.warning("fin en : {}".format(e-st)) return m |
b6d0165d1
|
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
if __name__ == "__main__": logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) input_shelve = sys.argv[1] db_path = sys.argv[2] size = [ int(x) for x in sys.argv[3].split("_")] workers = int(sys.argv[4]) name = sys.argv[5] it = [ int(x) for x in sys.argv[6].split("_")] sw_size = [ int(x) for x in sys.argv[7].split("_")] if sys.argv[8] != "None" : alpha = [ "symmetric", "auto" ] + [ float(x) for x in sys.argv[8].split("_")] eta = ["auto"] + [ float(x) for x in sys.argv[9].split("_")] else : alpha = ["symmetric"] eta = ["auto"] passes = [ int(x) for x in sys.argv[10].split("_")] chunk = [ int(x) for x in sys.argv[11].split("_")] #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = shelve.open(input_shelve) try : os.mkdir(name) except : logging.warning(" folder already existe " ) db = TinyDB(db_path) nb_model = len(passes) * len(chunk) * len(it) * len(sw_size) * len(alpha) * len(eta) * len(size) logging.warning(" hey will train {} models ".format(nb_model)) |
7db73861f
|
129 130 |
args_list=[] |
b6d0165d1
|
131 132 133 134 135 136 137 |
for p in passes: for c in chunk: for i in it : for sw in sw_size: for a in alpha: for e in eta: for s in size: |
7db73861f
|
138 139 140 141 142 |
args_list.append((name,train,s,i,sw,a,e,p,c)) res_list= Parallel(n_jobs=15)(delayed(train_one)(*args) for args in args_list) for m in res_list : db.insert(m) |