modify_speaker_info.py
4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python
from __future__ import print_function
import argparse, sys,os
from collections import defaultdict
parser = argparse.ArgumentParser(description="""
Combine consecutive utterances into fake speaker ids for a kind of
poor man's segmentation. Reads old utt2spk from standard input,
outputs new utt2spk to standard output.""")
parser.add_argument("--utts-per-spk-max", type = int, required = True,
help="Maximum number of utterances allowed per speaker")
parser.add_argument("--seconds-per-spk-max", type = float, required = True,
help="""Maximum duration in seconds allowed per speaker.
If this option is >0, --utt2dur option must be provided.""")
parser.add_argument("--utt2dur", type = str,
help="""Filename of input 'utt2dur' file (needed only if
--seconds-per-spk-max is provided)""")
parser.add_argument("--respect-speaker-info", type = str, default = 'true',
choices = ['true', 'false'],
help="""If true, the output speakers will be split from "
"existing speakers.""")
args = parser.parse_args()
utt2spk = dict()
# an undefined spk2utt entry will default to an empty list.
spk2utt = defaultdict(lambda: [])
while True:
line = sys.stdin.readline()
if line == '':
break;
a = line.split()
if len(a) != 2:
sys.exit("modify_speaker_info.py: bad utt2spk line from standard input (expected two fields): " +
line)
[ utt, spk ] = a
utt2spk[utt] = spk
spk2utt[spk].append(utt)
if args.seconds_per_spk_max > 0:
utt2dur = dict()
try:
f = open(args.utt2dur)
while True:
line = f.readline()
if line == '':
break
a = line.split()
if len(a) != 2:
sys.exit("modify_speaker_info.py: bad utt2dur line from standard input (expected two fields): " +
line)
[ utt, dur ] = a
utt2dur[utt] = float(dur)
for utt in utt2spk:
if not utt in utt2dur:
sys.exit("modify_speaker_info.py: utterance {0} not in utt2dur file {1}".format(
utt, args.utt2dur))
except Exception as e:
sys.exit("modify_speaker_info.py: problem reading utt2dur info: " + str(e))
# splits a list of utts into a list of lists, based on constraints from the
# command line args. Note: the last list will tend to be shorter than the others,
# we make no attempt to fix this.
def SplitIntoGroups(uttlist):
ans = [] # list of lists.
cur_uttlist = []
cur_dur = 0.0
for utt in uttlist:
if ((args.utts_per_spk_max > 0 and len(cur_uttlist) == args.utts_per_spk_max) or
(args.seconds_per_spk_max > 0 and len(cur_uttlist) > 0 and
cur_dur + utt2dur[utt] > args.seconds_per_spk_max)):
ans.append(cur_uttlist)
cur_uttlist = []
cur_dur = 0.0
cur_uttlist.append(utt)
if args.seconds_per_spk_max > 0:
cur_dur += utt2dur[utt]
if len(cur_uttlist) > 0:
ans.append(cur_uttlist)
return ans
# This function will return '%01d' if d < 10, '%02d' if d < 100, and so on.
# It's for printf printing of numbers in such a way that sorted order will be
# correct.
def GetFormatString(d):
ans = 1
while (d >= 10):
d //= 10 # integer division
ans += 1
# e.g. we might return the string '%01d' or '%02d'
return '%0{0}d'.format(ans)
if args.respect_speaker_info == 'true':
for spk in sorted(spk2utt.keys()):
uttlists = SplitIntoGroups(spk2utt[spk])
format_string = '%s-' + GetFormatString(len(uttlists))
for i in range(len(uttlists)):
# the following might look like: '%s-%02d'.format('john_smith' 9 + 1),
# giving 'john_smith-10'.
this_spk = format_string % (spk, i + 1)
for utt in uttlists[i]:
print(utt, this_spk)
else:
uttlists = SplitIntoGroups(sorted(utt2spk.keys()))
format_string = 'speaker-' + GetFormatString(len(uttlists))
for i in range(len(uttlists)):
# the following might look like: 'speaker-%04d'.format(105 + 1),
# giving 'speaker-0106'.
this_spk = format_string % (i + 1)
for utt in uttlists[i]:
print(utt, this_spk)