combine_short_segments.py
12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
#!/usr/bin/env python
# Copyright 2016 Vijayaditya Peddinti
# Apache 2.0
from __future__ import print_function
import argparse
import sys
import os
import subprocess
import errno
import copy
import shutil
import warnings
def GetArgs():
# we add compulsary arguments as named arguments for readability
parser = argparse.ArgumentParser(description="""
**Warning, this script is deprecated. Please use utils/data/combine_short_segments.sh**
This script concatenates segments in the input_data_dir to ensure that"""
" the segments in the output_data_dir have a specified minimum length.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--minimum-duration", type=float, required = True,
help="Minimum duration of the segments in the output directory")
parser.add_argument("--input-data-dir", type=str, required = True)
parser.add_argument("--output-data-dir", type=str, required = True)
print(' '.join(sys.argv))
args = parser.parse_args()
return args
def RunKaldiCommand(command, wait = True):
""" Runs commands frequently seen in Kaldi scripts. These are usually a
sequence of commands connected by pipes, so we use shell=True """
p = subprocess.Popen(command, shell = True,
stdout = subprocess.PIPE,
stderr = subprocess.PIPE)
if wait:
[stdout, stderr] = p.communicate()
if p.returncode is not 0:
raise Exception("There was an error while running the command {0}\n".format(command)+"-"*10+"\n"+stderr)
return stdout, stderr
else:
return p
def MakeDir(dir):
try:
os.mkdir(dir)
except OSError as exc:
if exc.errno != errno.EEXIST:
raise exc
raise Exception("Directory {0} already exists".format(dir))
pass
def CheckFiles(input_data_dir):
for file_name in ['spk2utt', 'text', 'utt2spk', 'feats.scp']:
file_name = '{0}/{1}'.format(input_data_dir, file_name)
if not os.path.exists(file_name):
raise Exception("There is no such file {0}".format(file_name))
def ParseFileToDict(file, assert2fields = False, value_processor = None):
if value_processor is None:
value_processor = lambda x: x[0]
dict = {}
for line in open(file, 'r'):
parts = line.split()
if assert2fields:
assert(len(parts) == 2)
dict[parts[0]] = value_processor(parts[1:])
return dict
def WriteDictToFile(dict, file_name):
file = open(file_name, 'w')
keys = dict.keys()
keys.sort()
for key in keys:
value = dict[key]
if type(value) in [list, tuple] :
if type(value) is tuple:
value = list(value)
value.sort()
value = ' '.join(value)
file.write('{0}\t{1}\n'.format(key, value))
file.close()
def ParseDataDirInfo(data_dir):
data_dir_file = lambda file_name: '{0}/{1}'.format(data_dir, file_name)
utt2spk = ParseFileToDict(data_dir_file('utt2spk'))
spk2utt = ParseFileToDict(data_dir_file('spk2utt'), value_processor = lambda x: x)
text = ParseFileToDict(data_dir_file('text'), value_processor = lambda x: " ".join(x))
# we want to assert feats.scp has just 2 fields, as we don't know how
# to process it otherwise
feat = ParseFileToDict(data_dir_file('feats.scp'), assert2fields = True)
utt2dur = ParseFileToDict(data_dir_file('utt2dur'), value_processor = lambda x: float(x[0]))
utt2uniq = None
if os.path.exists(data_dir_file('utt2uniq')):
utt2uniq = ParseFileToDict(data_dir_file('utt2uniq'))
return utt2spk, spk2utt, text, feat, utt2dur, utt2uniq
def GetCombinedUttIndexRange(utt_index, utts, utt_durs, minimum_duration):
# We want the minimum number of concatenations
# to reach the minimum_duration. If two concatenations satisfy
# the minimum duration constraint we choose the shorter one.
left_index = utt_index - 1
right_index = utt_index + 1
num_remaining_segments = len(utts) - 1
cur_utt_dur = utt_durs[utts[utt_index]]
while num_remaining_segments > 0:
left_utt_dur = 0
if left_index >= 0:
left_utt_dur = utt_durs[utts[left_index]]
right_utt_dur = 0
if right_index <= len(utts) - 1:
right_utt_dur = utt_durs[utts[right_index]]
right_combined_utt_dur = cur_utt_dur + right_utt_dur
left_combined_utt_dur = cur_utt_dur + left_utt_dur
left_right_combined_utt_dur = cur_utt_dur + left_utt_dur + right_utt_dur
combine_left_exit = False
combine_right_exit = False
if right_combined_utt_dur >= minimum_duration:
if left_combined_utt_dur >= minimum_duration:
if left_combined_utt_dur <= right_combined_utt_dur:
combine_left_exit = True
else:
combine_right_exit = True
else:
combine_right_exit = True
elif left_combined_utt_dur >= minimum_duration:
combine_left_exit = True
elif left_right_combined_utt_dur >= minimum_duration :
combine_left_exit = True
combine_right_exit = True
if combine_left_exit and combine_right_exit:
cur_utt_dur = left_right_combined_utt_dur
break
elif combine_left_exit:
cur_utt_dur = left_combined_utt_dur
# move back the right_index as we don't need to combine it
right_index = right_index - 1
break
elif combine_right_exit:
cur_utt_dur = right_combined_utt_dur
# move back the left_index as we don't need to combine it
left_index = left_index + 1
break
# couldn't satisfy minimum duration requirement so continue search
if left_index >= 0:
num_remaining_segments = num_remaining_segments - 1
if right_index <= len(utts) - 1:
num_remaining_segments = num_remaining_segments - 1
left_index = left_index - 1
right_index = right_index + 1
cur_utt_dur = left_right_combined_utt_dur
left_index = max(0, left_index)
right_index = min(len(utts)-1, right_index)
return left_index, right_index, cur_utt_dur
def WriteCombinedDirFiles(output_dir, utt2spk, spk2utt, text, feat, utt2dur, utt2uniq):
out_dir_file = lambda file_name: '{0}/{1}'.format(output_dir, file_name)
total_combined_utt_list = []
for speaker in spk2utt.keys():
utts = spk2utt[speaker]
for utt in utts:
if type(utt) is tuple:
#this is a combined utt
total_combined_utt_list.append((speaker, utt))
for speaker, combined_utt_tuple in total_combined_utt_list:
combined_utt_list = list(combined_utt_tuple)
combined_utt_list.sort()
new_utt_name = "-".join(combined_utt_list)+'-appended'
# updating the utt2spk dict
for utt in combined_utt_list:
spk_name = utt2spk.pop(utt)
utt2spk[new_utt_name] = spk_name
# updating the spk2utt dict
spk2utt[speaker].remove(combined_utt_tuple)
spk2utt[speaker].append(new_utt_name)
# updating the text dict
combined_text = []
for utt in combined_utt_list:
combined_text.append(text.pop(utt))
text[new_utt_name] = ' '.join(combined_text)
# updating the feat dict
combined_feat = []
for utt in combined_utt_list:
combined_feat.append(feat.pop(utt))
feat_command = "concat-feats --print-args=false {feats} - |".format(feats = " ".join(combined_feat))
feat[new_utt_name] = feat_command
# updating utt2dur
combined_dur = 0
for utt in combined_utt_list:
combined_dur += utt2dur.pop(utt)
utt2dur[new_utt_name] = combined_dur
# updating utt2uniq
if utt2uniq is not None:
combined_uniqs = []
for utt in combined_utt_list:
combined_uniqs.append(utt2uniq.pop(utt))
# utt2uniq file is used to map perturbed data to original unperturbed
# versions so that the training cross validation sets can avoid overlap
# of data however if perturbation changes the length of the utterance
# (e.g. speed perturbation) the utterance combinations in each
# perturbation of the original recording can be very different. So there
# is no good way to find the utt2uniq mapping so that we can avoid
# overlap.
utt2uniq[new_utt_name] = combined_uniqs[0]
WriteDictToFile(utt2spk, out_dir_file('utt2spk'))
WriteDictToFile(spk2utt, out_dir_file('spk2utt'))
WriteDictToFile(feat, out_dir_file('feats.scp'))
WriteDictToFile(text, out_dir_file('text'))
if utt2uniq is not None:
WriteDictToFile(utt2uniq, out_dir_file('utt2uniq'))
WriteDictToFile(utt2dur, out_dir_file('utt2dur'))
def CombineSegments(input_dir, output_dir, minimum_duration):
utt2spk, spk2utt, text, feat, utt2dur, utt2uniq = ParseDataDirInfo(input_dir)
total_combined_utt_list = []
# copy the duration dictionary so that we can modify it
utt_durs = copy.deepcopy(utt2dur)
speakers = spk2utt.keys()
speakers.sort()
for speaker in speakers:
utts = spk2utt[speaker] # this is an assignment of the reference
# In WriteCombinedDirFiles the values of spk2utt will have the list
# of combined utts which will be used as reference
# we make an assumption that the sorted uttlist corresponds
# to contiguous segments. This is true only if utt naming
# is done according to accepted conventions
# this is an easily violatable assumption. Have to think of a better
# way to do this.
utts.sort()
utt_index = 0
while utt_index < len(utts):
if utt_durs[utts[utt_index]] < minimum_duration:
left_index, right_index, cur_utt_dur = GetCombinedUttIndexRange(utt_index, utts, utt_durs, minimum_duration)
if not cur_utt_dur >= minimum_duration:
# this is a rare occurrence, better make the user aware of this
# situation and let them deal with it
warnings.warn('Speaker {0} does not have enough utterances to satisfy the minimum duration '
'constraint. Not modifying these utterances'.format(speaker))
utt_index = utt_index + 1
continue
combined_duration = 0
combined_utts = []
# update the utts_dur dictionary
for utt in utts[left_index:right_index + 1]:
combined_duration += utt_durs.pop(utt)
if type(utt) is tuple:
for item in utt:
combined_utts.append(item)
else:
combined_utts.append(utt)
combined_utts = tuple(combined_utts) # converting to immutable type to use as dictionary key
assert(cur_utt_dur == combined_duration)
# now modify the utts list
combined_indices = list(range(left_index, right_index + 1))
# start popping from the largest index so that the lower
# indexes are valid
for i in combined_indices[::-1]:
utts.pop(i)
utts.insert(left_index, combined_utts)
utt_durs[combined_utts] = combined_duration
utt_index = left_index
utt_index = utt_index + 1
WriteCombinedDirFiles(output_dir, utt2spk, spk2utt, text, feat, utt2dur, utt2uniq)
def Main():
print("""steps/cleanup/combine_short_segments.py: warning: this script is deprecated and will be removed.
Please use utils/data/combine_short_segments.sh""", file = sys.stderr)
args = GetArgs()
CheckFiles(args.input_data_dir)
MakeDir(args.output_data_dir)
feat_lengths = {}
segments_file = '{0}/segments'.format(args.input_data_dir)
RunKaldiCommand("utils/data/get_utt2dur.sh {0}".format(args.input_data_dir))
CombineSegments(args.input_data_dir, args.output_data_dir, args.minimum_duration)
RunKaldiCommand("utils/utt2spk_to_spk2utt.pl {od}/utt2spk > {od}/spk2utt".format(od = args.output_data_dir))
if os.path.exists('{0}/cmvn.scp'.format(args.input_data_dir)):
shutil.copy('{0}/cmvn.scp'.format(args.input_data_dir), args.output_data_dir)
RunKaldiCommand("utils/fix_data_dir.sh {0}".format(args.output_data_dir))
if __name__ == "__main__":
Main()