vae.py 4.99 KB
'''This script demonstrates how to build a variational autoencoder with Keras.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''

import itertools
import sys
import json

import numpy as np
import matplotlib.pyplot as plt
from scipy import sparse
import scipy.io

from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.callbacks import EarlyStopping,Callback

import pandas
import shelve
import pickle


class ZeroStopping(Callback):
    '''Stop training when a monitored quantity has stopped improving.
    # Arguments
        monitor: quantity to be monitored.
        patience: number of epochs with no improvement
            after which training will be stopped.
        verbose: verbosity mode.
        mode: one of {auto, min, max}. In 'min' mode,
            training will stop when the quantity
            monitored has stopped decreasing; in 'max'
            mode it will stop when the quantity
            monitored has stopped increasing.
    '''
    def __init__(self, monitor='val_loss', verbose=0, mode='auto', thresh = 0):
        super(ZeroStopping, self).__init__()

        self.monitor = monitor
        self.verbose = verbose
        self.thresh = thresh # is a rythme

        if mode not in ['auto', 'min', 'max']:
            warnings.warn('EarlyStopping mode %s is unknown, '
                          'fallback to auto mode.' % (self.mode),
                          RuntimeWarning)
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
        elif mode == 'max':
            self.monitor_op = np.greater
        else:
            if 'acc' in self.monitor:
                self.monitor_op = np.greater
            else:
                self.monitor_op = np.less

    def on_epoch_end(self, epoch, logs={}):
        current = logs.get(self.monitor)
        if current is None:
            warnings.warn('Zero stopping requires %s available!' %
                          (self.monitor), RuntimeWarning)

        if self.monitor_op(current, self.thresh):
            self.best = current
            self.model.stop_training = True

#batch_size = 16
#original_dim = 784
#latent_dim = 2
#intermediate_dim = 128
#epsilon_std = 0.01
#nb_epoch = 40




def train_vae(x_train,x_dev,x_test,y_train=None,y_dev=None,y_test=None,hidden_size=80,latent_dim=12,batch_size=8,nb_epochs=10,sgd="rmsprop",input_activation = "relu",output_activation = "sigmoid",epsilon_std=0.01):



    def sampling(args):
        z_mean, z_log_std = args
        epsilon = K.random_normal(shape=(batch_size, latent_dim),
                                  mean=0., std=epsilon_std)
        return z_mean + K.exp(z_log_std) * epsilon

    def vae_loss(x, x_decoded_mean):
        xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
        kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1)
        return xent_loss + kl_loss

    original_dim = x_train.shape[1]


    x = Input(batch_shape=(batch_size, original_dim))
    h = Dense(hidden_size, activation=input_activation)(x)
    z_mean = Dense(latent_dim)(h)
    z_log_std = Dense(latent_dim)(h)


    # note that "output_shape" isn't necessary with the TensorFlow backend
    # so you could write `Lambda(sampling)([z_mean, z_log_std])`
    z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std])

    # we instantiate these layers separately so as to reuse them later
    decoder_h = Dense(hidden_size, activation=input_activation)
    decoder_mean = Dense(original_dim, activation=output_activation)
    h_decoded = decoder_h(z)
    x_decoded_mean = decoder_mean(h_decoded)


    vae = Model(x, x_decoded_mean)
    vae.compile(optimizer=sgd, loss=vae_loss)

    # train the VAE on MNIST digits
    if y_train is None or y_dev is None or y_test is None :
        y_train = x_train
        y_dev = x_dev
        y_test = x_test

    vae.fit(x_train, y_train,
            shuffle=True,
            nb_epoch=nb_epochs,
            verbose = 1,
            batch_size=batch_size,
            validation_data=(x_dev, y_dev)
            #callbacks = [ZeroStopping(monitor='val_loss', thresh=0, verbose=0, mode='min')]
            )

    # build a model to project inputs on the latent space
    encoder = Model(x, z_mean)
    pred_train = encoder.predict(x_train, batch_size=batch_size)
    pred_dev = encoder.predict(x_dev, batch_size=batch_size)
    pred_test = encoder.predict(x_test,batch_size=batch_size)
    return [ [ pred_train, pred_dev, pred_test ] ]
# display a 2D plot of the digit classes in the latent space
    #x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
    # build a digit generator that can sample from the learned distribution
    #decoder_input = Input(shape=(latent_dim,))
    #_h_decoded = decoder_h(decoder_input)
    #_x_decoded_mean = decoder_mean(_h_decoded)
    #generator = Model(decoder_input, _x_decoded_mean)
    #x_decoded = generator.predict(z_sample)