Blame view
LDA/vae.py
4.99 KB
7db73861f add vae et mmf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
'''This script demonstrates how to build a variational autoencoder with Keras. Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114 ''' import itertools import sys import json import numpy as np import matplotlib.pyplot as plt from scipy import sparse import scipy.io from keras.layers import Input, Dense, Lambda from keras.models import Model from keras import backend as K from keras import objectives from keras.datasets import mnist |
2af8e57f4 change all |
19 |
from keras.callbacks import EarlyStopping,Callback |
7db73861f add vae et mmf |
20 21 22 23 |
import pandas import shelve import pickle |
2af8e57f4 change all |
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
class ZeroStopping(Callback): '''Stop training when a monitored quantity has stopped improving. # Arguments monitor: quantity to be monitored. patience: number of epochs with no improvement after which training will be stopped. verbose: verbosity mode. mode: one of {auto, min, max}. In 'min' mode, training will stop when the quantity monitored has stopped decreasing; in 'max' mode it will stop when the quantity monitored has stopped increasing. ''' def __init__(self, monitor='val_loss', verbose=0, mode='auto', thresh = 0): super(ZeroStopping, self).__init__() self.monitor = monitor self.verbose = verbose self.thresh = thresh # is a rythme if mode not in ['auto', 'min', 'max']: warnings.warn('EarlyStopping mode %s is unknown, ' 'fallback to auto mode.' % (self.mode), RuntimeWarning) mode = 'auto' if mode == 'min': self.monitor_op = np.less elif mode == 'max': self.monitor_op = np.greater else: if 'acc' in self.monitor: self.monitor_op = np.greater else: self.monitor_op = np.less def on_epoch_end(self, epoch, logs={}): current = logs.get(self.monitor) if current is None: warnings.warn('Zero stopping requires %s available!' % (self.monitor), RuntimeWarning) if self.monitor_op(current, self.thresh): self.best = current self.model.stop_training = True |
7db73861f add vae et mmf |
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#batch_size = 16 #original_dim = 784 #latent_dim = 2 #intermediate_dim = 128 #epsilon_std = 0.01 #nb_epoch = 40 def train_vae(x_train,x_dev,x_test,y_train=None,y_dev=None,y_test=None,hidden_size=80,latent_dim=12,batch_size=8,nb_epochs=10,sgd="rmsprop",input_activation = "relu",output_activation = "sigmoid",epsilon_std=0.01): def sampling(args): z_mean, z_log_std = args epsilon = K.random_normal(shape=(batch_size, latent_dim), mean=0., std=epsilon_std) return z_mean + K.exp(z_log_std) * epsilon def vae_loss(x, x_decoded_mean): xent_loss = objectives.binary_crossentropy(x, x_decoded_mean) kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1) return xent_loss + kl_loss original_dim = x_train.shape[1] x = Input(batch_shape=(batch_size, original_dim)) h = Dense(hidden_size, activation=input_activation)(x) z_mean = Dense(latent_dim)(h) z_log_std = Dense(latent_dim)(h) # note that "output_shape" isn't necessary with the TensorFlow backend # so you could write `Lambda(sampling)([z_mean, z_log_std])` z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std]) # we instantiate these layers separately so as to reuse them later decoder_h = Dense(hidden_size, activation=input_activation) decoder_mean = Dense(original_dim, activation=output_activation) h_decoded = decoder_h(z) x_decoded_mean = decoder_mean(h_decoded) vae = Model(x, x_decoded_mean) vae.compile(optimizer=sgd, loss=vae_loss) # train the VAE on MNIST digits if y_train is None or y_dev is None or y_test is None : y_train = x_train y_dev = x_dev y_test = x_test vae.fit(x_train, y_train, shuffle=True, nb_epoch=nb_epochs, |
2af8e57f4 change all |
127 |
verbose = 1, |
7db73861f add vae et mmf |
128 |
batch_size=batch_size, |
2af8e57f4 change all |
129 130 131 |
validation_data=(x_dev, y_dev), callbacks = [ZeroStopping(monitor='val_loss', thresh=0, verbose=0, mode='min')] ) |
7db73861f add vae et mmf |
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# build a model to project inputs on the latent space encoder = Model(x, z_mean) pred_train = encoder.predict(x_train, batch_size=batch_size) pred_dev = encoder.predict(x_dev, batch_size=batch_size) pred_test = encoder.predict(x_test,batch_size=batch_size) return [ [ pred_train, pred_dev, pred_test ] ] # display a 2D plot of the digit classes in the latent space #x_test_encoded = encoder.predict(x_test, batch_size=batch_size) # build a digit generator that can sample from the learned distribution #decoder_input = Input(shape=(latent_dim,)) #_h_decoded = decoder_h(decoder_input) #_x_decoded_mean = decoder_mean(_h_decoded) #generator = Model(decoder_input, _x_decoded_mean) #x_decoded = generator.predict(z_sample) |