train.py
4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/users/parcollet/.pyenv/versions/2.7.13/bin/python
# -*- coding: utf-8 -*-
#
# Authors: Dmitriy Serdyuk
from __future__ import print_function
import numpy
import numpy as np
from os import path
import argparse
import mimir
import keras
import musicnet.models.complex
from musicnet.callbacks import (
SaveLastModel, Performance, Validation, LearningRateScheduler)
from musicnet.dataset import MusicNet
from musicnet import models
# input dimensions
d = 16384 / 4
window_size = d
# number of notes
m = 84
step = 512
step = step / 4
def schedule(epoch):
if epoch >= 0 and epoch < 10:
lrate = 1e-3
if epoch == 0:
print('\ncurrent learning rate value is ' + str(lrate))
elif epoch >= 10 and epoch < 100:
lrate = 1e-4
if epoch == 10:
print('\ncurrent learning rate value is ' + str(lrate))
elif epoch >= 100 and epoch < 120:
lrate = 5e-5
if epoch == 100:
print('\ncurrent learning rate value is ' + str(lrate))
elif epoch >= 120 and epoch < 150:
lrate = 1e-5
if epoch == 120:
print('\ncurrent learning rate value is ' + str(lrate))
elif epoch >= 150:
lrate = 1e-6
if epoch == 150:
print('\ncurrent learning rate value is ' + str(lrate))
return lrate
def get_model(model, feature_dim):
if model.startswith('complex'):
complex_ = True
model = model.split('_')[1]
else:
complex_ = False
if complex_:
model_module = models.complex
print('.. complex network')
else:
model_module = models
if model == 'mlp':
print('.. using MLP')
return model_module.get_mlp(window_size=numpy.prod(feature_dim))
elif model == 'shallow_convnet':
print('.. using shallow convnet')
return model_module.get_shallow_convnet(window_size=feature_dim[0],
channels=feature_dim[1])
elif model == 'deep_convnet':
print('.. using deep convnet')
return model_module.get_deep_convnet(window_size=feature_dim[0],
channels=feature_dim[1])
else:
raise ValueError
def main(model_name, in_memory, complex_, model, local_data, epochs, fourier,
stft, fast_load):
rng = numpy.random.RandomState(123)
# Warning: the full dataset is over 40GB. Make sure you have enough RAM!
# This can take a few minutes to load
if in_memory:
print('.. loading train data')
dataset = MusicNet(local_data, complex_=complex_, fourier=fourier,
stft=stft, rng=rng, fast_load=fast_load)
dataset.load()
print('.. train data loaded')
Xvalid, Yvalid = dataset.eval_set('valid')
Xtest, Ytest = dataset.eval_set('test')
else:
raise ValueError
print(".. building model")
model = get_model(model, dataset.feature_dim)
model.summary()
print(".. parameters: {:03.2f}M".format(model.count_params() / 1000000.))
if in_memory:
pass
# do nothing
else:
raise ValueError
logger = mimir.Logger(
filename='models/log_{}.jsonl.gz'.format(model_name))
it = dataset.train_iterator()
callbacks = [Validation(Xvalid, Yvalid, 'valid', logger),
Validation(Xtest, Ytest, 'test', logger),
SaveLastModel("./models/", 1, name=model),
Performance(logger),
LearningRateScheduler(schedule)]
print('.. start training')
model.fit_generator(
it, steps_per_epoch=1000, epochs=epochs,
callbacks=callbacks, workers=1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model_name')
parser.add_argument('--in-memory', action='store_true', default=False)
parser.add_argument('--complex', dest='complex_', action='store_true',
default=False)
parser.add_argument('--model', default='shallow_convnet')
parser.add_argument('--epochs', default=200, type=int)
parser.add_argument('--fourier', action='store_true', default=False)
parser.add_argument('--stft', action='store_true', default=False)
parser.add_argument('--fast-load', action='store_true', default=False)
parser.add_argument(
'--local-data',
default="/Tmp/serdyuk/data/musicnet_11khz.npz")
main(**parser.parse_args().__dict__)