04a-mlp.py 3.28 KB
# coding: utf-8

# In[29]:

# Import
import itertools
import shelve
import pickle
import numpy
import scipy
from scipy import sparse
import scipy.sparse
import scipy.io
from mlp import *
import mlp
import sys
import utils
import dill
from collections import Counter
from gensim.models import LdaModel



# In[3]:

#30_50_50_150_0.0001

# In[4]:

#db=shelve.open("SPELIKE_MLP_DB.shelve",writeback=True)
in_dir = sys.argv[1]
origin_corps = shelve.open(sys.argv[2])

## ['vocab',
#'ASR_AE_OUT_RELU',
#'ASR_AE_H2_RELU',
#'ASR_H1_TRANSFORMED_W2_RELU',
#'ASR_AE_H1_RELU',
#'ASR_H1_TRANFORMED_OUT_RELU',
#'ASR_H1_TRANFORMED_TRSH2_RELU',
#'TRS_AE_H2_RELU',
#'ASR_H2_TRANSFORMED_W1_RELU',
#'ASR_H2_TRANSFORMED_W2_RELU',
#'TRS_AE_H1_RELU',
#'ASR_H2_TRANFORMED_OUT_RELU',
#'ASR_SPARSE',
#'ASR_H2_TRANFORMED_TRSH2_RELU',
#'ASR_H1_TRANSFORMED_W1_RELU',
#'TRS_AE_OUT_RELU']
##
#
# [ 'vocab', 'LABEL', 'TRS_SPARSE', 'ASR_SPARSE'] 

out_db=shelve.open("{}/mlp_scores.shelve".format(in_dir),writeback=True)

infer_db=shelve.open("{}/infer.shelve".format(in_dir),writeback=True)
#lb=LabelBinarizer()
#y_train=lb.fit_transform([utils.select(ligneid) for ligneid in origin_corps["LABEL"]["TRAIN"]])
#y_dev=lb.transform([utils.select(ligneid) for ligneid in origin_corps["LABEL"]["DEV"]])
#y_test=lb.transform([utils.select(ligneid) for ligneid in origin_corps["LABEL"]["TEST"]])


y_train=origin_corps["LABEL"]["TRAIN"]
y_dev= origin_corps["LABEL"]["DEV"]
y_test=origin_corps["LABEL"]["TEST"]

sw =dill.load(open("{}/stopwords.dill".format(in_dir))) # stop words
LDAs={}
LDAs["ASR"] = LdaModel.load("{}/lda_asr.model".format(in_dir))
LDAs["TRS"] = LdaModel.load("{}/lda_trs.model".format(in_dir))

data = {"RAW":{"ASR":{},"TRS":{}},"LDA":{"ASR":{},"TRS":{}}}
data["RAW"]["ASR"]["TRAIN"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["ASR_wid"]["TRAIN"] ] 
data["RAW"]["ASR"]["DEV"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["ASR_wid"]["DEV"] ] 
data["RAW"]["ASR"]["TEST"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["ASR_wid"]["TEST"] ] 


data["RAW"]["TRS"]["TRAIN"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["TRS_wid"]["TRAIN"] ] 
data["RAW"]["TRS"]["DEV"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["TRS_wid"]["DEV"] ] 
data["RAW"]["TRS"]["TEST"] =  [[ (x,y) for x,y in Counter(z).items() if x not in sw ] for z in origin_corps["TRS_wid"]["TEST"] ] 

nb_epochs=500
for key in ["TRS", "ASR"] :
    for  corp_key in data["RAW"][key].keys():
        data["LDA"][key][corp_key]= \
        LDAs[key].inference(
                            data["RAW"][key][corp_key])[0]

    res=mlp.train_mlp(data["LDA"][key]["TRAIN"],y_train,data["LDA"][key]["DEV"],y_dev,data["LDA"][key]["TEST"],y_test,[40,25,40],dropouts=[0,0,0,0],sgd=Adam(lr=0.0001),epochs=nb_epochs,batch_size=8,save_pred=False,keep_histo=False,loss="categorical_crossentropy",fit_verbose=0)
    arg_best=numpy.argmax(res[1])
    dev_best = res[1][arg_best]
    test_best = res[2][arg_best]
    out_db[key]=(res,(dev_best,test_best))
    print in_dir,dev_best,test_best
    

for k,v in data.items():
    infer_db[k] = v 

for key in out_db.keys():
    print key,out_db[key][1]
out_db.close()
infer_db.close()
origin_corps.close()