Blame view
complexnn/init.py
3.37 KB
f2d3bd141 Initial commit wi... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
#!/usr/bin/env python # -*- coding: utf-8 -*- # Contributors: Titouan Parcollet # Authors: Chiheb Trabelsi import numpy as np from numpy.random import RandomState from random import gauss import keras.backend as K from keras import initializers from keras.initializers import Initializer from keras.utils.generic_utils import (serialize_keras_object, deserialize_keras_object) ##################################################################### # Quaternion Implementations # ##################################################################### class QuaternionInit(Initializer): # The standard complex initialization using # either the He or the Glorot criterion. def __init__(self, kernel_size, input_dim, weight_dim, nb_filters=None, criterion='he', seed=None): # `weight_dim` is used as a parameter for sanity check # as we should not pass an integer as kernel_size when # the weight dimension is >= 2. # nb_filters == 0 if weights are not convolutional (matrix instead of filters) # then in such a case, weight_dim = 2. # (in case of 2D input): # nb_filters == None and len(kernel_size) == 2 and_weight_dim == 2 # conv1D: len(kernel_size) == 1 and weight_dim == 1 # conv2D: len(kernel_size) == 2 and weight_dim == 2 # conv3d: len(kernel_size) == 3 and weight_dim == 3 assert len(kernel_size) == weight_dim and weight_dim in {0, 1, 2, 3} self.nb_filters = nb_filters self.kernel_size = kernel_size self.input_dim = input_dim self.weight_dim = weight_dim self.criterion = criterion self.seed = 1337 if seed is None else seed def __call__(self, shape, dtype=None): if self.nb_filters is not None: kernel_shape = tuple(self.kernel_size) + (int(self.input_dim), self.nb_filters) else: kernel_shape = (int(self.input_dim), self.kernel_size[-1]) fan_in, fan_out = initializers._compute_fans( tuple(self.kernel_size) + (self.input_dim, self.nb_filters) ) # Quaternion operations start here if self.criterion == 'glorot': s = 1. / np.sqrt(2*(fan_in + fan_out)) elif self.criterion == 'he': s = 1. / np.sqrt(2*fan_in) else: raise ValueError('Invalid criterion: ' + self.criterion) #Generating randoms and purely imaginary quaternions : number_of_weights = np.prod(kernel_shape) v_i = np.random.uniform(0.0,1.0,number_of_weights) v_j = np.random.uniform(0.0,1.0,number_of_weights) v_k = np.random.uniform(0.0,1.0,number_of_weights) #Make these purely imaginary quaternions unitary for i in range(0, number_of_weights): norm = np.sqrt(v_i[i]**2 + v_j[i]**2 + v_k[i]**2)+0.0001 v_i[i]/= norm v_j[i]/= norm v_k[i]/= norm v_i = v_i.reshape(kernel_shape) v_j = v_j.reshape(kernel_shape) v_k = v_k.reshape(kernel_shape) rng = RandomState(self.seed) modulus = rng.rayleigh(scale=s, size=kernel_shape) phase = rng.uniform(low=-np.pi, high=np.pi, size=kernel_shape) weight_r = modulus * np.cos(phase) weight_i = modulus * v_i*np.sin(phase) weight_j = modulus * v_j*np.sin(phase) weight_k = modulus * v_k*np.sin(phase) weight = np.concatenate([weight_r, weight_i, weight_j, weight_k], axis=-1) return weight class SqrtInit(Initializer): def __call__(self, shape, dtype=None): return K.constant(1 / K.sqrt(2), shape=shape, dtype=dtype) # Aliases: sqrt_init = SqrtInit quaternion_independent_filters = QuaternionIndependentFilters quaternion_init = QuaternionInit |