Blame view
egs/wsj/s5/utils/data/internal/choose_utts_to_combine.py
16.7 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 |
#!/usr/bin/env python # Copyright 2016 Vijayaditya Peddinti # 2016 Johns Hopkins University (author: Daniel Povey) # Apache 2.0 from __future__ import print_function import argparse from random import randint import sys import os from collections import defaultdict parser = argparse.ArgumentParser(description=""" This script, called from data/utils/combine_short_segments.sh, chooses consecutive utterances to concatenate that will satisfy the minimum segment length. It uses the --spk2utt file to ensure that utterances from the same speaker are preferentially combined (as far as possible while respecting the minimum segment length). If it has to combine utterances across different speakers in order to satisfy the duration constraint, it will assign the combined utterances to the speaker which contributed the most to the duration of the combined utterances. The utt2uts output of this program is a map from new utterance-id to a list of old utterance-ids, so for example if the inputs were utt1, utt2 and utt3, and utterances 2 and 3 were combined, the output might look like: utt1 utt1 utt2-combine2 utt2 utt3 The utt2spk output of this program assigns utterances to the speakers of the input; in the (hopefully rare) case where utterances were combined across speakers, it will assign the utterance to whichever of the original speakers contributed the most to the grouped utterance. """) parser.add_argument("--min-duration", type = float, default = 1.55, help="Minimum utterance duration") parser.add_argument("spk2utt_in", type = str, metavar = "<spk2utt-in>", help="Filename of [input] speaker to utterance map needed " "because this script tries to merge utterances from the " "same speaker as much as possible, and also needs to produce" "an output utt2spk map.") parser.add_argument("utt2dur_in", type = str, metavar = "<utt2dur-in>", help="Filename of [input] utterance-to-duration map, with lines like 'utt1 1.23'.") parser.add_argument("utt2utts_out", type = str, metavar = "<utt2utts-out>", help="Filename of [output] new-utterance-to-old-utterances map, with lines " "like 'utt1 utt1' or 'utt2-comb2 utt2 utt3'") parser.add_argument("utt2spk_out", type = str, metavar = "<utt2spk-out>", help="Filename of [output] utt2spk map, which maps new utterances to original " "speakers. If utterances were combined across speakers, we map the new " "utterance to the speaker that contributed the most to them.") parser.add_argument("utt2dur_out", type = str, metavar = "<utt2spk-out>", help="Filename of [output] utt2dur map, which is just the summations of " "the durations of the source utterances.") args = parser.parse_args() # This LessThan is designed to be impervious to roundoff effects in cases where # numbers are really always separated by a distance >> 1.0e-05. It will return # false if x and y are almost identical, differing only by roundoff effects. def LessThan(x, y): return x < y - 1.0e-5 # This function implements the core of the utterance-combination code. # The input 'durations' is a list of durations, which must all be # >=0.0 This function tries to combine consecutive indexes # into groups such that for each group, the total duration is at # least 'min_duration'. It returns a list of (start,end) indexes. # For example, CombineList(0.1, [5.0,6.0,7.0]) would return # [ (0,1), (1,2), (2,3) ] because no combination is necessary; each # returned pair represents a singleton group. # Or CombineList(1.0, [0.5, 0.6, 0.7]) would return # [ (0,3) ]. # Or CombineList(1.0, [0.5, 0.6, 1.7]) would return # [ (0,2), (2,3) ]. # Note: if sum(durations) < min_duration, this function will # return everything in one group but of course the sum of durations # will be less than the total. def CombineList(min_duration, durations): assert min_duration >= 0.0 and min(durations) > 0.0 num_utts = len(durations) # for each utterance-index i, group_start[i] gives us the # start-index of the group of utterances of which it's currently # a member. group_start = list(range(num_utts)) # if utterance-index i currently corresponds to the start of a group # of utterances, then group_durations[i] is the total duration of # that utterance-group, otherwise undefined. group_durations = list(durations) # if utterance-index i currently corresponds to the start of a group # of utterances, then group_end[i] is the end-index (i.e. last index plus one # of that utterance-group, otherwise undefined. group_end = [ x + 1 for x in range(num_utts) ] queue = [ i for i in range(num_utts) if LessThan(group_durations[i], min_duration) ] while len(queue) > 0: i = queue.pop() if group_start[i] != i or not LessThan(group_durations[i], min_duration): # this group no longer exists or already has at least the minimum duration. continue this_dur = group_durations[i] # left_dur is the duration of the group to the left of this group, # or 0.0 if there is no such group. left_dur = group_durations[group_start[i-1]] if i > 0 else 0.0 # right_dur is the duration of the group to the right of this group, # or 0.0 if there is no such group. right_dur = group_durations[group_end[i]] if group_end[i] < num_utts else 0.0 if left_dur == 0.0 and right_dur == 0.0: # there is only one group. Nothing more to merge; break assert group_start[i] == 0 and group_end[i] == num_utts break # work out whether to combine left or right, # by means of the combine_left variable [ True or False ] if left_dur == 0.0: combine_left = False elif right_dur == 0.0: combine_left = True elif LessThan(left_dur + this_dur, min_duration): # combining left would still be below the minimum duration-> # combine right... if it's above the min duration then good; # otherwise it still doesn't really matter so we might as well # pick one. combine_left = False elif LessThan(right_dur + this_dur, min_duration): # combining right would still be below the minimum duration, # and combining left would be >= the min duration (else we wouldn't # have reached this line) -> combine left. combine_left = True elif LessThan(left_dur, right_dur): # if we reached here then combining either way would take us >= the # minimum duration; but if left_dur < right_dur then we combine left # because that would give us more evenly sized segments. combine_left = True else: # if we reached here then combining either way would take us >= the # minimum duration; but left_dur >= right_dur, so we combine right # because that would give us more evenly sized segments. combine_left = False if combine_left: assert left_dur != 0.0 new_group_start = group_start[i-1] group_end[new_group_start] = group_end[i] for j in range(group_start[i], group_end[i]): group_start[j] = new_group_start group_durations[new_group_start] += durations[j] # note: there is no need to add group_durations[new_group_start] to # the queue even if it is still below the minimum length, because it # would have previously had to have been below the minimum length, # therefore it would already be in the queue. else: assert right_dur != 0.0 # group start doesn't change, group end changes. old_group_end = group_end[i] new_group_end = group_end[old_group_end] group_end[i] = new_group_end for j in range(old_group_end, new_group_end): group_durations[i] += durations[j] group_start[j] = i if LessThan(group_durations[i], min_duration): # the group starting at i is still below the minimum length, so # we need to put it back on the queue. queue.append(i) ans = [] cur_group_start = 0 while cur_group_start < num_utts: ans.append( (cur_group_start, group_end[cur_group_start]) ) cur_group_start = group_end[cur_group_start] return ans def SelfTest(): assert CombineList(0.1, [5.0, 6.0, 7.0]) == [ (0,1), (1,2), (2,3) ] assert CombineList(0.5, [0.1, 6.0, 7.0]) == [ (0,2), (2,3) ] assert CombineList(0.5, [6.0, 7.0, 0.1]) == [ (0,1), (1,3) ] # in the two examples below, it combines with the shorter one if both would # be above min-dur. assert CombineList(0.5, [6.0, 0.1, 7.0]) == [ (0,2), (2,3) ] assert CombineList(0.5, [7.0, 0.1, 6.0]) == [ (0,1), (1,3) ] # in the example below, it combines with whichever one would # take it above the min-dur, if there is only one such. # note, it tests the 0.1 first as the queue is popped from the end. assert CombineList(1.0, [1.0, 0.5, 0.1, 6.0]) == [ (0,2), (2,4) ] for x in range(100): min_duration = 0.05 num_utts = randint(1, 15) durations = [] for i in range(num_utts): durations.append(0.01 * randint(1, 10)) ranges = CombineList(min_duration, durations) if len(ranges) > 1: # check that each range's duration is >= min_duration for j in range(len(ranges)): (start, end) = ranges[j] this_dur = sum([ durations[k] for k in range(start, end) ]) assert not LessThan(this_dur, min_duration) # check that the list returned is not affected by very tiny differences # in the inputs. durations2 = list(durations) for i in range(len(durations2)): durations2[i] += 1.0e-07 * randint(-5, 5) ranges2 = CombineList(min_duration, durations2) assert ranges2 == ranges # This function figures out the grouping of utterances. # The input is: # 'min_duration' which is the minimum utterance length in seconds. # 'spk2utt' which is a list of pairs (speaker-id, [list-of-utterances]) # 'utt2dur' which is a dict from utterance-id to duration (as a float) # It returns a lists of lists of utterances; each list corresponds to # a group, e.g. # [ ['utt1'], ['utt2', 'utt3'] ] def GetUtteranceGroups(min_duration, spk2utt, utt2dur): # utt_groups will be a list of lists of utterance-ids formed from the # first pass of combination. utt_groups = [] # group_durations will be the durations of the corresponding elements of # 'utt_groups'. group_durations = [] # This block calls CombineList for the utterances of each speaker # separately, in the 'first pass' of combination. for i in range(len(spk2utt)): (spk, utts) = spk2utt[i] durations = [] # durations for this group of utts. for utt in utts: try: durations.append(utt2dur[utt]) except: sys.exit("choose_utts_to_combine.py: no duration available " "in utt2dur file {0} for utterance {1}".format( args.utt2dur_in, utt)) ranges = CombineList(min_duration, durations) for start, end in ranges: # each element of 'ranges' is a 2-tuple (start, end) utt_groups.append( [ utts[i] for i in range(start, end) ]) group_durations.append(sum([ durations[i] for i in range(start, end) ])) old_dur_sum = sum(utt2dur.values()) new_dur_sum = sum(group_durations) if abs(old_dur_sum - new_dur_sum) > 0.0001 * old_dur_sum: print("choose_utts_to_combine.py: large difference in total " "durations: {0} vs {1} ".format(old_dur_sum, new_dur_sum), file = sys.stderr) # Now we combine the groups obtained above, in case we had situations where # the combination of all the utterances of one speaker were still below # the minimum duration. new_utt_groups = [] ranges = CombineList(min_duration, group_durations) for start, end in ranges: # the following code is destructive of 'utt_groups' but it doesn't # matter. this_group = utt_groups[start] for i in range(start + 1, end): this_group += utt_groups[i] new_utt_groups.append(this_group) print("choose_utts_to_combine.py: combined {0} utterances to {1} utterances " "while respecting speaker boundaries, and then to {2} utterances " "with merging across speaker boundaries.".format( len(utt2dur), len(utt_groups), len(new_utt_groups)), file = sys.stderr) return new_utt_groups SelfTest() if args.min_duration < 0.0: print("choose_utts_to_combine.py: bad minium duration {0}".format( args.min_duration)) # spk2utt is a list of 2-tuples (speaker-id, [list-of-utterances]) spk2utt = [] # utt2spk is a dict from speaker-id to utternace-id. utt2spk = dict() try: f = open(args.spk2utt_in) except: sys.exit("choose_utts_to_combine.py: error opening --spk2utt={0}".format(args.spk2utt_in)) while True: line = f.readline() if line == '': break a = line.split() if len(a) < 2: sys.exit("choose_utts_to_combine.py: bad line in spk2utt file: " + line) spk = a[0] utts = a[1:] spk2utt.append((spk, utts)) for utt in utts: if utt in utt2spk: sys.exit("choose_utts_to_combine.py: utterance {0} is listed more than once" "in the spk2utt file {1}".format(utt, args.spk2utt_in)) utt2spk[utt] = spk f.close() # utt2dur is a dict from utterance-id (as a string) to duration in seconds (as a float) utt2dur = dict() try: f = open(args.utt2dur_in) except: sys.exit("choose_utts_to_combine.py: error opening utt2dur file {0}".format(args.utt2dur_in)) while True: line = f.readline() if line == '': break try: [ utt, dur ] = line.split() dur = float(dur) utt2dur[utt] = dur except: sys.exit("choose_utts_to_combine.py: bad line in utt2dur file {0}: {1}".format( args.utt2dur_in, line)) utt_groups = GetUtteranceGroups(args.min_duration, spk2utt, utt2dur) # set utt_group names to an array like [ 'utt1', 'utt2-comb2', 'utt4', ... ] utt_group_names = [ group[0] if len(group)==1 else "{0}-comb{1}".format(group[0], len(group)) for group in utt_groups ] # write the utt2utts file. try: with open(args.utt2utts_out, 'w') as f: for i in range(len(utt_groups)): print(utt_group_names[i], ' '.join(utt_groups[i]), file = f) except Exception as e: sys.exit("choose_utts_to_combine.py: exception writing to " "<utt2utts-out>={0}: {1}".format(args.utt2utts_out, str(e))) # write the utt2spk file. try: with open(args.utt2spk_out, 'w') as f: for i in range(len(utt_groups)): utt_group = utt_groups[i] spk_list = [ utt2spk[utt] for utt in utt_group ] if spk_list == [ spk_list[0] ] * len(utt_group): spk = spk_list[0] else: spk2dur = defaultdict(float) # spk2dur is a map from the speaker-id to the duration within this # utt, that it comprises. for utt in utt_group: spk2dur[utt2spk[utt]] += utt2dur[utt] # the following code, which picks the speaker that contributed # the most to the duration of this utterance, is a little # complex because we want to break ties in a deterministic way # picking the earlier spaker in case of a tied duration. longest_spk_dur = -1.0 spk = None for this_spk in sorted(spk2dur.keys()): if LessThan(longest_spk_dur, spk2dur[this_spk]): longest_spk_dur = spk2dur[this_spk] spk = this_spk assert spk != None new_utt = utt_group_names[i] print(new_utt, spk, file = f) except Exception as e: sys.exit("choose_utts_to_combine.py: exception writing to " "<utt2spk-out>={0}: {1}".format(args.utt2spk_out, str(e))) # write the utt2dur file. try: with open(args.utt2dur_out, 'w') as f: for i in range(len(utt_groups)): utt_name = utt_group_names[i] duration = sum([ utt2dur[utt] for utt in utt_groups[i]]) print(utt_name, duration, file = f) except Exception as e: sys.exit("choose_utts_to_combine.py: exception writing to " "<utt2dur-out>={0}: {1}".format(args.utt2dur_out, str(e))) |