from keras.models import Sequential
from keras.layers.core import TimeDistributedDense
from keras.layers.advanced_activations import PReLU
from keras.layers.normalization import BatchNormalization
from keras import backend as K
#from keras.layers.core import Dense, Dropout, Activation, Reshape, Flatten
from keras.layers import TimeDistributed, Lambda, Input, merge, Bidirectional
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D, Conv1D, MaxPooling1D
#from keras.layers.convolutional import Convolution1D, MaxPooling1D, AveragePooling1D
from keras.optimizers import SGD, Adadelta, Adam, Adamax, RMSprop
from keras.models import Model
#from keras.layers.recurrent import LSTM, GRU, SimpleRNN
from keras.constraints import maxnorm
from keras.callbacks import Callback, EarlyStopping
from keras.preprocessing import sequence
from keras.layers.embeddings import Embedding
from keras.regularizers import l2, activity_l2
from keras import regularizers
from keras.models import model_from_json
import theano
from theano import tensor
import warnings
import sys
import time
import os
import numpy as np
import fileinput
import math

warnings.filterwarnings("ignore")

def create_class_weight(labels_dict,mu=0.55):
    total = np.sum(labels_dict.values())
    keys = labels_dict.keys()
    class_weight = dict()

    for key in keys:
        score = math.log(mu*total/float(labels_dict[key]))
        class_weight[key] = score if score > 1.0 else 1.0

    return class_weight


def score_deft_2017(pred, gold, task):
    t = ConfusionMatrix()
    for x in range(0, len(gold)):
        a = pred[x].tolist()
        b = gold[x].tolist()
        t.store(a.index(max(a)), b.index(max(b)))
    return t.score_deft_2017(task)


def score_semeval_2017(pred, gold):
    t = ConfusionMatrix()
    for x in range(0, len(gold)):
        a = pred[x].tolist()
        b = gold[x].tolist()
        t.store(a.index(max(a)), b.index(max(b)))
    return t.score_semeval_2017()

def score_semeval_2016(pred, gold):
    t = ConfusionMatrix()
    for x in range(0, len(gold)):
        a = pred[x].tolist()
        b = gold[x].tolist()
        t.store(a.index(max(a)), b.index(max(b)))
    return t.score_semeval_2016()




class validation_semeval(Callback):

    def __init__(self, dev_files, word_word2vec, maxlen, output, task, patience=5):
        super(Callback, self).__init__()
        self.best_result = -1.0
        self.best_round = 1
        self.wait = 0
        self.output = output
        self.patience = patience
        self.counter = 0
        self.maxlen = maxlen
        self.dev_files = dev_files
        self.word_word2vec = word_word2vec
        self.task = task

    
    def on_epoch_end(self, epoch, logs={}):
        self.counter += 1
        current = 0

        id_train, X_word_train, y_train = read_sentiment(self.dev_files, self.word_word2vec, self.task)
        X_word_train = sequence.pad_sequences(X_word_train, maxlen=self.maxlen, padding='post', truncating='post')
        #predict = self.model.predict([X_word_train, X_highlevel_train], batch_size=1024, verbose=1)
        predict = self.model.predict(X_word_train, batch_size=1024, verbose=1)
        current = score_deft_2017(predict, y_train, self.task)
            
        if self.best_result < current:
            self.best_result = current
            self.best_round = self.counter
            self.wait = 0

            json_string = self.model.to_json()
            open(self.output+".json", 'w').write(json_string)
            self.model.save_weights(self.output+".h5", overwrite=True)


        print "\n\n"
        print "Dev score 2017 : "+str(current)
        print "Best score 2017 : "+str(self.best_result)
        print "\n\n"



class ConfusionMatrix(object):

    def __init__(self):
        self.h = {}
        self.total = 0

    def store(self, actual, truth):
        if actual not in self.h.keys():
            self.h[ actual ] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
        if truth not in self.h.keys():
            self.h[ truth ] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}

        if actual == truth:
            self.h[actual]["tp"] += 1
        else:
            self.h[actual]["fp"] += 1
            self.h[truth]["fn"] += 1
            self.h[truth]["tn"] += 1
            self.total += 1

    def recall(self, name):
        t = self.h[name]["tp"] + self.h[name]["fn"]
    
        if t == 0:
            return 0
        return ( (float)(self.h[name]["tp"]) / (float)( t ) )


    def precision(self, name):
        t = self.h[name]["tp"] + self.h[name]["fp"]
        if t == 0:
            return 0
        return ( (float)(self.h[name]["tp"]) / (float)( t ) )


    def fscore(self, name):
        t = self.precision(name) + self.recall(name)
        if t == 0:
            return 0
        return (2 * self.precision(name) * self.recall(name) ) / (t )

    def score_semeval_2016(self):
        return ( self.fscore(0) + self.fscore(1) ) / 2

    def score_semeval_2017(self):
        return ( self.recall(0) + self.recall(1) + self.recall(2)  ) / 3

    def score_deft_2017(self, task):
        print self.h

        if task == "task1":
            print "negative : "+str(self.fscore(0))
            print "positive : "+str(self.fscore(1))
            print "objective : "+str(self.fscore(2))
            print "mixed : "+str(self.fscore(3))
            return (self.fscore(0) + self.fscore(1) + self.fscore(2) + self.fscore(3) ) / 4

        if task == "task3":
            print "negative : "+str(self.fscore(0))
            print "positive : "+str(self.fscore(1))
            print "objective : "+str(self.fscore(2))
            print "mixed : "+str(self.fscore(3))
            return (self.fscore(0) + self.fscore(1) + self.fscore(2) + self.fscore(3) ) / 4

        if task == "task2":
            print "figurative : "+str(self.fscore(0))
            print "nonfigurative : "+str(self.fscore(1))
            return (self.fscore(0) + self.fscore(1) ) / 2



    def info(self):
        print self.h


def read_word2vec(embedding_file):
    ar = []
    dico = {}
    size = 0
    counter = 0

    with open(embedding_file) as f:
        for line in f:
            line = line.strip()
            line = line.split(" ")
            if len(line) > 3:
                ar.append( map(float, line[1:] ) )
                dico[  line[0] ] = counter
                counter += 1
            if len(line) < 3:
                size = int(line[1])
    
    ar = np.array(ar, dtype='float')

    return counter, size, dico, ar

def read_sentiment_train(sentiment_file, word_word2vec, task):
    X_word = []
    Y = []
    Z = []

    Y_COUNTER = None
    if task == "task1":
        Y_COUNTER = [0.0]*4
    if task == "task2":
        Y_COUNTER = [0.0]*2
    if task == "task3":
        Y_COUNTER = [0.0]*4

    total = 0

    with open(sentiment_file) as f:
        for line in f:
            line = line.strip()
            line = line.split("\t")

            if task == "task1":
                if line[1] == "negative":
                    Y_COUNTER[0] += 1
                if line[1] == "positive":
                    Y_COUNTER[1] += 1
                if line[1] == "objective":
                    Y_COUNTER[2] += 1
                if line[1] == "mixed":
                    Y_COUNTER[3] += 1
                total += 1

            if task == "task3":
                if line[1] == "negative":
                    Y_COUNTER[0] += 1
                if line[1] == "positive":
                    Y_COUNTER[1] += 1
                if line[1] == "objective":
                    Y_COUNTER[2] += 1
                if line[1] == "mixed":
                    Y_COUNTER[3] += 1
                total += 1

            if task == "task2":
                if line[1] == "figurative":
                    Y_COUNTER[0] += 1
                if line[1] == "nonfigurative":
                    Y_COUNTER[1] += 1
                total += 1



    print Y_COUNTER 
    hash_y_counter = None
    if task == "task1":
        hash_y_counter = {0: Y_COUNTER[0], 1: Y_COUNTER[1], 2: Y_COUNTER[2], 3: Y_COUNTER[3]}
    if task == "task3":
        hash_y_counter = {0: Y_COUNTER[0], 1: Y_COUNTER[1], 2: Y_COUNTER[2], 3: Y_COUNTER[3]}
    if task == "task2":
        hash_y_counter = {0: Y_COUNTER[0], 1: Y_COUNTER[1]}
    print hash_y_counter
    cw_y = create_class_weight(hash_y_counter, 0.5)
    print cw_y



    with open(sentiment_file) as f:
        for line in f:
            line = line.strip()
            line = line.split("\t")

            Z.append( line[0] )

            ar_y = None
            if task == "task1":
                ar_y = [0]*4

                if line[1] == "negative":
                    ar_y[0] = 1
                if line[1] == "positive":
                    ar_y[1] = 1
                if line[1] == "objective":
                    ar_y[2] = 1
                if line[1] == "mixed":
                    ar_y[3] = 1


            if task == "task2":
                ar_y = [0]*2

                if line[1] == "figurative":
                    ar_y[0] = 1
                if line[1] == "nonfigurative":
                    ar_y[1] = 1


            if task == "task3":
                ar_y = [0]*4

                if line[1] == "negative":
                    ar_y[0] = 1
                if line[1] == "positive":
                    ar_y[1] = 1
                if line[1] == "objective":
                    ar_y[2] = 1
                if line[1] == "mixed":
                    ar_y[3] = 1


            Y.append( ar_y )

            tok = line[2].split(" ")
            word_ar = []
            for x in tok:
                if x in word_word2vec:
                    word_ar.append( word_word2vec[ x ] )
            X_word.append( word_ar )
            
    X_word = np.array(X_word)
    Y = np.array(Y)
    Z = np.array(Z)

    return (Z, X_word, Y, cw_y)


def read_sentiment_test(sentiment_file, word_word2vec):
    X_word = []
    Y = []
    Z = []

    with open(sentiment_file) as f:
        for line in f:
            line = line.strip()
            line = line.split("\t")

            Z.append( line[0] )

            ar_y = [0]
            Y.append( ar_y )

            tok = line[2].split(" ")
            word_ar = []
            for x in tok:
                if x in word_word2vec:
                    word_ar.append( word_word2vec[ x ] )
            X_word.append( word_ar )
            
    X_word = np.array(X_word)
    Y = np.array(Y)
    Z = np.array(Z)

    return (Z, X_word, Y)




def read_sentiment(sentiment_file, word_word2vec, task):
    X_word = []
    Y = []
    Z = []

    with open(sentiment_file) as f:
        for line in f:
            line = line.strip()
            line = line.split("\t")

            Z.append( line[0] )

            ar_y = None
            if task == "task1":
                ar_y = [0]*4

                if line[1] == "negative":
                    ar_y[0] = 1
                if line[1] == "positive":
                    ar_y[1] = 1
                if line[1] == "objective":
                    ar_y[2] = 1
                if line[1] == "mixed":
                    ar_y[3] = 1


            if task == "task2":
                ar_y = [0]*2

                if line[1] == "figurative":
                    ar_y[0] = 1
                if line[1] == "nonfigurative":
                    ar_y[1] = 1


            if task == "task3":
                ar_y = [0]*4

                if line[1] == "negative":
                    ar_y[0] = 1
                if line[1] == "positive":
                    ar_y[1] = 1
                if line[1] == "objective":
                    ar_y[2] = 1
                if line[1] == "mixed":
                    ar_y[3] = 1


            Y.append( ar_y )

            tok = line[2].split(" ")
            word_ar = []
            for x in tok:
                if x in word_word2vec:
                    word_ar.append( word_word2vec[ x ] )
            X_word.append( word_ar )
            
    X_word = np.array(X_word)
    Y = np.array(Y)
    Z = np.array(Z)

    return (Z, X_word, Y)



maxlen = 150
word_nb_feature_maps = 200
hidden_size = 64

word_embedding_file = sys.argv[1]
test_file = sys.argv[2]
model_file = sys.argv[3]


model = model_from_json(open(model_file+".json").read())
model.load_weights(model_file+".h5")


word_vocab_size, word_embedding_size, word_dico, word_initialize_weight = read_word2vec(word_embedding_file)
ID_train, X_word_train, Y_train = read_sentiment_test(test_file, word_dico)
X_word_train = sequence.pad_sequences(X_word_train, maxlen=maxlen, padding='post', truncating='post')


predict = model.predict(X_word_train, batch_size=2048, verbose=0)


counter = 0
for x in ID_train:
    print x+"\t"+" ".join(map(str, predict[counter]))
    counter += 1





