04c-mmf_sae.py 4.26 KB
# coding: utf-8

# In[2]:

# Import
import gensim
from scipy import sparse
import itertools
from sklearn import preprocessing
from keras.models import Sequential
from keras.optimizers import SGD,Adam
from mlp import *
import mlp
import sklearn.metrics
import shelve
import pickle
from utils import *
import sys
import os
import json
# In[4]:

infer_model=shelve.open("{}".format(sys.argv[2]))
in_dir = sys.argv[1]
#['ASR', 'TRS', 'LABEL']
# In[6]:
json_conf =json.load(open(sys.argv[3])) 
sae_conf = json_conf["sae"]

hidden_size= sae_conf["hidden_size"]
input_activation=sae_conf["input_activation"]
output_activation=sae_conf["output_activation"]
loss=sae_conf["loss"]
epochs=sae_conf["epochs"]
batch=sae_conf["batch"]
patience=sae_conf["patience"]
do_do=sae_conf["do"]

try:
    k = sae_conf["sgd"]
    if sae_conf["sgd"]["name"] == "adam":
        sgd = Adam(lr=sae_conf["sgd"]["lr"])
    elif sae_conf["sgd"]["name"] == "sgd":
        sgd = SGD(lr=sae_conf["sgd"]["lr"])
except :
    sgd = sae_conf["sgd"]

name = json_conf["name"]
try:
    os.mkdir("{}/{}".format(in_dir,name))
except:
    pass
db = shelve.open("{}/{}/ae_model.shelve".format(in_dir,name),writeback=True)
#
keys = ["ASR","TRS"]
mlp_conf = json_conf["mlp"]
mlp_h = mlp_conf["hidden_size"]
mlp_loss = mlp_conf["loss"]
mlp_dropouts = mlp_conf["do"]
mlp_epochs = mlp_conf["epochs"]
mlp_batch_size = mlp_conf["batch"]
mlp_input_activation=mlp_conf["input_activation"]
mlp_output_activation=mlp_conf["output_activation"]

try:
    k = mlp_conf["sgd"]
    if mlp_conf["sgd"]["name"] == "adam":
        mlp_sgd = Adam(lr=mlp_conf["sgd"]["lr"])
    elif mlp_conf["sgd"]["name"] == "sgd" :
        mlp_sgd = SGD(lr=mlp_conf["sgd"]["lr"])
except :
    mlp_sgd = mlp_conf["sgd"]


db["SAE"] = {}

db["SAEFT"] = {}
for mod in keys : 
    res_tuple=train_sae(infer_model["LDA"][mod]["TRAIN"],infer_model["LDA"][mod]["DEV"],
                 infer_model["LDA"][mod]["TEST"],
                 hidden_size,dropouts=do_do,
                 patience = "patience",sgd=sgd,input_activation="tanh",
                 output_activation="tanh",loss=loss,epochs=epochs,
                 batch_size=batch,verbose=0)
    #print len(res), [len(x) for x in res[0]], [ len(x) for x in res[1]]
    for name , levels in zip(["SAE","SAEFT"],res_tuple):
        mlp_res_by_level = []
        for res in levels:
            mlp_res_list=[]
            for nb,layer in enumerate(res) :
                mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"],
                                              layer[1],infer_model["LABEL"][mod]["DEV"],
                                              layer[2],infer_model["LABEL"][mod]["TEST"],
                                              mlp_h,loss=mlp_loss,dropouts=mlp_dropouts,
                                              sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size,
                                              fit_verbose=0))
            mlp_res_by_level.append(mlp_res_list)
        db[name][mod]=mlp_res_by_level

mod = "ASR"
mod2= "TRS"
res_tuple = train_sae(infer_model["LDA"][mod]["TRAIN"],
                      infer_model["LDA"][mod]["DEV"],
                      infer_model["LDA"][mod]["TEST"],
                      hidden_size,dropouts=[0],patience="patience",
                      sgd=sgd,input_activation=input_activation,output_activation=input_activation,
                      loss=loss,epochs=epochs,batch_size=batch,
                      y_train=infer_model["LDA"][mod2]["TRAIN"],
                      y_dev=infer_model["LDA"][mod2]["DEV"],
                      y_test=infer_model["LDA"][mod2]["TEST"])

for name , levels in zip(["SAE","SAEFT"],res_tuple):
    mlp_res_by_level = []
    for res in levels : 
        mlp_res_list=[]
        for layer in res :
            mlp_res_list.append(train_mlp(layer[0],infer_model["LABEL"][mod]["TRAIN"],
                                layer[1],infer_model["LABEL"][mod]["DEV"],layer[2],
                                infer_model["LABEL"][mod]["TEST"],
                                mlp_h,loss=mlp_loss,dropouts=mlp_dropouts,
                                sgd=mlp_sgd,epochs=mlp_epochs,batch_size=mlp_batch_size,
                                fit_verbose=0))
        mlp_res_by_level.append(mlp_res_list)
    db[name]["SPE"] = mlp_res_by_level

db.sync()
db.close()