dataset.py
7.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# -*- coding: utf-8 -*-
#
# Authors: Dmitriy Serdyuk
import itertools
import numpy
from six.moves import range
from itertools import chain
from scipy import fft
from scipy.signal import stft
FS = 44100 # samples/second
DEFAULT_WINDOW_SIZE = 2048 # fourier window size
OUTPUT_SIZE = 128 # number of distinct notes
STRIDE = 512 # samples between windows
WPS = FS / float(512) # windows/second
class MusicNet(object):
def __init__(self, filename, in_memory=True, window_size=4096,
output_size=84, feature_size=1024, sample_freq=11000,
complex_=False, fourier=False, stft=False, fast_load=False,
rng=None, seed=123):
if not in_memory:
raise NotImplementedError
self.filename = filename
self.window_size = window_size
self.output_size = output_size
self.feature_size = feature_size
self.sample_freq = sample_freq
self.complex_ = complex_
self.fourier = fourier
self.stft = stft
self.fast_load = fast_load
if rng is not None:
self.rng = rng
else:
self.rng = numpy.random.RandomState(seed)
self._train_data = {}
self._valid_data = {}
self._test_data = {}
self._loaded = False
self._eval_sets = {}
def splits(self):
with open(self.filename, 'rb') as f:
# This should be fast
all_inds = numpy.load(f).keys()
test_inds = ['2303', '2382', '1819']
valid_inds = ['2131', '2384', '1792',
'2514', '2567', '1876']
train_inds = [ind for ind in all_inds
if ind not in test_inds and ind not in test_inds]
return train_inds, valid_inds, test_inds
@classmethod
def note_to_class(cls, note):
return note - 21
@property
def train_data(self):
if self._train_data == {}:
self.load()
return self._train_data
@property
def valid_data(self):
if self._valid_data == {}:
self.load()
return self._valid_data
@property
def test_data(self):
if self._test_data == {}:
self.load()
return self._test_data
def load(self, filename=None, reload=False):
if filename is None:
filename = self.filename
if self._loaded and not reload:
return
with open(filename, 'rb') as f:
train_inds, valid_inds, test_inds = self.splits()
data_file = numpy.load(f)
if self.fast_load:
train_inds = train_inds[:6]
train_data = {}
for ind in chain(train_inds, valid_inds, test_inds):
train_data[ind] = data_file[ind]
else:
train_data = dict(data_file)
# test set
test_data = {}
for ind in test_inds:
if ind in train_data:
test_data[ind] = train_data.pop(ind)
# valid set
valid_data = {}
for ind in valid_inds:
valid_data[ind] = train_data.pop(ind)
self._train_data = train_data
self._valid_data = valid_data
self._test_data = test_data
def construct_eval_set(self, data, step=128):
n_files = len(data)
pos_per_file = 7500
features = numpy.empty([n_files * pos_per_file, self.window_size])
outputs = numpy.zeros([n_files * pos_per_file, self.output_size])
features_ind = 0
labels_ind = 1
for i, ind in enumerate(data):
print(ind)
audio = data[ind][features_ind]
for j in range(pos_per_file):
if j % 1000 == 0:
print(j)
# start from one second to give us some wiggle room for larger
# segments
index = self.sample_freq + j * step
features[pos_per_file * i + j] = audio[index:
index + self.window_size]
# label stuff that's on in the center of the window
s = int((index + self.window_size / 2))
for label in data[ind][labels_ind][s]:
note = label.data[1]
outputs[pos_per_file * i + j, self.note_to_class(note)] = 1
return features, outputs
@property
def feature_dim(self):
dummy_features = numpy.zeros((1, self.window_size, 1))
dummy_output = numpy.zeros((1, self.output_size))
dummy_features, _ = self.aggregate_raw_batch(
dummy_features, dummy_output)
return dummy_features.shape[1:]
def aggregate_raw_batch(self, features, output):
"""Aggregate batch.
All post processing goes here.
Parameters:
-----------
features : 3D float tensor
Input tensor
output : 2D integer tensor
Output classes
"""
channels = 2 if self.complex_ else 1
features_out = numpy.zeros(
[features.shape[0], self.window_size, channels])
if self.fourier:
if self.complex_:
data = fft(features, axis=1)
features_out[:, :, 0] = numpy.real(data[:, :, 0])
features_out[:, :, 1] = numpy.imag(data[:, :, 0])
else:
data = numpy.abs(fft(features, axis=1))
features_out = data
elif self.stft:
_, _, data = stft(features, nperseg=120, noverlap=60, axis=1)
length = data.shape[1]
n_feats = data.shape[3]
if self.complex_:
features_out = numpy.zeros(
[len(self.train_data), length, n_feats * 2])
features_out[:, :, :n_feats] = numpy.real(data)
features_out[:, :, n_feats:] = numpy.imag(data)
else:
features_out = numpy.abs(data[:, :, 0, :])
else:
features_out = features
return features_out, output
def train_iterator(self):
features = numpy.zeros([len(self.train_data), self.window_size])
while True:
output = numpy.zeros([len(self.train_data), self.output_size])
for j, ind in enumerate(self.train_data):
s = self.rng.randint(
self.window_size / 2,
len(self.train_data[ind][0]) - self.window_size / 2)
data = self.train_data[ind][0][s - self.window_size / 2:
s + self.window_size / 2]
features[j, :] = data
for label in self.train_data[ind][1][s]:
note = label.data[1]
output[j, self.note_to_class(note)] = 1
yield self.aggregate_raw_batch(features[:, :, None], output)
def eval_set(self, set_name):
if not self._eval_sets:
for name in ['valid', 'test']:
data = self.valid_data if name == 'valid' else self.test_data
x, y = self.construct_eval_set(data)
x, y = self.aggregate_raw_batch(x[:, :, None], y)
self._eval_sets[name] = (x, y)
return self._eval_sets[set_name]