vae.py
4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
'''This script demonstrates how to build a variational autoencoder with Keras.
Reference: "Auto-Encoding Variational Bayes" https://arxiv.org/abs/1312.6114
'''
import itertools
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
from scipy import sparse
import scipy.io
from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import objectives
from keras.datasets import mnist
from keras.callbacks import EarlyStopping,Callback
import pandas
import shelve
import pickle
class ZeroStopping(Callback):
'''Stop training when a monitored quantity has stopped improving.
# Arguments
monitor: quantity to be monitored.
patience: number of epochs with no improvement
after which training will be stopped.
verbose: verbosity mode.
mode: one of {auto, min, max}. In 'min' mode,
training will stop when the quantity
monitored has stopped decreasing; in 'max'
mode it will stop when the quantity
monitored has stopped increasing.
'''
def __init__(self, monitor='val_loss', verbose=0, mode='auto', thresh = 0):
super(ZeroStopping, self).__init__()
self.monitor = monitor
self.verbose = verbose
self.thresh = thresh # is a rythme
if mode not in ['auto', 'min', 'max']:
warnings.warn('EarlyStopping mode %s is unknown, '
'fallback to auto mode.' % (self.mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
def on_epoch_end(self, epoch, logs={}):
current = logs.get(self.monitor)
if current is None:
warnings.warn('Zero stopping requires %s available!' %
(self.monitor), RuntimeWarning)
if self.monitor_op(current, self.thresh):
self.best = current
self.model.stop_training = True
#batch_size = 16
#original_dim = 784
#latent_dim = 2
#intermediate_dim = 128
#epsilon_std = 0.01
#nb_epoch = 40
def train_vae(x_train,x_dev,x_test,y_train=None,y_dev=None,y_test=None,hidden_size=80,latent_dim=12,batch_size=8,nb_epochs=10,sgd="rmsprop",input_activation = "relu",output_activation = "sigmoid",epsilon_std=0.01):
def sampling(args):
z_mean, z_log_std = args
epsilon = K.random_normal(shape=(batch_size, latent_dim),
mean=0., std=epsilon_std)
return z_mean + K.exp(z_log_std) * epsilon
def vae_loss(x, x_decoded_mean):
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_std - K.square(z_mean) - K.exp(z_log_std), axis=-1)
return xent_loss + kl_loss
original_dim = x_train.shape[1]
x = Input(batch_shape=(batch_size, original_dim))
h = Dense(hidden_size, activation=input_activation)(x)
z_mean = Dense(latent_dim)(h)
z_log_std = Dense(latent_dim)(h)
# note that "output_shape" isn't necessary with the TensorFlow backend
# so you could write `Lambda(sampling)([z_mean, z_log_std])`
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_std])
# we instantiate these layers separately so as to reuse them later
decoder_h = Dense(hidden_size, activation=input_activation)
decoder_mean = Dense(original_dim, activation=output_activation)
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
vae = Model(x, x_decoded_mean)
vae.compile(optimizer=sgd, loss=vae_loss)
# train the VAE on MNIST digits
if y_train is None or y_dev is None or y_test is None :
y_train = x_train
y_dev = x_dev
y_test = x_test
vae.fit(x_train, y_train,
shuffle=True,
nb_epoch=nb_epochs,
verbose = 1,
batch_size=batch_size,
validation_data=(x_dev, y_dev),
callbacks = [ZeroStopping(monitor='val_loss', thresh=0, verbose=0, mode='min')]
)
# build a model to project inputs on the latent space
encoder = Model(x, z_mean)
pred_train = encoder.predict(x_train, batch_size=batch_size)
pred_dev = encoder.predict(x_dev, batch_size=batch_size)
pred_test = encoder.predict(x_test,batch_size=batch_size)
return [ [ pred_train, pred_dev, pred_test ] ]
# display a 2D plot of the digit classes in the latent space
#x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
# build a digit generator that can sample from the learned distribution
#decoder_input = Input(shape=(latent_dim,))
#_h_decoded = decoder_h(decoder_input)
#_x_decoded_mean = decoder_mean(_h_decoded)
#generator = Model(decoder_input, _x_decoded_mean)
#x_decoded = generator.predict(z_sample)