Blame view
LDA/04a-mmdf.py
2.99 KB
7db73861f add vae et mmf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
# coding: utf-8 # In[29]: # Import import itertools import shelve import pickle import numpy import scipy from scipy import sparse import scipy.sparse import scipy.io from mlp import * import mlp import sys import utils import dill from collections import Counter from gensim.models import LdaModel # In[3]: #30_50_50_150_0.0001 # In[4]: #db=shelve.open("SPELIKE_MLP_DB.shelve",writeback=True) origin_corps=shelve.open("{}".format(sys.argv[2])) in_dir = sys.argv[1] |
2af8e57f4 change all |
34 35 36 37 |
if len(sys.argv) > 3 : features_key = sys.argv[3] else : features_key = "LDA" |
7db73861f add vae et mmf |
38 39 40 41 42 43 44 45 |
out_db=shelve.open("{}/mlp_scores.shelve".format(in_dir),writeback=True) mlp_h = [ 250, 250 ] mlp_loss = "categorical_crossentropy" mlp_dropouts = [0.25]* len(mlp_h) mlp_sgd = Adam(lr=0.0001) mlp_epochs = 3000 |
2af8e57f4 change all |
46 |
mlp_batch_size = 5 |
7db73861f add vae et mmf |
47 48 49 50 |
mlp_input_activation = "relu" mlp_output_activation="softmax" ress = [] |
2af8e57f4 change all |
51 |
for key in origin_corps["features_key"].keys() : |
7db73861f add vae et mmf |
52 |
|
2af8e57f4 change all |
53 54 55 |
res=mlp.train_mlp(origin_corps[features_key][key]["TRAIN"],origin_corps["LABEL"][key]["TRAIN"], origin_corps[features_key][key]["DEV"],origin_corps["LABEL"][key]["DEV"], origin_corps[features_key][key]["TEST"],origin_corps["LABEL"][key]["TEST"], |
7db73861f add vae et mmf |
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
mlp_h,dropouts=mlp_dropouts,sgd=mlp_sgd, epochs=mlp_epochs, batch_size=mlp_batch_size, save_pred=False,keep_histo=False, loss="categorical_crossentropy",fit_verbose=0) arg_best=[] dev_best=[] arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 arg_best.append(numpy.argmax(res[1])) dev_best.append(res[1][arg_best[-1]]) res[1][arg_best[-1]]=0 test_best =[ res[2][x] for x in arg_best ] test_max = numpy.max(res[2]) out_db[key]=(res,(dev_best,test_best,test_max)) ress.append((key,dev_best,test_best,test_max)) |
e5108393c replace du mlp.p... |
107 |
print sys.argv[2] |
7db73861f add vae et mmf |
108 109 110 111 |
for el in ress : print el out_db.close() origin_corps.close() |