fft.py 5.45 KB
``````#!/usr/bin/env python
# -*- coding: utf-8 -*-

#
# Authors: Olexa Bilaniuk
#

import keras.backend                        as KB
import keras.engine                         as KE
import keras.layers                         as KL
import keras.optimizers                     as KO
import theano                               as T
import theano.ifelse                        as TI
import theano.tensor                        as TT
import theano.tensor.fft                    as TTF
import numpy                                as np

#
# FFT functions:
#
#  fft():   Batched 1-D FFT  (Input: (Batch, TimeSamples))
#  ifft():  Batched 1-D IFFT (Input: (Batch, FreqSamples))
#  fft2():  Batched 2-D FFT  (Input: (Batch, TimeSamplesH, TimeSamplesW))
#  ifft2(): Batched 2-D IFFT (Input: (Batch, FreqSamplesH, FreqSamplesW))
#

def fft(z):
B      = z.shape[0]//2
L      = z.shape[1]
C      = TT.as_tensor_variable(np.asarray([[[1,-1]]], dtype=T.config.floatX))
Zr, Zi = TTF.rfft(z[:B], norm="ortho"), TTF.rfft(z[B:], norm="ortho")
isOdd  = TT.eq(L%2, 1)
Zr     = TI.ifelse(isOdd, TT.concatenate([Zr, C*Zr[:,1:  ][:,::-1]], axis=1),
TT.concatenate([Zr, C*Zr[:,1:-1][:,::-1]], axis=1))
Zi     = TI.ifelse(isOdd, TT.concatenate([Zi, C*Zi[:,1:  ][:,::-1]], axis=1),
TT.concatenate([Zi, C*Zi[:,1:-1][:,::-1]], axis=1))
Zi     = (C*Zi)[:,:,::-1]  # Zi * i
Z      = Zr+Zi
return TT.concatenate([Z[:,:,0], Z[:,:,1]], axis=0)
def ifft(z):
B      = z.shape[0]//2
L      = z.shape[1]
C      = TT.as_tensor_variable(np.asarray([[[1,-1]]], dtype=T.config.floatX))
Zr, Zi = TTF.rfft(z[:B], norm="ortho"), TTF.rfft(z[B:]*-1, norm="ortho")
isOdd  = TT.eq(L%2, 1)
Zr     = TI.ifelse(isOdd, TT.concatenate([Zr, C*Zr[:,1:  ][:,::-1]], axis=1),
TT.concatenate([Zr, C*Zr[:,1:-1][:,::-1]], axis=1))
Zi     = TI.ifelse(isOdd, TT.concatenate([Zi, C*Zi[:,1:  ][:,::-1]], axis=1),
TT.concatenate([Zi, C*Zi[:,1:-1][:,::-1]], axis=1))
Zi     = (C*Zi)[:,:,::-1]  # Zi * i
Z      = Zr+Zi
return TT.concatenate([Z[:,:,0], Z[:,:,1]*-1], axis=0)
def fft2(x):
tt = x
tt = KB.reshape(tt, (x.shape[0] *x.shape[1], x.shape[2]))
tf = fft(tt)
tf = KB.reshape(tf, (x.shape[0], x.shape[1], x.shape[2]))
tf = KB.permute_dimensions(tf, (0, 2, 1))
tf = KB.reshape(tf, (x.shape[0] *x.shape[2], x.shape[1]))
ff = fft(tf)
ff = KB.reshape(ff, (x.shape[0], x.shape[2], x.shape[1]))
ff = KB.permute_dimensions(ff, (0, 2, 1))
return ff
def ifft2(x):
ff = x
ff = KB.permute_dimensions(ff, (0, 2, 1))
ff = KB.reshape(ff, (x.shape[0] *x.shape[2], x.shape[1]))
tf = ifft(ff)
tf = KB.reshape(tf, (x.shape[0], x.shape[2], x.shape[1]))
tf = KB.permute_dimensions(tf, (0, 2, 1))
tf = KB.reshape(tf, (x.shape[0] *x.shape[1], x.shape[2]))
tt = ifft(tf)
tt = KB.reshape(tt, (x.shape[0], x.shape[1], x.shape[2]))
return tt

#
# FFT Layers:
#
#  FFT:   Batched 1-D FFT  (Input: (Batch, FeatureMaps, TimeSamples))
#  IFFT:  Batched 1-D IFFT (Input: (Batch, FeatureMaps, FreqSamples))
#  FFT2:  Batched 2-D FFT  (Input: (Batch, FeatureMaps, TimeSamplesH, TimeSamplesW))
#  IFFT2: Batched 2-D IFFT (Input: (Batch, FeatureMaps, FreqSamplesH, FreqSamplesW))
#

class FFT(KL.Layer):
def call(self, x, mask=None):
a = KB.permute_dimensions(x, (1,0,2))
a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2]))
a = fft(a)
a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2]))
return KB.permute_dimensions(a, (1,0,2))
class IFFT(KL.Layer):
def call(self, x, mask=None):
a = KB.permute_dimensions(x, (1,0,2))
a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2]))
a = ifft(a)
a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2]))
return KB.permute_dimensions(a, (1,0,2))
class FFT2(KL.Layer):
def call(self, x, mask=None):
a = KB.permute_dimensions(x, (1,0,2,3))
a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3]))
a = fft2(a)
a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3]))
return KB.permute_dimensions(a, (1,0,2,3))
class IFFT2(KL.Layer):
def call(self, x, mask=None):
a = KB.permute_dimensions(x, (1,0,2,3))
a = KB.reshape(a, (x.shape[1] *x.shape[0], x.shape[2], x.shape[3]))
a = ifft2(a)
a = KB.reshape(a, (x.shape[1], x.shape[0], x.shape[2], x.shape[3]))
return KB.permute_dimensions(a, (1,0,2,3))

#
# Tests
#
# Note: The IFFT is the conjugate of the FFT of the conjugate.
#
#     np.fft.ifft(x) == np.conj(np.fft.fft(np.conj(x)))
#

if __name__ == "__main__":
# Numpy
np.random.seed(1)
L   = 19
r   = np.random.normal(0.8, size=(L,))
i   = np.random.normal(0.8, size=(L,))
x   = r+i*1j
R   = np.fft.rfft(r, norm="ortho")
I   = np.fft.rfft(i, norm="ortho")
X   = np.fft.fft (x, norm="ortho")

if L&1:
R   = np.concatenate([R, np.conj(R[1:  ][::-1])])
I   = np.concatenate([I, np.conj(I[1:  ][::-1])])
else:
R   = np.concatenate([R, np.conj(R[1:-1][::-1])])
I   = np.concatenate([I, np.conj(I[1:-1][::-1])])
Y   = R+I*1j
print np.allclose(X, Y)

# Theano
z   = TT.dmatrix()
f   = T.function([z], ifft(fft(z)))
v   = np.concatenate([np.real(x)[np.newaxis,:], np.imag(x)[np.newaxis,:]], axis=0)
print v
print f(v)
print np.allclose(v, f(v))

# Keras
x = i = KL.Input(shape=(128, 32,32))
x = IFFT2()(x)
model = KE.Model([i],[x])

loss  = "mse"