Blame view
LDA/04c-mmf_sae.py
4.26 KB
7db73861f add vae et mmf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
# coding: utf-8 # In[2]: # Import import gensim from scipy import sparse import itertools from sklearn import preprocessing from keras.models import Sequential from keras.optimizers import SGD,Adam from mlp import * import mlp import sklearn.metrics import shelve import pickle from utils import * import sys import os import json # In[4]: infer_model=shelve.open("{}".format(sys.argv[2])) in_dir = sys.argv[1] #['ASR', 'TRS', 'LABEL'] # In[6]: |
e5108393c replace du mlp.p... |
28 29 |
json_conf =json.load(open(sys.argv[3])) sae_conf = json_conf["sae"] |
7db73861f add vae et mmf |
30 |
|
e5108393c replace du mlp.p... |
31 32 33 34 35 36 37 38 |
hidden_size= sae_conf["hidden_size"] input_activation=sae_conf["input_activation"] output_activation=sae_conf["output_activation"] loss=sae_conf["loss"] epochs=sae_conf["epochs"] batch=sae_conf["batch"] patience=sae_conf["patience"] do_do=sae_conf["do"] |
7db73861f add vae et mmf |
39 |
|
e5108393c replace du mlp.p... |
40 41 42 43 44 45 46 47 |
try: k = sae_conf["sgd"] if sae_conf["sgd"]["name"] == "adam": sgd = Adam(lr=sae_conf["sgd"]["lr"]) elif sae_conf["sgd"]["name"] == "sgd": sgd = SGD(lr=sae_conf["sgd"]["lr"]) except : sgd = sae_conf["sgd"] |
7db73861f add vae et mmf |
48 |
|
e5108393c replace du mlp.p... |
49 |
name = json_conf["name"] |
7db73861f add vae et mmf |
50 |
try: |
e5108393c replace du mlp.p... |
51 |
os.mkdir("{}/{}".format(in_dir,name)) |
7db73861f add vae et mmf |
52 53 |
except: pass |
e5108393c replace du mlp.p... |
54 |
db = shelve.open("{}/{}/ae_model.shelve".format(in_dir,name),writeback=True) |
7db73861f add vae et mmf |
55 |
# |
7db73861f add vae et mmf |
56 |
keys = ["ASR","TRS"] |
e5108393c replace du mlp.p... |
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
mlp_conf = json_conf["mlp"] mlp_h = mlp_conf["hidden_size"] mlp_loss = mlp_conf["loss"] mlp_dropouts = mlp_conf["do"] mlp_epochs = mlp_conf["epochs"] mlp_batch_size = mlp_conf["batch"] mlp_input_activation=mlp_conf["input_activation"] mlp_output_activation=mlp_conf["output_activation"] try: k = mlp_conf["sgd"] if mlp_conf["sgd"]["name"] == "adam": mlp_sgd = Adam(lr=mlp_conf["sgd"]["lr"]) elif mlp_conf["sgd"]["name"] == "sgd" : mlp_sgd = SGD(lr=mlp_conf["sgd"]["lr"]) except : mlp_sgd = mlp_conf["sgd"] |
7db73861f add vae et mmf |
74 |
|
7db73861f add vae et mmf |
75 76 77 78 79 |
db["SAE"] = {} db["SAEFT"] = {} for mod in keys : |
7db73861f add vae et mmf |
80 81 82 |
res_tuple=train_sae(infer_model["LDA"][mod]["TRAIN"],infer_model["LDA"][mod]["DEV"], infer_model["LDA"][mod]["TEST"], hidden_size,dropouts=do_do, |
e5108393c replace du mlp.p... |
83 |
patience = "patience",sgd=sgd,input_activation="tanh", |
7db73861f add vae et mmf |
84 85 86 87 |
output_activation="tanh",loss=loss,epochs=epochs, batch_size=batch,verbose=0) #print len(res), [len(x) for x in res[0]], [ len(x) for x in res[1]] for name , levels in zip(["SAE","SAEFT"],res_tuple): |
7db73861f add vae et mmf |
88 89 90 91 |
mlp_res_by_level = [] for res in levels: mlp_res_list=[] for nb,layer in enumerate(res) : |
7db73861f add vae et mmf |
92 93 94 95 96 97 98 99 100 101 102 |
mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"], layer[1],infer_model["LABEL"][mod]["DEV"], layer[2],infer_model["LABEL"][mod]["TEST"], mlp_h,loss=mlp_loss,dropouts=mlp_dropouts, sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size, fit_verbose=0)) mlp_res_by_level.append(mlp_res_list) db[name][mod]=mlp_res_by_level mod = "ASR" mod2= "TRS" |
7db73861f add vae et mmf |
103 104 105 |
res_tuple = train_sae(infer_model["LDA"][mod]["TRAIN"], infer_model["LDA"][mod]["DEV"], infer_model["LDA"][mod]["TEST"], |
e5108393c replace du mlp.p... |
106 |
hidden_size,dropouts=[0],patience="patience", |
7db73861f add vae et mmf |
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
sgd=sgd,input_activation=input_activation,output_activation=input_activation, loss=loss,epochs=epochs,batch_size=batch, y_train=infer_model["LDA"][mod2]["TRAIN"], y_dev=infer_model["LDA"][mod2]["DEV"], y_test=infer_model["LDA"][mod2]["TEST"]) for name , levels in zip(["SAE","SAEFT"],res_tuple): mlp_res_by_level = [] for res in levels : mlp_res_list=[] for layer in res : mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"], layer[1],infer_model["LABEL"][mod]["DEV"],layer[2], infer_model["LABEL"][mod]["TEST"], mlp_h,loss=mlp_loss,dropouts=mlp_dropouts, sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size, fit_verbose=0)) mlp_res_by_level.append(mlp_res_list) db[name]["SPE"] = mlp_res_by_level |
e5108393c replace du mlp.p... |
126 |
db.sync() |
7db73861f add vae et mmf |
127 |
db.close() |