Blame view
build/scripts-2.7/train.py
4.38 KB
f2d3bd141 Initial commit wi... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
#!/users/parcollet/.pyenv/versions/2.7.13/bin/python # -*- coding: utf-8 -*- # # Authors: Dmitriy Serdyuk from __future__ import print_function import numpy import numpy as np from os import path import argparse import mimir import keras import musicnet.models.complex from musicnet.callbacks import ( SaveLastModel, Performance, Validation, LearningRateScheduler) from musicnet.dataset import MusicNet from musicnet import models # input dimensions d = 16384 / 4 window_size = d # number of notes m = 84 step = 512 step = step / 4 def schedule(epoch): if epoch >= 0 and epoch < 10: lrate = 1e-3 if epoch == 0: print(' current learning rate value is ' + str(lrate)) elif epoch >= 10 and epoch < 100: lrate = 1e-4 if epoch == 10: print(' current learning rate value is ' + str(lrate)) elif epoch >= 100 and epoch < 120: lrate = 5e-5 if epoch == 100: print(' current learning rate value is ' + str(lrate)) elif epoch >= 120 and epoch < 150: lrate = 1e-5 if epoch == 120: print(' current learning rate value is ' + str(lrate)) elif epoch >= 150: lrate = 1e-6 if epoch == 150: print(' current learning rate value is ' + str(lrate)) return lrate def get_model(model, feature_dim): if model.startswith('complex'): complex_ = True model = model.split('_')[1] else: complex_ = False if complex_: model_module = models.complex print('.. complex network') else: model_module = models if model == 'mlp': print('.. using MLP') return model_module.get_mlp(window_size=numpy.prod(feature_dim)) elif model == 'shallow_convnet': print('.. using shallow convnet') return model_module.get_shallow_convnet(window_size=feature_dim[0], channels=feature_dim[1]) elif model == 'deep_convnet': print('.. using deep convnet') return model_module.get_deep_convnet(window_size=feature_dim[0], channels=feature_dim[1]) else: raise ValueError def main(model_name, in_memory, complex_, model, local_data, epochs, fourier, stft, fast_load): rng = numpy.random.RandomState(123) # Warning: the full dataset is over 40GB. Make sure you have enough RAM! # This can take a few minutes to load if in_memory: print('.. loading train data') dataset = MusicNet(local_data, complex_=complex_, fourier=fourier, stft=stft, rng=rng, fast_load=fast_load) dataset.load() print('.. train data loaded') Xvalid, Yvalid = dataset.eval_set('valid') Xtest, Ytest = dataset.eval_set('test') else: raise ValueError print(".. building model") model = get_model(model, dataset.feature_dim) model.summary() print(".. parameters: {:03.2f}M".format(model.count_params() / 1000000.)) if in_memory: pass # do nothing else: raise ValueError logger = mimir.Logger( filename='models/log_{}.jsonl.gz'.format(model_name)) it = dataset.train_iterator() callbacks = [Validation(Xvalid, Yvalid, 'valid', logger), Validation(Xtest, Ytest, 'test', logger), SaveLastModel("./models/", 1, name=model), Performance(logger), LearningRateScheduler(schedule)] print('.. start training') model.fit_generator( it, steps_per_epoch=1000, epochs=epochs, callbacks=callbacks, workers=1) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('model_name') parser.add_argument('--in-memory', action='store_true', default=False) parser.add_argument('--complex', dest='complex_', action='store_true', default=False) parser.add_argument('--model', default='shallow_convnet') parser.add_argument('--epochs', default=200, type=int) parser.add_argument('--fourier', action='store_true', default=False) parser.add_argument('--stft', action='store_true', default=False) parser.add_argument('--fast-load', action='store_true', default=False) parser.add_argument( '--local-data', default="/Tmp/serdyuk/data/musicnet_11khz.npz") main(**parser.parse_args().__dict__) |