Blame view
egs/wsj/s5/utils/data/internal/modify_speaker_info.py
4.34 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
#!/usr/bin/env python from __future__ import print_function import argparse, sys,os from collections import defaultdict parser = argparse.ArgumentParser(description=""" Combine consecutive utterances into fake speaker ids for a kind of poor man's segmentation. Reads old utt2spk from standard input, outputs new utt2spk to standard output.""") parser.add_argument("--utts-per-spk-max", type = int, required = True, help="Maximum number of utterances allowed per speaker") parser.add_argument("--seconds-per-spk-max", type = float, required = True, help="""Maximum duration in seconds allowed per speaker. If this option is >0, --utt2dur option must be provided.""") parser.add_argument("--utt2dur", type = str, help="""Filename of input 'utt2dur' file (needed only if --seconds-per-spk-max is provided)""") parser.add_argument("--respect-speaker-info", type = str, default = 'true', choices = ['true', 'false'], help="""If true, the output speakers will be split from " "existing speakers.""") args = parser.parse_args() utt2spk = dict() # an undefined spk2utt entry will default to an empty list. spk2utt = defaultdict(lambda: []) while True: line = sys.stdin.readline() if line == '': break; a = line.split() if len(a) != 2: sys.exit("modify_speaker_info.py: bad utt2spk line from standard input (expected two fields): " + line) [ utt, spk ] = a utt2spk[utt] = spk spk2utt[spk].append(utt) if args.seconds_per_spk_max > 0: utt2dur = dict() try: f = open(args.utt2dur) while True: line = f.readline() if line == '': break a = line.split() if len(a) != 2: sys.exit("modify_speaker_info.py: bad utt2dur line from standard input (expected two fields): " + line) [ utt, dur ] = a utt2dur[utt] = float(dur) for utt in utt2spk: if not utt in utt2dur: sys.exit("modify_speaker_info.py: utterance {0} not in utt2dur file {1}".format( utt, args.utt2dur)) except Exception as e: sys.exit("modify_speaker_info.py: problem reading utt2dur info: " + str(e)) # splits a list of utts into a list of lists, based on constraints from the # command line args. Note: the last list will tend to be shorter than the others, # we make no attempt to fix this. def SplitIntoGroups(uttlist): ans = [] # list of lists. cur_uttlist = [] cur_dur = 0.0 for utt in uttlist: if ((args.utts_per_spk_max > 0 and len(cur_uttlist) == args.utts_per_spk_max) or (args.seconds_per_spk_max > 0 and len(cur_uttlist) > 0 and cur_dur + utt2dur[utt] > args.seconds_per_spk_max)): ans.append(cur_uttlist) cur_uttlist = [] cur_dur = 0.0 cur_uttlist.append(utt) if args.seconds_per_spk_max > 0: cur_dur += utt2dur[utt] if len(cur_uttlist) > 0: ans.append(cur_uttlist) return ans # This function will return '%01d' if d < 10, '%02d' if d < 100, and so on. # It's for printf printing of numbers in such a way that sorted order will be # correct. def GetFormatString(d): ans = 1 while (d >= 10): d //= 10 # integer division ans += 1 # e.g. we might return the string '%01d' or '%02d' return '%0{0}d'.format(ans) if args.respect_speaker_info == 'true': for spk in sorted(spk2utt.keys()): uttlists = SplitIntoGroups(spk2utt[spk]) format_string = '%s-' + GetFormatString(len(uttlists)) for i in range(len(uttlists)): # the following might look like: '%s-%02d'.format('john_smith' 9 + 1), # giving 'john_smith-10'. this_spk = format_string % (spk, i + 1) for utt in uttlists[i]: print(utt, this_spk) else: uttlists = SplitIntoGroups(sorted(utt2spk.keys())) format_string = 'speaker-' + GetFormatString(len(uttlists)) for i in range(len(uttlists)): # the following might look like: 'speaker-%04d'.format(105 + 1), # giving 'speaker-0106'. this_spk = format_string % (i + 1) for utt in uttlists[i]: print(utt, this_spk) |