Commit 7453639646158900fca7644f7bbc94cdbdf746fe
1 parent
f2d3bd1415
Exists in
master
Cleaning
Showing 1 changed file with 3 additions and 1 deletions Inline Diff
scripts/training.py
1 | #!/usr/bin/env python | 1 | #!/usr/bin/env python |
2 | # -*- coding: utf-8 -*- | 2 | # -*- coding: utf-8 -*- |
3 | # Authors: Parcollet Titouan | 3 | # Authors: Parcollet Titouan |
4 | 4 | ||
5 | # Imports | 5 | # Imports |
6 | import editdistance | 6 | import editdistance |
7 | import h5py | 7 | import h5py |
8 | import datasets.timit | 8 | import datasets.timit |
9 | from datasets.timit import Timit | 9 | from datasets.timit import Timit |
10 | from datasets.utils import construct_conv_stream, phone_to_phoneme_dict | 10 | from datasets.utils import construct_conv_stream, phone_to_phoneme_dict |
11 | import complexnn | 11 | import complexnn |
12 | from complexnn import * | 12 | from complexnn import * |
13 | import h5py as H | 13 | import h5py as H |
14 | import keras | 14 | import keras |
15 | from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler | 15 | from keras.callbacks import Callback, ModelCheckpoint, LearningRateScheduler |
16 | from keras.initializers import Orthogonal | 16 | from keras.initializers import Orthogonal |
17 | from keras.layers import Layer, Dropout, AveragePooling1D, AveragePooling2D, | 17 | from keras.layers import Layer, Dropout, AveragePooling1D, AveragePooling2D, |
18 | AveragePooling3D, add, Add, concatenate, Concatenate, | 18 | AveragePooling3D, add, Add, concatenate, Concatenate, |
19 | Input, Flatten, Dense, Convolution2D, BatchNormalization, | 19 | Input, Flatten, Dense, Convolution2D, BatchNormalization, |
20 | Activation, Reshape, ConvLSTM2D, Conv2D, Lambda | 20 | Activation, Reshape, ConvLSTM2D, Conv2D, Lambda |
21 | from keras.models import Model, load_model, save_model | 21 | from keras.models import Model, load_model, save_model |
22 | from keras.optimizers import SGD, Adam, RMSprop | 22 | from keras.optimizers import SGD, Adam, RMSprop |
23 | from keras.regularizers import l2 | 23 | from keras.regularizers import l2 |
24 | from keras.utils.np_utils import to_categorical | 24 | from keras.utils.np_utils import to_categorical |
25 | import keras.backend as K | 25 | import keras.backend as K |
26 | import keras.models as KM | 26 | import keras.models as KM |
27 | from keras.utils.training_utils import multi_gpu_model | 27 | from keras.utils.training_utils import multi_gpu_model |
28 | import logging as L | 28 | import logging as L |
29 | import numpy as np | 29 | import numpy as np |
30 | import os, pdb, socket, sys, time | 30 | import os, pdb, socket, sys, time |
31 | import theano as T | 31 | import theano as T |
32 | from keras.backend.tensorflow_backend import set_session | 32 | from keras.backend.tensorflow_backend import set_session |
33 | from models_timit import getTimitResnetModel2D,ctc_lambda_func | 33 | from models_timit import getTimitResnetModel2D,ctc_lambda_func |
34 | import tensorflow as tf | 34 | import tensorflow as tf |
35 | import itertools | 35 | import itertools |
36 | import random | 36 | import random |
37 | 37 | ||
38 | 38 | ||
39 | # | 39 | # |
40 | # Generator wrapper for timit | 40 | # Generator wrapper for timit |
41 | # | 41 | # |
42 | 42 | ||
43 | def timitGenerator(stream): | 43 | def timitGenerator(stream): |
44 | while True: | 44 | while True: |
45 | for data in stream.get_epoch_iterator(): | 45 | for data in stream.get_epoch_iterator(): |
46 | yield data | 46 | yield data |
47 | 47 | ||
48 | # | 48 | # |
49 | # Custom metrics | 49 | # Custom metrics |
50 | # | 50 | # |
51 | class EditDistance(Callback): | 51 | class EditDistance(Callback): |
52 | def __init__(self, func, dataset, quaternion, save_prefix): | 52 | def __init__(self, func, dataset, quaternion, save_prefix): |
53 | self.func = func | 53 | self.func = func |
54 | if(dataset in ['train','test','dev']): | 54 | if(dataset in ['train','test','dev']): |
55 | self.dataset_type = dataset | 55 | self.dataset_type = dataset |
56 | self.save_prefix = save_prefix | 56 | self.save_prefix = save_prefix |
57 | self.dataset = Timit(str(dataset)) | 57 | self.dataset = Timit(str(dataset)) |
58 | self.full_phonemes_dict = self.dataset.get_phoneme_dict() | 58 | self.full_phonemes_dict = self.dataset.get_phoneme_dict() |
59 | self.ind_phonemes_dict = self.dataset.get_phoneme_ind_dict() | 59 | self.ind_phonemes_dict = self.dataset.get_phoneme_ind_dict() |
60 | self.rng = np.random.RandomState(123) | 60 | self.rng = np.random.RandomState(123) |
61 | self.data_stream = construct_conv_stream(self.dataset, self.rng, 20, 1000,quaternion) | 61 | self.data_stream = construct_conv_stream(self.dataset, self.rng, 20, 1000,quaternion) |
62 | 62 | ||
63 | else: | 63 | else: |
64 | raise ValueError("Unknown dataset for edit distance "+dataset) | 64 | raise ValueError("Unknown dataset for edit distance "+dataset) |
65 | 65 | ||
66 | def labels_to_text(self,labels): | 66 | def labels_to_text(self,labels): |
67 | ret = [] | 67 | ret = [] |
68 | for c in labels: | 68 | for c in labels: |
69 | if c == len(self.full_phonemes_dict) - 2: | 69 | if c == len(self.full_phonemes_dict) - 2: |
70 | ret.append("") | 70 | ret.append("") |
71 | else: | 71 | else: |
72 | c_ = self.full_phonemes_dict[c + 1] | 72 | c_ = self.full_phonemes_dict[c + 1] |
73 | ret.append(phone_to_phoneme_dict.get(c_, c_)) | 73 | ret.append(phone_to_phoneme_dict.get(c_, c_)) |
74 | ret = [k for k, g in itertools.groupby(ret)] | 74 | ret = [k for k, g in itertools.groupby(ret)] |
75 | return list(filter(lambda c: c != "", ret)) | 75 | return list(filter(lambda c: c != "", ret)) |
76 | 76 | ||
77 | def decode_batch(self, out, mask): | 77 | def decode_batch(self, out, mask): |
78 | ret = [] | 78 | ret = [] |
79 | for j in range(out.shape[0]): | 79 | for j in range(out.shape[0]): |
80 | out_best = list(np.argmax(out[j], 1))[:int(mask[j])] | 80 | out_best = list(np.argmax(out[j], 1))[:int(mask[j])] |
81 | out_best = [k for k, g in itertools.groupby(out_best)] | 81 | out_best = [k for k, g in itertools.groupby(out_best)] |
82 | # map from 61-d to 39-d | 82 | # map from 61-d to 39-d |
83 | out_str = self.labels_to_text(out_best) | 83 | out_str = self.labels_to_text(out_best) |
84 | ret.append(out_str) | 84 | ret.append(out_str) |
85 | return ret | 85 | return ret |
86 | 86 | ||
87 | def on_epoch_end(self, epoch, logs={}): | 87 | def on_epoch_end(self, epoch, logs={}): |
88 | mean_norm_ed = 0. | 88 | mean_norm_ed = 0. |
89 | num = 0 | 89 | num = 0 |
90 | for data in self.data_stream.get_epoch_iterator(): | 90 | for data in self.data_stream.get_epoch_iterator(): |
91 | x, y = data | 91 | x, y = data |
92 | y_pred = self.func([x[0]])[0] | 92 | y_pred = self.func([x[0]])[0] |
93 | decoded_y_pred = self.decode_batch(y_pred, x[1]) | 93 | decoded_y_pred = self.decode_batch(y_pred, x[1]) |
94 | decoded_gt = [] | 94 | decoded_gt = [] |
95 | for i in range(x[2].shape[0]): | 95 | for i in range(x[2].shape[0]): |
96 | decoded_gt.append(self.labels_to_text(x[2][i][:int(x[3][i])])) | 96 | decoded_gt.append(self.labels_to_text(x[2][i][:int(x[3][i])])) |
97 | num += len(decoded_y_pred) | 97 | num += len(decoded_y_pred) |
98 | for i, (_pred, _gt) in enumerate(zip(decoded_y_pred, decoded_gt)): | 98 | for i, (_pred, _gt) in enumerate(zip(decoded_y_pred, decoded_gt)): |
99 | edit_dist = editdistance.eval(_pred, _gt) | 99 | edit_dist = editdistance.eval(_pred, _gt) |
100 | mean_norm_ed += float(edit_dist) / float(len(_gt)) | 100 | mean_norm_ed += float(edit_dist) / float(len(_gt)) |
101 | mean_norm_ed = mean_norm_ed / num | 101 | mean_norm_ed = mean_norm_ed / num |
102 | 102 | ||
103 | # Dump To File Logs at every epoch for clusters sbatch | 103 | # Dump To File Logs at every epoch for clusters sbatch |
104 | f=open(str(self.save_prefix)+"_"+str(self.dataset_type)+"_PER.txt",'ab') | 104 | f=open(str(self.save_prefix)+"_"+str(self.dataset_type)+"_PER.txt",'ab') |
105 | mean = np.array([mean_norm_ed]) | 105 | mean = np.array([mean_norm_ed]) |
106 | np.savetxt(f,mean) | 106 | np.savetxt(f,mean) |
107 | f.close() | 107 | f.close() |
108 | L.getLogger("train").info("PER on "+str(self.dataset_type)+" : "+str(mean_norm_ed)+" at epoch "+str(epoch)) | 108 | L.getLogger("train").info("PER on "+str(self.dataset_type)+" : "+str(mean_norm_ed)+" at epoch "+str(epoch)) |
109 | 109 | ||
110 | # | 110 | # |
111 | # Callbacks: | 111 | # Callbacks: |
112 | # | 112 | # |
113 | 113 | ||
114 | class TrainLoss(Callback): | 114 | class TrainLoss(Callback): |
115 | def __init__(self, savedir): | 115 | def __init__(self, savedir): |
116 | self.savedir = savedir | 116 | self.savedir = savedir |
117 | def on_epoch_end(self, epoch, logs={}): | 117 | def on_epoch_end(self, epoch, logs={}): |
118 | f=open(str(self.savedir)+"_train_loss.txt",'ab') | 118 | f=open(str(self.savedir)+"_train_loss.txt",'ab') |
119 | f2=open(str(self.savedir)+"_dev_loss.txt",'ab') | 119 | f2=open(str(self.savedir)+"_dev_loss.txt",'ab') |
120 | value = float(logs['loss']) | 120 | value = float(logs['loss']) |
121 | np.savetxt(f,np.array([value])) | 121 | np.savetxt(f,np.array([value])) |
122 | f.close() | 122 | f.close() |
123 | value = float(logs['val_loss']) | 123 | value = float(logs['val_loss']) |
124 | np.savetxt(f2,np.array([value])) | 124 | np.savetxt(f2,np.array([value])) |
125 | f2.close() | 125 | f2.close() |
126 | # | 126 | # |
127 | # Print a newline after each epoch, because Keras doesn't. Grumble. | 127 | # Print a newline after each epoch, because Keras doesn't. Grumble. |
128 | # | 128 | # |
129 | 129 | ||
130 | class PrintNewlineAfterEpochCallback(Callback): | 130 | class PrintNewlineAfterEpochCallback(Callback): |
131 | def on_epoch_end(self, epoch, logs={}): | 131 | def on_epoch_end(self, epoch, logs={}): |
132 | sys.stdout.write("\n") | 132 | sys.stdout.write("\n") |
133 | # | 133 | # |
134 | # Save checkpoints. | 134 | # Save checkpoints. |
135 | # | 135 | # |
136 | 136 | ||
137 | class SaveLastModel(Callback): | 137 | class SaveLastModel(Callback): |
138 | def __init__(self, workdir, save_prefix, model_mono,period=10): | 138 | def __init__(self, workdir, save_prefix, model_mono,period=10): |
139 | self.workdir = workdir | 139 | self.workdir = workdir |
140 | self.model_mono = model_mono | 140 | self.model_mono = model_mono |
141 | self.chkptsdir = os.path.join(self.workdir, "chkpts") | 141 | self.chkptsdir = os.path.join(self.workdir, "chkpts") |
142 | self.save_prefix = save_prefix | 142 | self.save_prefix = save_prefix |
143 | if not os.path.isdir(self.chkptsdir): | 143 | if not os.path.isdir(self.chkptsdir): |
144 | os.mkdir(self.chkptsdir) | 144 | os.mkdir(self.chkptsdir) |
145 | self.period_of_epochs = period | 145 | self.period_of_epochs = period |
146 | self.linkFilename = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt.hdf5") | 146 | self.linkFilename = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt.hdf5") |
147 | self.linkFilename_weight = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt_weight.hdf5") | 147 | self.linkFilename_weight = os.path.join(self.chkptsdir, str(save_prefix)+"ModelChkpt_weight.hdf5") |
148 | 148 | ||
149 | def on_epoch_end(self, epoch, logs={}): | 149 | def on_epoch_end(self, epoch, logs={}): |
150 | if (epoch + 1) % self.period_of_epochs == 0: | 150 | if (epoch + 1) % self.period_of_epochs == 0: |
151 | 151 | ||
152 | # Filenames | 152 | # Filenames |
153 | baseHDF5Filename = str(self.save_prefix)+"ModelChkpt{:06d}.hdf5".format(epoch+1) | 153 | baseHDF5Filename = str(self.save_prefix)+"ModelChkpt{:06d}.hdf5".format(epoch+1) |
154 | baseHDF5Filename_weight = str(self.save_prefix)+"ModelChkpt{:06d}_weight.hdf5".format(epoch+1) | 154 | baseHDF5Filename_weight = str(self.save_prefix)+"ModelChkpt{:06d}_weight.hdf5".format(epoch+1) |
155 | baseYAMLFilename = str(self.save_prefix)+"ModelChkpt{:06d}.yaml".format(epoch+1) | 155 | baseYAMLFilename = str(self.save_prefix)+"ModelChkpt{:06d}.yaml".format(epoch+1) |
156 | hdf5Filename = os.path.join(self.chkptsdir, baseHDF5Filename) | 156 | hdf5Filename = os.path.join(self.chkptsdir, baseHDF5Filename) |
157 | hdf5Filename_weight = os.path.join(self.chkptsdir, baseHDF5Filename_weight) | 157 | hdf5Filename_weight = os.path.join(self.chkptsdir, baseHDF5Filename_weight) |
158 | yamlFilename = os.path.join(self.chkptsdir, baseYAMLFilename) | 158 | yamlFilename = os.path.join(self.chkptsdir, baseYAMLFilename) |
159 | 159 | ||
160 | # YAML | 160 | # YAML |
161 | yamlModel = self.model_mono.to_yaml() | 161 | yamlModel = self.model_mono.to_yaml() |
162 | with open(yamlFilename, "w") as yamlFile: | 162 | with open(yamlFilename, "w") as yamlFile: |
163 | yamlFile.write(yamlModel) | 163 | yamlFile.write(yamlModel) |
164 | 164 | ||
165 | # HDF5 | 165 | # HDF5 |
166 | KM.save_model(self.model_mono, hdf5Filename) | 166 | KM.save_model(self.model_mono, hdf5Filename) |
167 | self.model_mono.save_weights(hdf5Filename_weight) | 167 | self.model_mono.save_weights(hdf5Filename_weight) |
168 | with H.File(hdf5Filename, "r+") as f: | 168 | with H.File(hdf5Filename, "r+") as f: |
169 | f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) | 169 | f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) |
170 | f.flush() | 170 | f.flush() |
171 | with H.File(hdf5Filename_weight, "r+") as f: | 171 | with H.File(hdf5Filename_weight, "r+") as f: |
172 | f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) | 172 | f.require_dataset("initialEpoch", (), "uint64", True)[...] = int(epoch+1) |
173 | f.flush() | 173 | f.flush() |
174 | 174 | ||
175 | 175 | ||
176 | # Symlink to new HDF5 file, then atomically rename and replace. | 176 | # Symlink to new HDF5 file, then atomically rename and replace. |
177 | os.symlink(baseHDF5Filename_weight, self.linkFilename_weight+".rename") | 177 | os.symlink(baseHDF5Filename_weight, self.linkFilename_weight+".rename") |
178 | os.rename (self.linkFilename_weight+".rename", | 178 | os.rename (self.linkFilename_weight+".rename", |
179 | self.linkFilename_weight) | 179 | self.linkFilename_weight) |
180 | 180 | ||
181 | 181 | ||
182 | # Symlink to new HDF5 file, then atomically rename and replace. | 182 | # Symlink to new HDF5 file, then atomically rename and replace. |
183 | os.symlink(baseHDF5Filename, self.linkFilename+".rename") | 183 | os.symlink(baseHDF5Filename, self.linkFilename+".rename") |
184 | os.rename (self.linkFilename+".rename", | 184 | os.rename (self.linkFilename+".rename", |
185 | self.linkFilename) | 185 | self.linkFilename) |
186 | 186 | ||
187 | 187 | ||
188 | L.getLogger("train").info("Saved checkpoint to {:s} at epoch {:5d}".format(hdf5Filename, epoch+1)) | 188 | L.getLogger("train").info("Saved checkpoint to {:s} at epoch {:5d}".format(hdf5Filename, epoch+1)) |
189 | 189 | ||
190 | # | 190 | # |
191 | # Summarize environment variable. | 191 | # Summarize environment variable. |
192 | # | 192 | # |
193 | 193 | ||
194 | def summarizeEnvvar(var): | 194 | def summarizeEnvvar(var): |
195 | if var in os.environ: return var+"="+os.environ.get(var) | 195 | if var in os.environ: return var+"="+os.environ.get(var) |
196 | else: return var+" unset" | 196 | else: return var+" unset" |
197 | 197 | ||
198 | # | 198 | # |
199 | # TRAINING PROCESS | 199 | # TRAINING PROCESS |
200 | # | 200 | # |
201 | 201 | ||
202 | def train(d): | 202 | def train(d): |
203 | 203 | ||
204 | # | 204 | # |
205 | # | 205 | # |
206 | # Log important data about how we were invoked. | 206 | # Log important data about how we were invoked. |
207 | # | 207 | # |
208 | L.getLogger("entry").info("INVOCATION: "+" ".join(sys.argv)) | 208 | L.getLogger("entry").info("INVOCATION: "+" ".join(sys.argv)) |
209 | L.getLogger("entry").info("HOSTNAME: "+socket.gethostname()) | 209 | L.getLogger("entry").info("HOSTNAME: "+socket.gethostname()) |
210 | L.getLogger("entry").info("PWD: "+os.getcwd()) | 210 | L.getLogger("entry").info("PWD: "+os.getcwd()) |
211 | L.getLogger("entry").info("CUDA DEVICE: "+str(d.device)) | 211 | L.getLogger("entry").info("CUDA DEVICE: "+str(d.device)) |
212 | os.environ["CUDA_VISIBLE_DEVICES"]=str(d.device) | 212 | os.environ["CUDA_VISIBLE_DEVICES"]=str(d.device) |
213 | 213 | ||
214 | # | 214 | # |
215 | # Setup GPUs | 215 | # Setup GPUs |
216 | # | 216 | # |
217 | config = tf.ConfigProto() | 217 | config = tf.ConfigProto() |
218 | 218 | ||
219 | # | 219 | # |
220 | # Don't pre-allocate memory; allocate as-needed | 220 | # Don't pre-allocate memory; allocate as-needed |
221 | # | 221 | # |
222 | config.gpu_options.allow_growth = True | 222 | config.gpu_options.allow_growth = True |
223 | 223 | ||
224 | # | 224 | # |
225 | # Only allow a total of half the GPU memory to be allocated | 225 | # Only allow a total of half the GPU memory to be allocated |
226 | # | 226 | # |
227 | config.gpu_options.per_process_gpu_memory_fraction = d.memory | 227 | config.gpu_options.per_process_gpu_memory_fraction = d.memory |
228 | 228 | ||
229 | # | 229 | # |
230 | # Create a session with the above options specified. | 230 | # Create a session with the above options specified. |
231 | # | 231 | # |
232 | K.tensorflow_backend.set_session(tf.Session(config=config)) | 232 | K.tensorflow_backend.set_session(tf.Session(config=config)) |
233 | 233 | ||
234 | summary = "\n" | 234 | summary = "\n" |
235 | summary += "Environment:\n" | 235 | summary += "Environment:\n" |
236 | summary += summarizeEnvvar("THEANO_FLAGS")+"\n" | 236 | summary += summarizeEnvvar("THEANO_FLAGS")+"\n" |
237 | summary += "\n" | 237 | summary += "\n" |
238 | summary += "Software Versions:\n" | 238 | summary += "Software Versions:\n" |
239 | summary += "Theano: "+T.__version__+"\n" | 239 | summary += "Theano: "+T.__version__+"\n" |
240 | summary += "Keras: "+keras.__version__+"\n" | 240 | summary += "Keras: "+keras.__version__+"\n" |
241 | summary += "\n" | 241 | summary += "\n" |
242 | summary += "Arguments:\n" | 242 | summary += "Arguments:\n" |
243 | summary += "Path to Datasets: "+str(d.datadir)+"\n" | 243 | summary += "Path to Datasets: "+str(d.datadir)+"\n" |
244 | summary += "Number of GPUs: "+str(d.datadir)+"\n" | 244 | summary += "Number of GPUs: "+str(d.datadir)+"\n" |
245 | summary += "Path to Workspace: "+str(d.workdir)+"\n" | 245 | summary += "Path to Workspace: "+str(d.workdir)+"\n" |
246 | summary += "Model: "+str(d.model)+"\n" | 246 | summary += "Model: "+str(d.model)+"\n" |
247 | summary += "Number of Epochs: "+str(d.num_epochs)+"\n" | 247 | summary += "Number of Epochs: "+str(d.num_epochs)+"\n" |
248 | summary += "Number of Start Filters: "+str(d.start_filter)+"\n" | 248 | summary += "Number of Start Filters: "+str(d.start_filter)+"\n" |
249 | summary += "Number of Layers: "+str(d.num_layers)+"\n" | 249 | summary += "Number of Layers: "+str(d.num_layers)+"\n" |
250 | summary += "Optimizer: "+str(d.optimizer)+"\n" | 250 | summary += "Optimizer: "+str(d.optimizer)+"\n" |
251 | summary += "Learning Rate: "+str(d.lr)+"\n" | 251 | summary += "Learning Rate: "+str(d.lr)+"\n" |
252 | summary += "Learning Rate Decay: "+str(d.decay)+"\n" | 252 | summary += "Learning Rate Decay: "+str(d.decay)+"\n" |
253 | summary += "Clipping Norm: "+str(d.clipnorm)+"\n" | 253 | summary += "Clipping Norm: "+str(d.clipnorm)+"\n" |
254 | summary += "Clipping Value: "+str(d.clipval)+"\n" | 254 | summary += "Clipping Value: "+str(d.clipval)+"\n" |
255 | summary += "Dropout Probability: "+str(d.dropout)+"\n" | 255 | summary += "Dropout Probability: "+str(d.dropout)+"\n" |
256 | if d.optimizer in ["adam"]: | 256 | if d.optimizer in ["adam"]: |
257 | summary += "Beta 1: "+str(d.beta1)+"\n" | 257 | summary += "Beta 1: "+str(d.beta1)+"\n" |
258 | summary += "Beta 2: "+str(d.beta2)+"\n" | 258 | summary += "Beta 2: "+str(d.beta2)+"\n" |
259 | else: | 259 | else: |
260 | summary += "Momentum: "+str(d.momentum)+"\n" | 260 | summary += "Momentum: "+str(d.momentum)+"\n" |
261 | summary += "Save Prefix: "+str(d.save_prefix)+"\n" | 261 | summary += "Save Prefix: "+str(d.save_prefix)+"\n" |
262 | L.getLogger("entry").info(summary[:-1]) | 262 | L.getLogger("entry").info(summary[:-1]) |
263 | 263 | ||
264 | # | 264 | # |
265 | # Load dataset | 265 | # Load dataset |
266 | # | 266 | # |
267 | L.getLogger("entry").info("Loading dataset {:s} ...".format(d.dataset)) | 267 | L.getLogger("entry").info("Loading dataset {:s} ...".format(d.dataset)) |
268 | np.random.seed(d.seed % 2**32) | 268 | np.random.seed(d.seed % 2**32) |
269 | 269 | ||
270 | # | 270 | # |
271 | # Create training data generator | 271 | # Create training data generator |
272 | # | 272 | # |
273 | dataset = Timit('train') | 273 | dataset = Timit('train') |
274 | rng=np.random.RandomState(123) | 274 | rng=np.random.RandomState(123) |
275 | if d.model =="quaternion": | 275 | if d.model =="quaternion": |
276 | data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=True) | 276 | data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=True) |
277 | else: | 277 | else: |
278 | data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) | 278 | data_stream_train = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) |
279 | 279 | ||
280 | # | 280 | # |
281 | # Create dev data generator | 281 | # Create dev data generator |
282 | # | 282 | # |
283 | dataset = Timit('dev') | 283 | dataset = Timit('dev') |
284 | rng=np.random.RandomState(123) | 284 | rng=np.random.RandomState(123) |
285 | if d.model =="quaternion": | 285 | if d.model =="quaternion": |
286 | data_stream_dev = construct_conv_stream(dataset, rng, 200, 10000, quaternion=True) | 286 | data_stream_dev = construct_conv_stream(dataset, rng, 200, 10000, quaternion=True) |
287 | else: | 287 | else: |
288 | data_stream_dev = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) | 288 | data_stream_dev = construct_conv_stream(dataset, rng, 200, 1000, quaternion=False) |
289 | 289 | ||
290 | 290 | ||
291 | L.getLogger("entry").info("Training set length: "+str(Timit('train').num_examples)) | 291 | L.getLogger("entry").info("Training set length: "+str(Timit('train').num_examples)) |
292 | L.getLogger("entry").info("Validation set length: "+str(Timit('dev').num_examples)) | 292 | L.getLogger("entry").info("Validation set length: "+str(Timit('dev').num_examples)) |
293 | L.getLogger("entry").info("Test set length: "+str(Timit('test').num_examples)) | 293 | L.getLogger("entry").info("Test set length: "+str(Timit('test').num_examples)) |
294 | L.getLogger("entry").info("Loaded dataset {:s}.".format(d.dataset)) | 294 | L.getLogger("entry").info("Loaded dataset {:s}.".format(d.dataset)) |
295 | 295 | ||
296 | # | 296 | # |
297 | # Optimizers | 297 | # Optimizers |
298 | # | 298 | # |
299 | if d.optimizer in ["sgd", "nag"]: | 299 | if d.optimizer in ["sgd", "nag"]: |
300 | opt = SGD (lr = d.lr, | 300 | opt = SGD (lr = d.lr, |
301 | momentum = d.momentum, | 301 | momentum = d.momentum, |
302 | decay = d.decay, | 302 | decay = d.decay, |
303 | nesterov = (d.optimizer=="nag"), | 303 | nesterov = (d.optimizer=="nag"), |
304 | clipnorm = d.clipnorm) | 304 | clipnorm = d.clipnorm) |
305 | elif d.optimizer == "rmsprop": | 305 | elif d.optimizer == "rmsprop": |
306 | opt = RMSProp(lr = d.lr, | 306 | opt = RMSProp(lr = d.lr, |
307 | decay = d.decay, | 307 | decay = d.decay, |
308 | clipnorm = d.clipnorm) | 308 | clipnorm = d.clipnorm) |
309 | elif d.optimizer == "adam": | 309 | elif d.optimizer == "adam": |
310 | opt = Adam (lr = d.lr, | 310 | opt = Adam (lr = d.lr, |
311 | beta_1 = d.beta1, | 311 | beta_1 = d.beta1, |
312 | beta_2 = d.beta2, | 312 | beta_2 = d.beta2, |
313 | decay = d.decay, | 313 | decay = d.decay, |
314 | clipnorm = d.clipnorm) | 314 | clipnorm = d.clipnorm) |
315 | else: | 315 | else: |
316 | raise ValueError("Unknown optimizer "+d.optimizer) | 316 | raise ValueError("Unknown optimizer "+d.optimizer) |
317 | 317 | ||
318 | 318 | ||
319 | # | 319 | # |
320 | # Initial Entry or Resume ? | 320 | # Initial Entry or Resume ? |
321 | # | 321 | # |
322 | 322 | ||
323 | initialEpoch = 0 | 323 | initialEpoch = 0 |
324 | chkptFilename = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt.hdf5") | 324 | chkptFilename = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt.hdf5") |
325 | chkptFilename_weight = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt_weight.hdf5") | 325 | chkptFilename_weight = os.path.join(d.workdir, "chkpts", str(d.save_prefix)+"ModelChkpt_weight.hdf5") |
326 | isResuming = os.path.isfile(chkptFilename) | 326 | isResuming = os.path.isfile(chkptFilename) |
327 | isResuming_weight = os.path.isfile(chkptFilename_weight) | 327 | isResuming_weight = os.path.isfile(chkptFilename_weight) |
328 | 328 | ||
329 | if isResuming or isResuming_weight: | 329 | if isResuming or isResuming_weight: |
330 | 330 | ||
331 | # Reload Model and Optimizer | 331 | # Reload Model and Optimizer |
332 | if d.dataset == "timit": | 332 | if d.dataset == "timit": |
333 | L.getLogger("entry").info("Re-Creating the model from scratch.") | 333 | L.getLogger("entry").info("Re-Creating the model from scratch.") |
334 | model_mono,test_func = getTimitResnetModel2D(d) | 334 | model_mono,test_func = getTimitResnetModel2D(d) |
335 | model_mono.load_weights(chkptFilename_weight) | 335 | model_mono.load_weights(chkptFilename_weight) |
336 | with H.File(chkptFilename_weight, "r") as f: | 336 | with H.File(chkptFilename_weight, "r") as f: |
337 | initialEpoch = int(f["initialEpoch"][...]) | 337 | initialEpoch = int(f["initialEpoch"][...]) |
338 | L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) | 338 | L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) |
339 | L.getLogger("entry").info("Compilation Started.") | 339 | L.getLogger("entry").info("Compilation Started.") |
340 | 340 | ||
341 | else: | 341 | else: |
342 | 342 | ||
343 | L.getLogger("entry").info("Reloading a model from "+chkptFilename+" ...") | 343 | L.getLogger("entry").info("Reloading a model from "+chkptFilename+" ...") |
344 | np.random.seed(d.seed % 2**32) | 344 | np.random.seed(d.seed % 2**32) |
345 | model = KM.load_model(chkptFilename, custom_objects={ | 345 | model = KM.load_model(chkptFilename, custom_objects={ |
346 | "QuaternionConv2D": QuaternionConv2D, | 346 | "QuaternionConv2D": QuaternionConv2D, |
347 | "QuaternionConv1D": QuaternionConv1D, | 347 | "QuaternionConv1D": QuaternionConv1D, |
348 | "GetIFirst": GetIFirst, | 348 | "GetIFirst": GetIFirst, |
349 | "GetJFirst": GetJFirst, | 349 | "GetJFirst": GetJFirst, |
350 | "GetKFirst": GetKFirst, | 350 | "GetKFirst": GetKFirst, |
351 | "GetRFirst": GetRFirst, | 351 | "GetRFirst": GetRFirst, |
352 | }) | 352 | }) |
353 | L.getLogger("entry").info("... reloading complete.") | 353 | L.getLogger("entry").info("... reloading complete.") |
354 | with H.File(chkptFilename, "r") as f: | 354 | with H.File(chkptFilename, "r") as f: |
355 | initialEpoch = int(f["initialEpoch"][...]) | 355 | initialEpoch = int(f["initialEpoch"][...]) |
356 | L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) | 356 | L.getLogger("entry").info("Training will restart at epoch {:5d}.".format(initialEpoch+1)) |
357 | L.getLogger("entry").info("Compilation Started.") | 357 | L.getLogger("entry").info("Compilation Started.") |
358 | else: | 358 | else: |
359 | model_mono,test_func = getTimitModel2D(d) | 359 | model_mono,test_func = getTimitModel2D(d) |
360 | 360 | ||
361 | L.getLogger("entry").info("Compilation Started.") | 361 | L.getLogger("entry").info("Compilation Started.") |
362 | 362 | ||
363 | # | 363 | # |
364 | # Multi GPU: Can only save the model_mono because of keras bug | 364 | # Multi GPU: Can only save the model_mono because of keras bug |
365 | # | 365 | # |
366 | if d.gpus >1: | 366 | if d.gpus >1: |
367 | model = multi_gpu_model(model_mono, gpus=d.gpus) | 367 | model = multi_gpu_model(model_mono, gpus=d.gpus) |
368 | else: | 368 | else: |
369 | model = model_mono | 369 | model = model_mono |
370 | 370 | ||
371 | # | 371 | # |
372 | # Compile with CTC koss function | 372 | # Compile with CTC koss function |
373 | # | 373 | # |
374 | model.compile(opt, loss={'ctc': lambda y_true, y_pred: y_pred}) | 374 | model.compile(opt, loss={'ctc': lambda y_true, y_pred: y_pred}) |
375 | 375 | ||
376 | 376 | ||
377 | # | 377 | # |
378 | # Precompile several backend functions | 378 | # Precompile several backend functions |
379 | # | 379 | # |
380 | if d.summary: | 380 | if d.summary: |
381 | model.summary() | 381 | model.summary() |
382 | L.getLogger("entry").info("# of Parameters: {:10d}".format(model.count_params())) | 382 | L.getLogger("entry").info("# of Parameters: {:10d}".format(model.count_params())) |
383 | L.getLogger("entry").info("Compiling Train Function...") | 383 | L.getLogger("entry").info("Compiling Train Function...") |
384 | t =- time.time() | 384 | t =- time.time() |
385 | model._make_train_function() | 385 | model._make_train_function() |
386 | t += time.time() | 386 | t += time.time() |
387 | L.getLogger("entry").info(" {:10.3f}s".format(t)) | 387 | L.getLogger("entry").info(" {:10.3f}s".format(t)) |
388 | L.getLogger("entry").info("Compiling Predict Function...") | 388 | L.getLogger("entry").info("Compiling Predict Function...") |
389 | t =- time.time() | 389 | t =- time.time() |
390 | model._make_predict_function() | 390 | model._make_predict_function() |
391 | t += time.time() | 391 | t += time.time() |
392 | L.getLogger("entry").info(" {:10.3f}s".format(t)) | 392 | L.getLogger("entry").info(" {:10.3f}s".format(t)) |
393 | L.getLogger("entry").info("Compiling Test Function...") | 393 | L.getLogger("entry").info("Compiling Test Function...") |
394 | t =- time.time() | 394 | t =- time.time() |
395 | model._make_test_function() | 395 | model._make_test_function() |
396 | t += time.time() | 396 | t += time.time() |
397 | L.getLogger("entry").info(" {:10.3f}s".format(t)) | 397 | L.getLogger("entry").info(" {:10.3f}s".format(t)) |
398 | L.getLogger("entry").info("Compilation Ended.") | 398 | L.getLogger("entry").info("Compilation Ended.") |
399 | 399 | ||
400 | # | 400 | # |
401 | # Create Callbacks | 401 | # Create Callbacks |
402 | # | 402 | # |
403 | newLineCb = PrintNewlineAfterEpochCallback() | 403 | newLineCb = PrintNewlineAfterEpochCallback() |
404 | saveLastCb = SaveLastModel(d.workdir, d.save_prefix, model_mono, period=10) | 404 | saveLastCb = SaveLastModel(d.workdir, d.save_prefix, model_mono, period=10) |
405 | 405 | ||
406 | 406 | ||
407 | callbacks = [] | 407 | callbacks = [] |
408 | 408 | ||
409 | # | 409 | # |
410 | # End of line for better looking | 410 | # End of line for better looking |
411 | # | 411 | # |
412 | callbacks += [newLineCb] | 412 | callbacks += [newLineCb] |
413 | if d.model=="quaternion": | 413 | if d.model=="quaternion": |
414 | quaternion = True | 414 | quaternion = True |
415 | else: | 415 | else: |
416 | quaternion = False | 416 | quaternion = False |
417 | 417 | ||
418 | if not os.path.exists(d.workdir+"/LOGS"): | ||
419 | os.makedirs(d.workdir+"/LOGS") | ||
418 | savedir = d.workdir+"/LOGS/"+d.save_prefix | 420 | savedir = d.workdir+"/LOGS/"+d.save_prefix |
419 | 421 | ||
420 | # | 422 | # |
421 | # Save the Train loss | 423 | # Save the Train loss |
422 | # | 424 | # |
423 | trainLoss = TrainLoss(savedir) | 425 | trainLoss = TrainLoss(savedir) |
424 | 426 | ||
425 | # | 427 | # |
426 | # Compute accuracies and save | 428 | # Compute accuracies and save |
427 | # | 429 | # |
428 | editDistValCb = EditDistance(test_func,'dev',quaternion, savedir) | 430 | editDistValCb = EditDistance(test_func,'dev',quaternion, savedir) |
429 | editDistTestCb = EditDistance(test_func,'test',quaternion, savedir) | 431 | editDistTestCb = EditDistance(test_func,'test',quaternion, savedir) |
430 | callbacks += [trainLoss] | 432 | callbacks += [trainLoss] |
431 | callbacks += [editDistValCb] | 433 | callbacks += [editDistValCb] |
432 | callbacks += [editDistTestCb] | 434 | callbacks += [editDistTestCb] |
433 | 435 | ||
434 | callbacks += [newLineCb] | 436 | callbacks += [newLineCb] |
435 | 437 | ||
436 | # | 438 | # |
437 | # Save the model | 439 | # Save the model |
438 | # | 440 | # |
439 | callbacks += [saveLastCb] | 441 | callbacks += [saveLastCb] |
440 | 442 | ||
441 | # | 443 | # |
442 | # Enter training loop. | 444 | # Enter training loop. |
443 | # | 445 | # |
444 | L .getLogger("entry").info("**********************************************") | 446 | L .getLogger("entry").info("**********************************************") |
445 | if isResuming: L.getLogger("entry").info("*** Reentering Training Loop @ Epoch {:5d} ***".format(initialEpoch+1)) | 447 | if isResuming: L.getLogger("entry").info("*** Reentering Training Loop @ Epoch {:5d} ***".format(initialEpoch+1)) |
446 | else: L.getLogger("entry").info("*** Entering Training Loop @ First Epoch ***") | 448 | else: L.getLogger("entry").info("*** Entering Training Loop @ First Epoch ***") |
447 | L .getLogger("entry").info("**********************************************") | 449 | L .getLogger("entry").info("**********************************************") |
448 | 450 | ||
449 | 451 | ||
450 | # | 452 | # |
451 | # TRAIN | 453 | # TRAIN |
452 | # | 454 | # |
453 | 455 | ||
454 | ######## | 456 | ######## |
455 | # Make sure to give the right number of mini_batch size | 457 | # Make sure to give the right number of mini_batch size |
456 | # needed to complete ONE epoch (according to your data generator) | 458 | # needed to complete ONE epoch (according to your data generator) |
457 | ######## | 459 | ######## |
458 | 460 | ||
459 | epochs_train = 1144 | 461 | epochs_train = 1144 |
460 | epochs_dev = 121 | 462 | epochs_dev = 121 |
461 | 463 | ||
462 | model.fit_generator(generator = timitGenerator(data_stream_train), | 464 | model.fit_generator(generator = timitGenerator(data_stream_train), |
463 | steps_per_epoch = epochs_train, | 465 | steps_per_epoch = epochs_train, |
464 | epochs = d.num_epochs, | 466 | epochs = d.num_epochs, |
465 | verbose = 1, | 467 | verbose = 1, |
466 | validation_data = timitGenerator(data_stream_dev), | 468 | validation_data = timitGenerator(data_stream_dev), |
467 | validation_steps = epochs_dev, | 469 | validation_steps = epochs_dev, |
468 | callbacks = callbacks, | 470 | callbacks = callbacks, |
469 | initial_epoch = initialEpoch) | 471 | initial_epoch = initialEpoch) |
470 | 472 | ||
471 | 473 |