Blame view
LDA/04c-mmf_sae.py
5.96 KB
7db73861f add vae et mmf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# coding: utf-8 # In[2]: # Import import gensim from scipy import sparse import itertools from sklearn import preprocessing from keras.models import Sequential from keras.optimizers import SGD,Adam from mlp import * import mlp import sklearn.metrics import shelve import pickle from utils import * import sys import os import json # In[4]: infer_model=shelve.open("{}".format(sys.argv[2])) in_dir = sys.argv[1] |
2af8e57f4 change all |
26 27 28 29 30 |
if len(sys.argv) > 4 : features_key = sys.argv[4] else : features_key = "LDA" save_projection = True |
7db73861f add vae et mmf |
31 32 |
#['ASR', 'TRS', 'LABEL'] # In[6]: |
e5108393c replace du mlp.p... |
33 34 |
json_conf =json.load(open(sys.argv[3])) sae_conf = json_conf["sae"] |
7db73861f add vae et mmf |
35 |
|
e5108393c replace du mlp.p... |
36 37 38 39 40 41 42 43 |
hidden_size= sae_conf["hidden_size"] input_activation=sae_conf["input_activation"] output_activation=sae_conf["output_activation"] loss=sae_conf["loss"] epochs=sae_conf["epochs"] batch=sae_conf["batch"] patience=sae_conf["patience"] do_do=sae_conf["do"] |
7db73861f add vae et mmf |
44 |
|
e5108393c replace du mlp.p... |
45 46 47 48 49 50 51 52 |
try: k = sae_conf["sgd"] if sae_conf["sgd"]["name"] == "adam": sgd = Adam(lr=sae_conf["sgd"]["lr"]) elif sae_conf["sgd"]["name"] == "sgd": sgd = SGD(lr=sae_conf["sgd"]["lr"]) except : sgd = sae_conf["sgd"] |
7db73861f add vae et mmf |
53 |
|
e5108393c replace du mlp.p... |
54 |
name = json_conf["name"] |
2af8e57f4 change all |
55 |
print name |
7db73861f add vae et mmf |
56 |
try: |
e5108393c replace du mlp.p... |
57 |
os.mkdir("{}/{}".format(in_dir,name)) |
7db73861f add vae et mmf |
58 59 |
except: pass |
e5108393c replace du mlp.p... |
60 |
db = shelve.open("{}/{}/ae_model.shelve".format(in_dir,name),writeback=True) |
7db73861f add vae et mmf |
61 |
# |
e5108393c replace du mlp.p... |
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
mlp_conf = json_conf["mlp"] mlp_h = mlp_conf["hidden_size"] mlp_loss = mlp_conf["loss"] mlp_dropouts = mlp_conf["do"] mlp_epochs = mlp_conf["epochs"] mlp_batch_size = mlp_conf["batch"] mlp_input_activation=mlp_conf["input_activation"] mlp_output_activation=mlp_conf["output_activation"] try: k = mlp_conf["sgd"] if mlp_conf["sgd"]["name"] == "adam": mlp_sgd = Adam(lr=mlp_conf["sgd"]["lr"]) elif mlp_conf["sgd"]["name"] == "sgd" : mlp_sgd = SGD(lr=mlp_conf["sgd"]["lr"]) except : mlp_sgd = mlp_conf["sgd"] |
7db73861f add vae et mmf |
79 |
|
2af8e57f4 change all |
80 |
keys = infer_model[features_key].keys() |
7db73861f add vae et mmf |
81 82 83 84 |
db["SAE"] = {} db["SAEFT"] = {} for mod in keys : |
2af8e57f4 change all |
85 86 |
res_tuple=train_sae(infer_model[features_key][mod]["TRAIN"],infer_model[features_key][mod]["DEV"], infer_model[features_key][mod]["TEST"], |
7db73861f add vae et mmf |
87 |
hidden_size,dropouts=do_do, |
e5108393c replace du mlp.p... |
88 |
patience = "patience",sgd=sgd,input_activation="tanh", |
7db73861f add vae et mmf |
89 90 91 |
output_activation="tanh",loss=loss,epochs=epochs, batch_size=batch,verbose=0) #print len(res), [len(x) for x in res[0]], [ len(x) for x in res[1]] |
2af8e57f4 change all |
92 |
for i, levels in zip(["SAE","SAEFT"],res_tuple): |
7db73861f add vae et mmf |
93 |
mlp_res_by_level = [] |
2af8e57f4 change all |
94 |
for lvl,res in enumerate(levels): |
7db73861f add vae et mmf |
95 96 |
mlp_res_list=[] for nb,layer in enumerate(res) : |
2af8e57f4 change all |
97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
if save_projection: pd = pandas.DataFrame(layer[0]) col_count= (pd.sum(axis=0) != 0) pd = pd.loc[:,col_count] hdffile = "{}/{}/{}_{}_{}_df.hdf".format(in_dir,name,i,lvl,nb,mod) print hdffile pd.to_hdf(hdffile,"TRAIN") pd = pandas.DataFrame(layer[1]) pd = pd.loc[:,col_count] pd.to_hdf(hdffile,"DEV") pd = pandas.DataFrame(layer[2]) pd = pd.loc[:,col_count] pd.to_hdf(hdffile,"TEST") del pd |
7db73861f add vae et mmf |
111 112 113 114 115 116 117 |
mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"], layer[1],infer_model["LABEL"][mod]["DEV"], layer[2],infer_model["LABEL"][mod]["TEST"], mlp_h,loss=mlp_loss,dropouts=mlp_dropouts, sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size, fit_verbose=0)) mlp_res_by_level.append(mlp_res_list) |
2af8e57f4 change all |
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
db[i][mod]=mlp_res_by_level if "ASR" in keys and "TRS" in keys : mod = "ASR" mod2= "TRS" res_tuple = train_sae(infer_model[features_key][mod]["TRAIN"], infer_model[features_key][mod]["DEV"], infer_model[features_key][mod]["TEST"], hidden_size,dropouts=[0],patience="patience", sgd=sgd,input_activation=input_activation,output_activation=input_activation, loss=loss,epochs=epochs,batch_size=batch, y_train=infer_model[features_key][mod2]["TRAIN"], y_dev=infer_model[features_key][mod2]["DEV"], y_test=infer_model[features_key][mod2]["TEST"]) for i , levels in zip(["SAE","SAEFT"],res_tuple): mlp_res_by_level = [] for lvl,res in enumerate(levels) : mlp_res_list=[] for nb,layer in enumerate(res) : if save_projection: pd = pandas.DataFrame(layer[0]) col_count= (pd.sum(axis=0) != 0) pd = pd.loc[:,col_count] pd.to_hdf("{}/{}/{}_{}_{}_{}_df.hdf".format(in_dir,name,i,lvl,nb,"SPE"),"TRAIN") pd = pandas.DataFrame(layer[1]) pd = pd.loc[:,col_count] pd.to_hdf("{}/{}/{}_{}_{}_{}_df.hdf".format(in_dir,name,i,lvl,nb,"SPE"),"DEV") pd = pandas.DataFrame(layer[2]) pd = pd.loc[:,col_count] pd.to_hdf("{}/{}/{}_{}_{}_{}_df.hdf".format(in_dir,name,i,lvl,nb,"SPE"),"TEST") del pd mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"], layer[1],infer_model["LABEL"][mod]["DEV"],layer[2], infer_model["LABEL"][mod]["TEST"], mlp_h,loss=mlp_loss,dropouts=mlp_dropouts, sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size, fit_verbose=0)) mlp_res_by_level.append(mlp_res_list) db[i]["SPE"] = mlp_res_by_level |
7db73861f add vae et mmf |
160 |
|
e5108393c replace du mlp.p... |
161 |
db.sync() |
7db73861f add vae et mmf |
162 |
db.close() |