Commit 7453639646158900fca7644f7bbc94cdbdf746fe

Authored by Parcollet Titouan
1 parent f2d3bd1415
Exists in master

Cleaning

Showing 1 changed file with 3 additions and 1 deletions Inline Diff

1 #!/usr/bin/env python 1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*- 2 # -*- coding: utf-8 -*-
3 # Authors: Parcollet Titouan 3 # Authors: Parcollet Titouan
4 4
5 # Imports 5 # Imports
6 import editdistance 6 import editdistance
7 import h5py 7 import h5py
8 import datasets.timit 8 import datasets.timit
9 from datasets.timit import Timit 9 from datasets.timit import Timit
10 from datasets.utils import construct_conv_stream, phone_to_phoneme_dict 10 from datasets.utils import construct_conv_stream, phone_to_phoneme_dict
11 import complexnn 11 import complexnn
12 from complexnn import * 12 from complexnn import *
13 import h5py as H 13 import h5py as H
14 import keras 14 import keras
15 from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler 15 from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler
16 from keras.initializers import Orthogonal 16 from keras.initializers import Orthogonal
17 from keras.layers import Layer, Dropout, AveragePooling1D, AveragePooling2D, 17 from keras.layers import Layer, Dropout, AveragePooling1D, AveragePooling2D,
18 AveragePooling3D, add, Add, concatenate, Concatenate, 18 AveragePooling3D, add, Add, concatenate, Concatenate,
19 Input, Flatten, Dense, Convolution2D, BatchNormalization, 19 Input, Flatten, Dense, Convolution2D, BatchNormalization,
20 Activation, Reshape, ConvLSTM2D, Conv2D, Lambda 20 Activation, Reshape, ConvLSTM2D, Conv2D, Lambda
21 from keras.models import Model, load_model, save_model 21 from keras.models import Model, load_model, save_model
22 from keras.optimizers import SGD, Adam, RMSprop 22 from keras.optimizers import SGD, Adam, RMSprop
23 from keras.regularizers import l2 23 from keras.regularizers import l2
24 from keras.utils.np_utils import to_categorical 24 from keras.utils.np_utils import to_categorical
25 import keras.backend as K 25 import keras.backend as K
26 import keras.models as KM 26 import keras.models as KM
27 from keras.utils.training_utils import multi_gpu_model 27 from keras.utils.training_utils import multi_gpu_model
28 import logging as L 28 import logging as L
29 import numpy as np 29 import numpy as np
30 import os, pdb, socket, sys, time 30 import os, pdb, socket, sys, time
31 import theano as T 31 import theano as T
32 from keras.backend.tensorflow_backend import set_session 32 from keras.backend.tensorflow_backend import set_session
33 from models_timit import getTimitResnetModel2D,ctc_lambda_func 33 from models_timit import getTimitResnetModel2D,ctc_lambda_func
34 import tensorflow as tf 34 import tensorflow as tf
35 import itertools 35 import itertools
36 import random 36 import random
37 37
38 38
39 # 39 #
40 # Generator wrapper for timit 40 # Generator wrapper for timit
41 # 41 #
42 42
43 def timitGenerator(stream): 43 def timitGenerator(stream):
44 while True: 44 while True:
45 for data in stream.get_epoch_iterator(): 45 for data in stream.get_epoch_iterator():
46 yield data 46 yield data
47 47
48 # 48 #
49 # Custom metrics 49 # Custom metrics
50 # 50 #
51 class EditDistance(Callback): 51 class EditDistance(Callback):
52 def __init__(self, func, dataset, quaternion, save_prefix): 52 def __init__(self, func, dataset, quaternion, save_prefix):
53 self.func = func 53 self.func = func
54 if(dataset in ['train','test','dev']): 54 if(dataset in ['train','test','dev']):
55 self.dataset_type = dataset 55 self.dataset_type = dataset
56 self.save_prefix = save_prefix 56 self.save_prefix = save_prefix
57 self.dataset = Timit(str(dataset)) 57 self.dataset = Timit(str(dataset))
58 self.full_phonemes_dict = self.dataset.get_phoneme_dict() 58 self.full_phonemes_dict = self.dataset.get_phoneme_dict()
59 self.ind_phonemes_dict = self.dataset.get_phoneme_ind_dict() 59 self.ind_phonemes_dict = self.dataset.get_phoneme_ind_dict()
60 self.rng = np.random.RandomState(123) 60 self.rng = np.random.RandomState(123)
61 self.data_stream = construct_conv_stream(self.dataset, self.rng, 20, 1000,quaternion) 61 self.data_stream = construct_conv_stream(self.dataset, self.rng, 20, 1000,quaternion)
62 62
63 else: 63 else:
64 raise ValueError("Unknown dataset for edit distance "+dataset) 64 raise ValueError("Unknown dataset for edit distance "+dataset)
65 65
66 def labels_to_text(self,labels): 66 def labels_to_text(self,labels):
67 ret = [] 67 ret = []
68 for c in labels: 68 for c in labels:
69 if c == len(self.full_phonemes_dict) - 2: 69 if c == len(self.full_phonemes_dict) - 2:
70 ret.append("") 70 ret.append("")
71 else: 71 else:
72 c_ = self.full_phonemes_dict[c + 1] 72 c_ = self.full_phonemes_dict[c + 1]
73 ret.append(phone_to_phoneme_dict.get(c_, c_)) 73 ret.append(phone_to_phoneme_dict.get(c_, c_))
74 ret = [k for k, g in itertools.groupby(ret)] 74 ret = [k for k, g in itertools.groupby(ret)]
75 return list(filter(lambda c: c != "", ret)) 75 return list(filter(lambda c: c != "", ret))
76 76
77 def decode_batch(self, out, mask): 77 def decode_batch(self, out, mask):
78 ret = [] 78 ret = []
79 for j in range(out.shape[0]): 79 for j in range(out.shape[0]):
80 out_best = list(np.argmax(out[j], 1))[:int(mask[j])] 80 out_best = list(np.argmax(out[j], 1))[:int(mask[j])]
81 out_best = [k for k, g in itertools.groupby(out_best)] 81 out_best = [k for k, g in itertools.groupby(out_best)]
82 # map from 61-d to 39-d 82 # map from 61-d to 39-d
83 out_str = self.labels_to_text(out_best) 83 out_str = self.labels_to_text(out_best)
84 ret.append(out_str) 84 ret.append(out_str)
85 return ret 85 return ret
86 86
87 def on_epoch_end(self, epoch, logs={}): 87 def on_epoch_end(self, epoch, logs={}):
88 mean_norm_ed = 0. 88 mean_norm_ed = 0.
89 num = 0 89 num = 0
90 for data in self.data_stream.get_epoch_iterator(): 90 for data in self.data_stream.get_epoch_iterator():
91 x, y = data 91 x, y = data
92 y_pred = self.func([x[0]])[0] 92 y_pred = self.func([x[0]])[0]
93 decoded_y_pred = self.decode_batch(y_pred, x[1]) 93 decoded_y_pred = self.decode_batch(y_pred, x[1])
94 decoded_gt = [] 94 decoded_gt = []
95 for i in range(x[2].shape[0]): 95 for i in range(x[2].shape[0]):
96 decoded_gt.append(self.labels_to_text(x[2][i][:int(x[3][i])])) 96 decoded_gt.append(self.labels_to_text(x[2][i][:int(x[3][i])]))
97 num += len(decoded_y_pred) 97 num += len(decoded_y_pred)
98 for i, (_pred, _gt) in enumerate(zip(decoded_y_pred, decoded_gt)): 98 for i, (_pred, _gt) in enumerate(zip(decoded_y_pred, decoded_gt)):
99 edit_dist = editdistance.eval(_pred, _gt) 99 edit_dist = editdistance.eval(_pred, _gt)
100 mean_norm_ed += float(edit_dist) / float(len(_gt)) 100 mean_norm_ed += float(edit_dist) / float(len(_gt))
101 mean_norm_ed = mean_norm_ed / num 101 mean_norm_ed = mean_norm_ed / num
102 102
103 # Dump To File Logs at every epoch for clusters sbatch 103 # Dump To File Logs at every epoch for clusters sbatch
104 f=open(str(self.save_prefix)+"_"+str(self.dataset_type)+"_PER.txt",'ab') 104 f=open(str(self.save_prefix)+"_"+str(self.dataset_type)+"_PER.txt",'ab')
105 mean = np.array([mean_norm_ed]) 105 mean = np.array([mean_norm_ed])
106 np.savetxt(f,mean) 106 np.savetxt(f,mean)
107 f.close() 107 f.close()
108 L.getLogger("train").info("PER on "+str(self.dataset_type)+" : "+str(mean_norm_ed)+" at epoch "+str(epoch)) 108 L.getLogger("train").info("PER on "+str(self.dataset_type)+" : "+str(mean_norm_ed)+" at epoch "+str(epoch))
109 109
110 # 110 #
111 # Callbacks: 111 # Callbacks:
112 # 112 #
113 113
114 class TrainLoss(Callback): 114 class TrainLoss(Callback):
115 def __init__(self, savedir): 115 def __init__(self, savedir):
116 self.savedir = savedir 116 self.savedir = savedir
117 def on_epoch_end(self, epoch, logs={}): 117 def on_epoch_end(self, epoch, logs={}):
118 f=open(str(self.savedir)+"_train_loss.txt",'ab') 118 f=open(str(self.savedir)+"_train_loss.txt",'ab')
119 f2=open(str(self.savedir)+"_dev_loss.txt",'ab') 119 f2=open(str(self.savedir)+"_dev_loss.txt",'ab')
120 value = float(logs['loss']) 120 value = float(logs['loss'])
121 np.savetxt(f,np.array([value])) 121 np.savetxt(f,np.array([value]))
122 f.close() 122 f.close()
123 value = float(logs['val_loss']) 123 value = float(logs['val_loss'])
124 np.savetxt(f2,np.array([value])) 124 np.savetxt(f2,np.array([value]))
125 f2.close() 125 f2.close()
126 # 126 #
127 # Print a newline after each epoch, because Keras doesn't. Grumble. 127 # Print a newline after each epoch, because Keras doesn't. Grumble.
128 # 128 #
129 129
130 class PrintNewlineAfterEpochCallback(Callback): 130 class PrintNewlineAfterEpochCallback(Callback):
131 def on_epoch_end(self, epoch, logs={}): 131 def on_epoch_end(self, epoch, logs={}):
132 sys.stdout.write("\n") 132 sys.stdout.write("\n")
133 # 133 #
134 # Save checkpoints. 134 # Save checkpoints.
135 # 135 #
136 136
137 class SaveLastModel(Callback): 137 class SaveLastModel(Callback):
138 def __init__(self, workdir, save_prefix, model_mono,period=10): 138 def __init__(self, workdir, save_prefix, model_mono,period=10):
139 self.workdir = workdir 139 self.workdir = workdir
140 self.model_mono = model_mono 140 self.model_mono = model_mono
141 self.chkptsdir = os.path.join(self.workdir, "chkpts") 141 self.chkptsdir = os.path.join(self.workdir, "chkpts")
142 self.save_prefix = save_prefix 142 self.save_prefix = save_prefix
143 if not os.path.isdir(self.chkptsdir): 143 if not os.path.isdir(self.chkptsdir):
144 os.mkdir(self.chkptsdir) 144 os.mkdir(self.chkptsdir)
145 self.period_of_epochs = period 145 self.period_of_epochs = period
146 self.linkFilename = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt.hdf5") 146 self.linkFilename = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt.hdf5")
147 self.linkFilename_weight = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt_weight.hdf5") 147 self.linkFilename_weight = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt_weight.hdf5")
148 148
149 def on_epoch_end(self, epoch, logs={}): 149 def on_epoch_end(self, epoch, logs={}):
150 if (epoch + 1) % self.period_of_epochs == 0: 150 if (epoch + 1) % self.period_of_epochs == 0:
151 151
152 # Filenames 152 # Filenames
153 baseHDF5Filename = str(self.save_prefix)+"ModelChkpt{:06d}.hdf5".format(epoch+1) 153 baseHDF5Filename = str(self.save_prefix)+"ModelChkpt{:06d}.hdf5".format(epoch+1)
154 baseHDF5Filename_weight = str(self.save_prefix)+"ModelChkpt{:06d}_weight.hdf5".format(epoch+1) 154 baseHDF5Filename_weight = str(self.save_prefix)+"ModelChkpt{:06d}_weight.hdf5".format(epoch+1)
155 baseYAMLFilename = str(self.save_prefix)+"ModelChkpt{:06d}.yaml".format(epoch+1) 155 baseYAMLFilename = str(self.save_prefix)+"ModelChkpt{:06d}.yaml".format(epoch+1)
156 hdf5Filename = os.path.join(self.chkptsdir, baseHDF5Filename) 156 hdf5Filename = os.path.join(self.chkptsdir, baseHDF5Filename)
157 hdf5Filename_weight = os.path.join(self.chkptsdir, baseHDF5Filename_weight) 157 hdf5Filename_weight = os.path.join(self.chkptsdir, baseHDF5Filename_weight)
158 yamlFilename = os.path.join(self.chkptsdir, baseYAMLFilename) 158 yamlFilename = os.path.join(self.chkptsdir, baseYAMLFilename)
159 159
160 # YAML 160 # YAML
161 yamlModel = self.model_mono.to_yaml() 161 yamlModel = self.model_mono.to_yaml()
162 with open(yamlFilename, "w") as yamlFile: 162 with open(yamlFilename, "w") as yamlFile:
163 yamlFile.write(yamlModel) 163 yamlFile.write(yamlModel)
164 164
165 # HDF5 165 # HDF5
166 KM.save_model(self.model_mono, hdf5Filename) 166 KM.save_model(self.model_mono, hdf5Filename)
167 self.model_mono.save_weights(hdf5Filename_weight) 167 self.model_mono.save_weights(hdf5Filename_weight)
168 with H.File(hdf5Filename, "r+") as f: 168 with H.File(hdf5Filename, "r+") as f:
169 f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) 169 f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1)
170 f.flush() 170 f.flush()
171 with H.File(hdf5Filename_weight, "r+") as f: 171 with H.File(hdf5Filename_weight, "r+") as f:
172 f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) 172 f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1)
173 f.flush() 173 f.flush()
174 174
175 175
176 # Symlink to new HDF5 file, then atomically rename and replace. 176 # Symlink to new HDF5 file, then atomically rename and replace.
177 os.symlink(baseHDF5Filename_weight, self.linkFilename_weight+".rename") 177 os.symlink(baseHDF5Filename_weight, self.linkFilename_weight+".rename")
178 os.rename (self.linkFilename_weight+".rename", 178 os.rename (self.linkFilename_weight+".rename",
179 self.linkFilename_weight) 179 self.linkFilename_weight)
180 180
181 181
182 # Symlink to new HDF5 file, then atomically rename and replace. 182 # Symlink to new HDF5 file, then atomically rename and replace.
183 os.symlink(baseHDF5Filename, self.linkFilename+".rename") 183 os.symlink(baseHDF5Filename, self.linkFilename+".rename")
184 os.rename (self.linkFilename+".rename", 184 os.rename (self.linkFilename+".rename",
185 self.linkFilename) 185 self.linkFilename)
186 186
187 # Print 187 # Print
188 L.getLogger("train").info("Saved checkpoint to {:s} at epoch {:5d}".format(hdf5Filename, epoch+1)) 188 L.getLogger("train").info("Saved checkpoint to {:s} at epoch {:5d}".format(hdf5Filename, epoch+1))
189 189
190 # 190 #
191 # Summarize environment variable. 191 # Summarize environment variable.
192 # 192 #
193 193
194 def summarizeEnvvar(var): 194 def summarizeEnvvar(var):
195 if var in os.environ: return var+"="+os.environ.get(var) 195 if var in os.environ: return var+"="+os.environ.get(var)
196 else: return var+" unset" 196 else: return var+" unset"
197 197
198 # 198 #
199 # TRAINING PROCESS 199 # TRAINING PROCESS
200 # 200 #
201 201
202 def train(d): 202 def train(d):
203 203
204 # 204 #
205 # 205 #
206 # Log important data about how we were invoked. 206 # Log important data about how we were invoked.
207 # 207 #
208 L.getLogger("entry").info("INVOCATION: "+" ".join(sys.argv)) 208 L.getLogger("entry").info("INVOCATION: "+" ".join(sys.argv))
209 L.getLogger("entry").info("HOSTNAME: "+socket.gethostname()) 209 L.getLogger("entry").info("HOSTNAME: "+socket.gethostname())
210 L.getLogger("entry").info("PWD: "+os.getcwd()) 210 L.getLogger("entry").info("PWD: "+os.getcwd())
211 L.getLogger("entry").info("CUDA DEVICE: "+str(d.device)) 211 L.getLogger("entry").info("CUDA DEVICE: "+str(d.device))
212 os.environ["CUDA_VISIBLE_DEVICES"]=str(d.device) 212 os.environ["CUDA_VISIBLE_DEVICES"]=str(d.device)
213 213
214 # 214 #
215 # Setup GPUs 215 # Setup GPUs
216 # 216 #
217 config = tf.ConfigProto() 217 config = tf.ConfigProto()
218 218
219 # 219 #
220 # Don't pre-allocate memory; allocate as-needed 220 # Don't pre-allocate memory; allocate as-needed
221 # 221 #
222 config.gpu_options.allow_growth = True 222 config.gpu_options.allow_growth = True
223 223
224 # 224 #
225 # Only allow a total of half the GPU memory to be allocated 225 # Only allow a total of half the GPU memory to be allocated
226 # 226 #
227 config.gpu_options.per_process_gpu_memory_fraction = d.memory 227 config.gpu_options.per_process_gpu_memory_fraction = d.memory
228 228
229 # 229 #
230 # Create a session with the above options specified. 230 # Create a session with the above options specified.
231 # 231 #
232 K.tensorflow_backend.set_session(tf.Session(config=config)) 232 K.tensorflow_backend.set_session(tf.Session(config=config))
233 233
234 summary = "\n" 234 summary = "\n"
235 summary += "Environment:\n" 235 summary += "Environment:\n"
236 summary += summarizeEnvvar("THEANO_FLAGS")+"\n" 236 summary += summarizeEnvvar("THEANO_FLAGS")+"\n"
237 summary += "\n" 237 summary += "\n"
238 summary += "Software Versions:\n" 238 summary += "Software Versions:\n"
239 summary += "Theano: "+T.__version__+"\n" 239 summary += "Theano: "+T.__version__+"\n"
240 summary += "Keras: "+keras.__version__+"\n" 240 summary += "Keras: "+keras.__version__+"\n"
241 summary += "\n" 241 summary += "\n"
242 summary += "Arguments:\n" 242 summary += "Arguments:\n"
243 summary += "Path to Datasets: "+str(d.datadir)+"\n" 243 summary += "Path to Datasets: "+str(d.datadir)+"\n"
244 summary += "Number of GPUs: "+str(d.datadir)+"\n" 244 summary += "Number of GPUs: "+str(d.datadir)+"\n"
245 summary += "Path to Workspace: "+str(d.workdir)+"\n" 245 summary += "Path to Workspace: "+str(d.workdir)+"\n"
246 summary += "Model: "+str(d.model)+"\n" 246 summary += "Model: "+str(d.model)+"\n"
247 summary += "Number of Epochs: "+str(d.num_epochs)+"\n" 247 summary += "Number of Epochs: "+str(d.num_epochs)+"\n"
248 summary += "Number of Start Filters: "+str(d.start_filter)+"\n" 248 summary += "Number of Start Filters: "+str(d.start_filter)+"\n"
249 summary += "Number of Layers: "+str(d.num_layers)+"\n" 249 summary += "Number of Layers: "+str(d.num_layers)+"\n"
250 summary += "Optimizer: "+str(d.optimizer)+"\n" 250 summary += "Optimizer: "+str(d.optimizer)+"\n"
251 summary += "Learning Rate: "+str(d.lr)+"\n" 251 summary += "Learning Rate: "+str(d.lr)+"\n"
252 summary += "Learning Rate Decay: "+str(d.decay)+"\n" 252 summary += "Learning Rate Decay: "+str(d.decay)+"\n"
253 summary += "Clipping Norm: "+str(d.clipnorm)+"\n" 253 summary += "Clipping Norm: "+str(d.clipnorm)+"\n"
254 summary += "Clipping Value: "+str(d.clipval)+"\n" 254 summary += "Clipping Value: "+str(d.clipval)+"\n"
255 summary += "Dropout Probability: "+str(d.dropout)+"\n" 255 summary += "Dropout Probability: "+str(d.dropout)+"\n"
256 if d.optimizer in ["adam"]: 256 if d.optimizer in ["adam"]:
257 summary += "Beta 1: "+str(d.beta1)+"\n" 257 summary += "Beta 1: "+str(d.beta1)+"\n"
258 summary += "Beta 2: "+str(d.beta2)+"\n" 258 summary += "Beta 2: "+str(d.beta2)+"\n"
259 else: 259 else:
260 summary += "Momentum: "+str(d.momentum)+"\n" 260 summary += "Momentum: "+str(d.momentum)+"\n"
261 summary += "Save Prefix: "+str(d.save_prefix)+"\n" 261 summary += "Save Prefix: "+str(d.save_prefix)+"\n"
262 L.getLogger("entry").info(summary[:-1]) 262 L.getLogger("entry").info(summary[:-1])
263 263
264 # 264 #
265 # Load dataset 265 # Load dataset
266 # 266 #
267 L.getLogger("entry").info("Loading dataset {:s} ...".format(d.dataset)) 267 L.getLogger("entry").info("Loading dataset {:s} ...".format(d.dataset))
268 np.random.seed(d.seed % 2**32) 268 np.random.seed(d.seed % 2**32)
269 269
270 # 270 #
271 # Create training data generator 271 # Create training data generator
272 # 272 #
273 dataset = Timit('train') 273 dataset = Timit('train')
274 rng=np.random.RandomState(123) 274 rng=np.random.RandomState(123)
275 if d.model =="quaternion": 275 if d.model =="quaternion":
276 data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=True) 276 data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=True)
277 else: 277 else:
278 data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) 278 data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False)
279 279
280 # 280 #
281 # Create dev data generator 281 # Create dev data generator
282 # 282 #
283 dataset = Timit('dev') 283 dataset = Timit('dev')
284 rng=np.random.RandomState(123) 284 rng=np.random.RandomState(123)
285 if d.model =="quaternion": 285 if d.model =="quaternion":
286 data_stream_dev = construct_conv_stream(dataset, rng, 200, 10000, quaternion=True) 286 data_stream_dev = construct_conv_stream(dataset, rng, 200, 10000, quaternion=True)
287 else: 287 else:
288 data_stream_dev = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) 288 data_stream_dev = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False)
289 289
290 290
291 L.getLogger("entry").info("Training set length: "+str(Timit('train').num_examples)) 291 L.getLogger("entry").info("Training set length: "+str(Timit('train').num_examples))
292 L.getLogger("entry").info("Validation set length: "+str(Timit('dev').num_examples)) 292 L.getLogger("entry").info("Validation set length: "+str(Timit('dev').num_examples))
293 L.getLogger("entry").info("Test set length: "+str(Timit('test').num_examples)) 293 L.getLogger("entry").info("Test set length: "+str(Timit('test').num_examples))
294 L.getLogger("entry").info("Loaded dataset {:s}.".format(d.dataset)) 294 L.getLogger("entry").info("Loaded dataset {:s}.".format(d.dataset))
295 295
296 # 296 #
297 # Optimizers 297 # Optimizers
298 # 298 #
299 if d.optimizer in ["sgd", "nag"]: 299 if d.optimizer in ["sgd", "nag"]:
300 opt = SGD (lr = d.lr, 300 opt = SGD (lr = d.lr,
301 momentum = d.momentum, 301 momentum = d.momentum,
302 decay = d.decay, 302 decay = d.decay,
303 nesterov = (d.optimizer=="nag"), 303 nesterov = (d.optimizer=="nag"),
304 clipnorm = d.clipnorm) 304 clipnorm = d.clipnorm)
305 elif d.optimizer == "rmsprop": 305 elif d.optimizer == "rmsprop":
306 opt = RMSProp(lr = d.lr, 306 opt = RMSProp(lr = d.lr,
307 decay = d.decay, 307 decay = d.decay,
308 clipnorm = d.clipnorm) 308 clipnorm = d.clipnorm)
309 elif d.optimizer == "adam": 309 elif d.optimizer == "adam":
310 opt = Adam (lr = d.lr, 310 opt = Adam (lr = d.lr,
311 beta_1 = d.beta1, 311 beta_1 = d.beta1,
312 beta_2 = d.beta2, 312 beta_2 = d.beta2,
313 decay = d.decay, 313 decay = d.decay,
314 clipnorm = d.clipnorm) 314 clipnorm = d.clipnorm)
315 else: 315 else:
316 raise ValueError("Unknown optimizer "+d.optimizer) 316 raise ValueError("Unknown optimizer "+d.optimizer)
317 317
318 318
319 # 319 #
320 # Initial Entry or Resume ? 320 # Initial Entry or Resume ?
321 # 321 #
322 322
323 initialEpoch = 0 323 initialEpoch = 0
324 chkptFilename = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt.hdf5") 324 chkptFilename = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt.hdf5")
325 chkptFilename_weight = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt_weight.hdf5") 325 chkptFilename_weight = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt_weight.hdf5")
326 isResuming = os.path.isfile(chkptFilename) 326 isResuming = os.path.isfile(chkptFilename)
327 isResuming_weight = os.path.isfile(chkptFilename_weight) 327 isResuming_weight = os.path.isfile(chkptFilename_weight)
328 328
329 if isResuming or isResuming_weight: 329 if isResuming or isResuming_weight:
330 330
331 # Reload Model and Optimizer 331 # Reload Model and Optimizer
332 if d.dataset == "timit": 332 if d.dataset == "timit":
333 L.getLogger("entry").info("Re-Creating the model from scratch.") 333 L.getLogger("entry").info("Re-Creating the model from scratch.")
334 model_mono,test_func = getTimitResnetModel2D(d) 334 model_mono,test_func = getTimitResnetModel2D(d)
335 model_mono.load_weights(chkptFilename_weight) 335 model_mono.load_weights(chkptFilename_weight)
336 with H.File(chkptFilename_weight, "r") as f: 336 with H.File(chkptFilename_weight, "r") as f:
337 initialEpoch = int(f["initialEpoch"][...]) 337 initialEpoch = int(f["initialEpoch"][...])
338 L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) 338 L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1))
339 L.getLogger("entry").info("Compilation Started.") 339 L.getLogger("entry").info("Compilation Started.")
340 340
341 else: 341 else:
342 342
343 L.getLogger("entry").info("Reloading a model from "+chkptFilename+" ...") 343 L.getLogger("entry").info("Reloading a model from "+chkptFilename+" ...")
344 np.random.seed(d.seed % 2**32) 344 np.random.seed(d.seed % 2**32)
345 model = KM.load_model(chkptFilename, custom_objects={ 345 model = KM.load_model(chkptFilename, custom_objects={
346 "QuaternionConv2D": QuaternionConv2D, 346 "QuaternionConv2D": QuaternionConv2D,
347 "QuaternionConv1D": QuaternionConv1D, 347 "QuaternionConv1D": QuaternionConv1D,
348 "GetIFirst": GetIFirst, 348 "GetIFirst": GetIFirst,
349 "GetJFirst": GetJFirst, 349 "GetJFirst": GetJFirst,
350 "GetKFirst": GetKFirst, 350 "GetKFirst": GetKFirst,
351 "GetRFirst": GetRFirst, 351 "GetRFirst": GetRFirst,
352 }) 352 })
353 L.getLogger("entry").info("... reloading complete.") 353 L.getLogger("entry").info("... reloading complete.")
354 with H.File(chkptFilename, "r") as f: 354 with H.File(chkptFilename, "r") as f:
355 initialEpoch = int(f["initialEpoch"][...]) 355 initialEpoch = int(f["initialEpoch"][...])
356 L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) 356 L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1))
357 L.getLogger("entry").info("Compilation Started.") 357 L.getLogger("entry").info("Compilation Started.")
358 else: 358 else:
359 model_mono,test_func = getTimitModel2D(d) 359 model_mono,test_func = getTimitModel2D(d)
360 360
361 L.getLogger("entry").info("Compilation Started.") 361 L.getLogger("entry").info("Compilation Started.")
362 362
363 # 363 #
364 # Multi GPU: Can only save the model_mono because of keras bug 364 # Multi GPU: Can only save the model_mono because of keras bug
365 # 365 #
366 if d.gpus >1: 366 if d.gpus >1:
367 model = multi_gpu_model(model_mono, gpus=d.gpus) 367 model = multi_gpu_model(model_mono, gpus=d.gpus)
368 else: 368 else:
369 model = model_mono 369 model = model_mono
370 370
371 # 371 #
372 # Compile with CTC koss function 372 # Compile with CTC koss function
373 # 373 #
374 model.compile(opt, loss={'ctc': lambda y_true, y_pred: y_pred}) 374 model.compile(opt, loss={'ctc': lambda y_true, y_pred: y_pred})
375 375
376 376
377 # 377 #
378 # Precompile several backend functions 378 # Precompile several backend functions
379 # 379 #
380 if d.summary: 380 if d.summary:
381 model.summary() 381 model.summary()
382 L.getLogger("entry").info("# of Parameters: {:10d}".format(model.count_params())) 382 L.getLogger("entry").info("# of Parameters: {:10d}".format(model.count_params()))
383 L.getLogger("entry").info("Compiling Train Function...") 383 L.getLogger("entry").info("Compiling Train Function...")
384 t =- time.time() 384 t =- time.time()
385 model._make_train_function() 385 model._make_train_function()
386 t += time.time() 386 t += time.time()
387 L.getLogger("entry").info(" {:10.3f}s".format(t)) 387 L.getLogger("entry").info(" {:10.3f}s".format(t))
388 L.getLogger("entry").info("Compiling Predict Function...") 388 L.getLogger("entry").info("Compiling Predict Function...")
389 t =- time.time() 389 t =- time.time()
390 model._make_predict_function() 390 model._make_predict_function()
391 t += time.time() 391 t += time.time()
392 L.getLogger("entry").info(" {:10.3f}s".format(t)) 392 L.getLogger("entry").info(" {:10.3f}s".format(t))
393 L.getLogger("entry").info("Compiling Test Function...") 393 L.getLogger("entry").info("Compiling Test Function...")
394 t =- time.time() 394 t =- time.time()
395 model._make_test_function() 395 model._make_test_function()
396 t += time.time() 396 t += time.time()
397 L.getLogger("entry").info(" {:10.3f}s".format(t)) 397 L.getLogger("entry").info(" {:10.3f}s".format(t))
398 L.getLogger("entry").info("Compilation Ended.") 398 L.getLogger("entry").info("Compilation Ended.")
399 399
400 # 400 #
401 # Create Callbacks 401 # Create Callbacks
402 # 402 #
403 newLineCb = PrintNewlineAfterEpochCallback() 403 newLineCb = PrintNewlineAfterEpochCallback()
404 saveLastCb = SaveLastModel(d.workdir, d.save_prefix, model_mono, period=10) 404 saveLastCb = SaveLastModel(d.workdir, d.save_prefix, model_mono, period=10)
405 405
406 406
407 callbacks = [] 407 callbacks = []
408 408
409 # 409 #
410 # End of line for better looking 410 # End of line for better looking
411 # 411 #
412 callbacks += [newLineCb] 412 callbacks += [newLineCb]
413 if d.model=="quaternion": 413 if d.model=="quaternion":
414 quaternion = True 414 quaternion = True
415 else: 415 else:
416 quaternion = False 416 quaternion = False
417 417
418 if not os.path.exists(d.workdir+"/LOGS"):
419 os.makedirs(d.workdir+"/LOGS")
418 savedir = d.workdir+"/LOGS/"+d.save_prefix 420 savedir = d.workdir+"/LOGS/"+d.save_prefix
419 421
420 # 422 #
421 # Save the Train loss 423 # Save the Train loss
422 # 424 #
423 trainLoss = TrainLoss(savedir) 425 trainLoss = TrainLoss(savedir)
424 426
425 # 427 #
426 # Compute accuracies and save 428 # Compute accuracies and save
427 # 429 #
428 editDistValCb = EditDistance(test_func,'dev',quaternion, savedir) 430 editDistValCb = EditDistance(test_func,'dev',quaternion, savedir)
429 editDistTestCb = EditDistance(test_func,'test',quaternion, savedir) 431 editDistTestCb = EditDistance(test_func,'test',quaternion, savedir)
430 callbacks += [trainLoss] 432 callbacks += [trainLoss]
431 callbacks += [editDistValCb] 433 callbacks += [editDistValCb]
432 callbacks += [editDistTestCb] 434 callbacks += [editDistTestCb]
433 435
434 callbacks += [newLineCb] 436 callbacks += [newLineCb]
435 437
436 # 438 #
437 # Save the model 439 # Save the model
438 # 440 #
439 callbacks += [saveLastCb] 441 callbacks += [saveLastCb]
440 442
441 # 443 #
442 # Enter training loop. 444 # Enter training loop.
443 # 445 #
444 L .getLogger("entry").info("**********************************************") 446 L .getLogger("entry").info("**********************************************")
445 if isResuming: L.getLogger("entry").info("*** Reentering Training Loop @ Epoch {:5d} ***".format(initialEpoch+1)) 447 if isResuming: L.getLogger("entry").info("*** Reentering Training Loop @ Epoch {:5d} ***".format(initialEpoch+1))
446 else: L.getLogger("entry").info("*** Entering Training Loop @ First Epoch ***") 448 else: L.getLogger("entry").info("*** Entering Training Loop @ First Epoch ***")
447 L .getLogger("entry").info("**********************************************") 449 L .getLogger("entry").info("**********************************************")
448 450
449 451
450 # 452 #
451 # TRAIN 453 # TRAIN
452 # 454 #
453 455
454 ######## 456 ########
455 # Make sure to give the right number of mini_batch size 457 # Make sure to give the right number of mini_batch size
456 # needed to complete ONE epoch (according to your data generator) 458 # needed to complete ONE epoch (according to your data generator)
457 ######## 459 ########
458 460
459 epochs_train = 1144 461 epochs_train = 1144
460 epochs_dev = 121 462 epochs_dev = 121
461 463
462 model.fit_generator(generator = timitGenerator(data_stream_train), 464 model.fit_generator(generator = timitGenerator(data_stream_train),
463 steps_per_epoch = epochs_train, 465 steps_per_epoch = epochs_train,
464 epochs = d.num_epochs, 466 epochs = d.num_epochs,
465 verbose = 1, 467 verbose = 1,
466 validation_data = timitGenerator(data_stream_dev), 468 validation_data = timitGenerator(data_stream_dev),
467 validation_steps = epochs_dev, 469 validation_steps = epochs_dev,
468 callbacks = callbacks, 470 callbacks = callbacks,
469 initial_epoch = initialEpoch) 471 initial_epoch = initialEpoch)
470 472
471 473