Blame view

scripts/models_timit.py 6.79 KB
f2d3bd141   Parcollet Titouan   Initial commit wi...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
  #!/usr/bin/env python
  # -*- coding: utf-8 -*-
  # Authors: Parcollet Titouan
  
  # Imports
  
  import complexnn
  from   complexnn                             import *
  import keras
  from   keras.callbacks                       import Callback, ModelCheckpoint, LearningRateScheduler
  from   keras.datasets                        import cifar10, cifar100
  from   keras.initializers                    import Orthogonal
  from   keras.layers                          import Layer, Dropout, AveragePooling1D, 
                                                      AveragePooling2D, AveragePooling3D, add, Add, concatenate, 
                                                      Concatenate, Input, Flatten, Dense, Convolution2D, BatchNormalization, 
                                                      Activation, Reshape, ConvLSTM2D, Conv2D, Lambda, Permute, TimeDistributed, 
                                                      SpatialDropout1D, PReLU
  from   keras.models                          import Model, load_model, save_model
  from   keras.optimizers                      import SGD, Adam, RMSprop
  from   keras.preprocessing.image             import ImageDataGenerator
  from   keras.regularizers                    import l2
  from   keras.utils.np_utils                  import to_categorical
  import keras.backend                         as     K
  import keras.models                          as     KM
  from keras.utils.training_utils              import multi_gpu_model
  import logging                               as     L
  import numpy                                 as     np
  import os, pdb, socket, sys, time
  import theano                                as     T
  from keras.backend.tensorflow_backend        import set_session
  import tensorflow                            as tf
  
  
  #
  # CTC Loss
  #
  
  def ctc_lambda_func(args):
      y_pred, labels, input_length, label_length = args
      return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
  
  #
  # Get ResNet Model
  #
  
  def getTimitModel2D(d):
      n             = d.num_layers
      sf            = d.start_filter
      dataset       = d.dataset
      activation    = d.act
      advanced_act  = d.aact
      drop_prob     = d.dropout
      inputShape    = (3,41,None)  
      filsize       = (3, 5)
      channelAxis   = 1
      
      if d.aact != "none":
          d.act = 'linear'
  
      convArgs      = {
              "activation":               d.act,
              "data_format":              "channels_first",
              "padding":                  "same",
              "bias_initializer":         "zeros",
              "kernel_regularizer":       l2(d.l2),
              "kernel_initializer":       "random_uniform",
              }
      denseArgs     = {
              "activation":               d.act,        
              "kernel_regularizer":       l2(d.l2),
              "kernel_initializer":       "random_uniform",
              "bias_initializer":         "zeros",
              "use_bias":                 True
              }
       
      if d.model == "quaternion":
          convArgs.update({"kernel_initializer": d.quat_init})
  
      #
      # Input Layer & CTC Parameters for TIMIT
      #
      if d.model == "quaternion":
          I    = Input(shape=(4,41,None))
      else:
          I = Input(shape=inputShape)
  
      labels = Input(name='the_labels', shape=[None], dtype='float32')
      input_length = Input(name='input_length', shape=[1], dtype='int64')
      label_length = Input(name='label_length', shape=[1], dtype='int64')
  
      #
      # Input stage:
      #
      if d.model == "real":
          O = Conv2D(sf, filsize, name='conv', use_bias=True, **convArgs)(I)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
      else:
          O = QuaternionConv2D(sf, filsize, name='conv', use_bias=True, **convArgs)(I)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
      #
      # Pooling
      #
      O = keras.layers.MaxPooling2D(pool_size=(1, 3), padding='same')(O)
  
  
      #
      # Stage 1
      #
      for i in xrange(0,n/2):
          if d.model=="real":
              O = Conv2D(sf, filsize, name='conv'+str(i), use_bias=True,**convArgs)(O)
              if d.aact == "prelu":
                  O = PReLU(shared_axes=[1,0])(O)
              O = Dropout(d.dropout)(O)
          else:
              O = QuaternionConv2D(sf, filsize, name='conv'+str(i), use_bias=True, **convArgs)(O)
              if d.aact == "prelu":
                  O = PReLU(shared_axes=[1,0])(O)
              O = Dropout(d.dropout)(O)
      
      #
      # Stage 2
      #
      for i in xrange(0,n/2):
          if d.model=="real":
              O = Conv2D(sf*2, filsize, name='conv'+str(i+n/2), use_bias=True, **convArgs)(O)
              if d.aact == "prelu":
                  O = PReLU(shared_axes=[1,0])(O)
              O = Dropout(d.dropout)(O)
          else:
              O = QuaternionConv2D(sf*2, filsize, name='conv'+str(i+n/2), use_bias=True, **convArgs)(O)
              if d.aact == "prelu":
                  O = PReLU(shared_axes=[1,0])(O)
              O = Dropout(d.dropout)(O)
      
      #
      # Permutation for CTC 
      #
      
      O = Permute((3,1,2))(O)
      O = Lambda(lambda x: K.reshape(x, (K.shape(x)[0], K.shape(x)[1],
                                         K.shape(x)[2] * K.shape(x)[3])),
                 output_shape=lambda x: (None, None, x[2] * x[3]))(O)
      
      #
      # Dense
      #
      if d.model== "quaternion":
          O = TimeDistributed( QuaternionDense(256,  **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
          O = Dropout(d.dropout)(O)
          O = TimeDistributed( QuaternionDense(256,  **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
          O = Dropout(d.dropout)(O)
          O = TimeDistributed( QuaternionDense(256,  **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
      else:
          O = TimeDistributed( Dense(1024,  **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
          O = Dropout(d.dropout)(O)
          O = TimeDistributed( Dense(1024, **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
          O = Dropout(d.dropout)(O)
          O = TimeDistributed( Dense(1024, **denseArgs))(O)
          if d.aact == "prelu":
              O = PReLU(shared_axes=[1,0])(O)
  
      pred = TimeDistributed( Dense(62,  activation='softmax', kernel_regularizer=l2(d.l2), use_bias=True, bias_initializer="zeros", kernel_initializer='random_uniform' ))(O)
        
      #
      # CTC For sequence labelling
      #
      O = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([pred, labels,input_length,label_length])
      
      # Return the model
      if d.dataset == "timit":
          #
          # Creating a function for testing and validation purpose
          #
          val_function = K.function([I],[pred])
          return Model(inputs=[I, input_length, labels, label_length], outputs=O), val_function
      else:
          raise ValueError("Unknown dataset "+d.dataset)