05-mmf_getscore.py 2.54 KB
import numpy as np 
import shelve 
import sys
import glob
from collections import defaultdict
from tinydb import TinyDB, Query
from mako.template import Template
import time 

def get_best(x):                                                                               
    argbest=np.argmax(x[1])                                                                    
    maxdev=x[1][argbest]
    maxtrain=np.max(x[0])
    maxtest=np.max(x[2])
    besttest=x[2][argbest]
    return ( maxtrain,maxdev,maxtest,besttest) 
depth = lambda L: isinstance(L, list) and max(map(depth, L))+1


template_name = '''
${name}
========================

MLP scores : 
-------------------
'''
template_value='''\n\n
| ${model} ${ttype}   | train    | dev       |max test| best test|
| -------------------:|:--------:|:---------:|:------:|:--------:|
% for cpt,line in enumerate(models[model][ttype]):
| ${cpt}      | ${line[0]} | ${line[1]}  |${line[2]} | ${line[3]} |
% endfor
\n
'''

# ae_model.shelve
def get_folder_file(x):
    folder=x.split("/")[1]
    shelve_file = ".".join(x.split(".")[:-1])
    return(folder,shelve_file)

in_folder = sys.argv[1]


models = defaultdict(dict)

ae_model_list = glob.glob("{}/*/ae_model.shelve.dir".format(in_folder))
ae_model_list = sorted(ae_model_list)
ae_model_list= map(get_folder_file,ae_model_list)
for name , shelve_file in ae_model_list :
    print Template(template_name).render(name=name)
    opened_shelve = shelve.open(shelve_file)
    keys = opened_shelve.keys() 
    if "LABEL" in keys :
        keys.remove("LABEL")
    if "params" in keys:
        keys.remove("params")
    to_print = []
    for working_key in keys:
        for key in opened_shelve[working_key].keys():
            table_depth = depth(opened_shelve[working_key][key])
            if table_depth == 3 :
                models[working_key][key] = [ get_best(x) for x in opened_shelve[working_key][key] ]
                to_print.append(Template(template_value).render(model=working_key,ttype=key,models=models).strip())
            elif table_depth == 2 :
                models[working_key][key] = [ get_best(opened_shelve[working_key][key]) ]
                to_print.append(Template(template_value).render(model=working_key,ttype=key,models=models).strip())
            elif table_depth == 4 : 
                for layer in opened_shelve[working_key][key] :
                    models[working_key][key] = [ get_best(x) for x in layer ]
                    to_print.append(Template(template_value).render(model=working_key,ttype=key,models=models).strip())
    print "\n".join(to_print)