Blame view
LDA/02-lda.py
5.35 KB
b6d0165d1 Initial commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import gensim import os import sys import pickle from gensim.models.ldamodel import LdaModel from gensim.models.ldamulticore import LdaMulticore from collections import Counter import numpy as np import codecs import shelve import logging import dill from tinydb import TinyDB, where, Query import time |
7db73861f add vae et mmf |
15 |
from joblib import Parallel, delayed |
b6d0165d1 Initial commit |
16 17 |
def calc_perp(models,train): |
7db73861f add vae et mmf |
18 |
|
b6d0165d1 Initial commit |
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
stop_words=models[1] name = models[0] logging.warning(" go {} ".format(name)) logging.warning("TRS to be done") entry = Query() value=db.search(entry.name == name) if len(value) > 0 : logging.warning("{} already done".format(name)) return dev_trs=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["DEV"]] lda_trs = models[2] perp_trs = lda_trs.log_perplexity(dev_trs) logging.warning("ASR to be done") dev_asr = [ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["DEV"]] lda_asr = models[5] perp_asr = lda_asr.log_perplexity(dev_asr) logging.warning("ASR saving") res_dict = {"name" : name, "asr" : perp_asr, "trs" : perp_trs } return res_dict def train_lda(out_dir,train,size,it,sw_size,alpha,eta,passes,chunk): name = "s{}_it{}_sw{}_a{}_e{}_p{}_c{}".format(size,it,sw_size,alpha,eta,passes,chunk) logging.warning(name) |
7db73861f add vae et mmf |
48 49 |
deep_out_dir = out_dir+"/"+name if os.path.isdir(deep_out_dir): |
b6d0165d1 Initial commit |
50 51 52 53 54 55 56 57 |
logging.error(name+" already done") return logging.warning(name+" to be done") asr_count=Counter([ x for y in train["ASR_wid"]["TRAIN"] for x in y]) trs_count=Counter([ x for y in train["TRS_wid"]["TRAIN"] for x in y]) asr_sw = [ x[0] for x in asr_count.most_common(sw_size) ] trs_sw = [ x[0] for x in trs_count.most_common(sw_size) ] stop_words=set(asr_sw) | set(trs_sw) |
b6d0165d1 Initial commit |
58 59 60 61 62 63 64 65 66 67 68 69 70 |
logging.warning("TRS to be done") lda_trs = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["TRS_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) logging.warning("ASR to be done") lda_asr = LdaModel(corpus=[ [ (x,y) for x,y in Counter(z).items() if x not in stop_words] for z in train["ASR_wid"]["TRAIN"]], id2word=train["vocab"], num_topics=int(size), chunksize=chunk,iterations=it,alpha=alpha,eta=eta,passes=passes) dico = train["vocab"] word_list = [ dico[x] for x in range(len(train["vocab"]))] asr_probs = [] for line in lda_asr.expElogbeta: nline = line / np.sum(line) |
7db73861f add vae et mmf |
71 |
asr_probs.append([ str(x) for x in nline]) |
b6d0165d1 Initial commit |
72 73 74 |
trs_probs = [] for line in lda_trs.expElogbeta: nline = line / np.sum(line) |
7db73861f add vae et mmf |
75 |
trs_probs.append([str(x) for x in nline]) |
b6d0165d1 Initial commit |
76 77 78 79 80 81 |
K = lda_asr.num_topics topicWordProbMat_asr = lda_asr.print_topics(K,10) K = lda_trs.num_topics topicWordProbMat_trs = lda_trs.print_topics(K,10) |
7db73861f add vae et mmf |
82 83 84 85 86 87 |
os.mkdir(deep_out_dir) dill.dump([x for x in stop_words],open(deep_out_dir+"/stopwords.dill","w")) lda_asr.save(deep_out_dir+"/lda_asr.model") lda_trs.save(deep_out_dir+"/lda_trs.model") dill.dump([x for x in asr_probs],open(deep_out_dir+"/lda_asr_probs.dill","w")) dill.dump([x for x in trs_probs],open(deep_out_dir+"/lda_trs_probs.dill","w")) |
b6d0165d1 Initial commit |
88 |
return [name, stop_words, lda_asr , asr_probs , topicWordProbMat_asr, lda_trs, trs_probs, topicWordProbMat_trs] |
7db73861f add vae et mmf |
89 90 91 92 93 94 95 96 97 98 99 100 |
def train_one(name,train,s,i,sw,a,e,p,c): st=time.time() logging.warning(" ; ".join([str(x) for x in [s,i,sw,a,e,p,c]])) models = train_lda(name,train,s,i,sw,a,e,p,c) if models: m = calc_perp(models,train) #dill.dump(models,open("{}/{}.dill".format(name,models[0]),"wb")) else : m = None e = time.time() logging.warning("fin en : {}".format(e-st)) return m |
b6d0165d1 Initial commit |
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
if __name__ == "__main__": logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.WARNING) input_shelve = sys.argv[1] db_path = sys.argv[2] size = [ int(x) for x in sys.argv[3].split("_")] workers = int(sys.argv[4]) name = sys.argv[5] it = [ int(x) for x in sys.argv[6].split("_")] sw_size = [ int(x) for x in sys.argv[7].split("_")] if sys.argv[8] != "None" : alpha = [ "symmetric", "auto" ] + [ float(x) for x in sys.argv[8].split("_")] eta = ["auto"] + [ float(x) for x in sys.argv[9].split("_")] else : alpha = ["symmetric"] eta = ["auto"] passes = [ int(x) for x in sys.argv[10].split("_")] chunk = [ int(x) for x in sys.argv[11].split("_")] #train=pickle.load(open("{}/newsgroup_bow_train.pk".format(input_dir))) train = shelve.open(input_shelve) try : os.mkdir(name) except : logging.warning(" folder already existe " ) db = TinyDB(db_path) nb_model = len(passes) * len(chunk) * len(it) * len(sw_size) * len(alpha) * len(eta) * len(size) logging.warning(" hey will train {} models ".format(nb_model)) |
7db73861f add vae et mmf |
129 130 |
args_list=[] |
b6d0165d1 Initial commit |
131 132 133 134 135 136 137 |
for p in passes: for c in chunk: for i in it : for sw in sw_size: for a in alpha: for e in eta: for s in size: |
7db73861f add vae et mmf |
138 139 140 141 142 |
args_list.append((name,train,s,i,sw,a,e,p,c)) res_list= Parallel(n_jobs=15)(delayed(train_one)(*args) for args in args_list) for m in res_list : db.insert(m) |