allocate_multilingual_examples.py
10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python3
# Copyright 2017 Pegah Ghahremani
# 2018 Hossein Hadian
#
# Apache 2.0.
""" This script generates examples for multilingual training of neural network.
This scripts produces 3 sets of files --
egs.*.scp, egs.output.*.ark, egs.weight.*.ark
egs.*.scp are the SCP files of the training examples.
egs.weight.*.ark map from the key of the example to the language-specific
weight of that example.
egs.output.*.ark map from the key of the example to the name of
the output-node in the neural net for that specific language, e.g.
'output-2'.
--egs-prefix option can be used to generate train and diagnostics egs files.
If --egs-prefix=train_diagnostics. is passed, then the files produced by the
script will be named with the prefix as "train_diagnostics."
instead of "egs."
i.e. the files produced are -- train_diagnostics.*.scp,
train_diagnostics.output.*.ark, train_diagnostics.weight.*.ark and
train_diagnostics.ranges.*.txt.
The other egs-prefix options used in the recipes are "valid_diagnositics."
for validation examples and "combine." for examples used for model
combination.
For chain training egs, the --egs-prefix option should be "cegs."
You can call this script as (e.g.):
allocate_multilingual_examples.py [opts] example-scp-lists
multilingual-egs-dir
allocate_multilingual_examples.py --block-size 512
--lang2weight "0.2,0.8" exp/lang1/egs.scp exp/lang2/egs.scp
exp/multi/egs
"""
import os, argparse, sys, random
import logging
import traceback
sys.path.insert(0, 'steps')
logger = logging.getLogger('libs')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
"%(funcName)s - %(levelname)s ] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.info('Start generating multilingual examples')
def get_args():
parser = argparse.ArgumentParser(
description=""" This script generates examples for multilingual training
of neural network by producing 3 sets of primary files
as egs.*.scp, egs.output.*.ark, egs.weight.*.ark.
egs.*.scp are the SCP files of the training examples.
egs.weight.*.ark map from the key of the example to the language-specific
weight of that example.
egs.output.*.ark map from the key of the example to the name of
the output-node in the neural net for that specific language, e.g.
'output-2'.""",
epilog="Called by steps/nnet3/multilingual/combine_egs.sh")
parser.add_argument("--num-archives", type=int, default=None,
help="Number of archives to split the data into. (Note: in reality they are not "
"archives, only scp files, but we use this notation by analogy with the "
"conventional egs-creating script).")
parser.add_argument("--block-size", type=int, default=512,
help="This relates to locality of disk access. 'block-size' is"
"the average number of examples that are read consecutively"
"from each input scp file (and are written in the same order to the output scp files)"
"Smaller values lead to more random disk access (during "
"the nnet3 training process).")
parser.add_argument("--egs-prefix", type=str, default="egs.",
help="This option can be used to add a prefix to the filenames "
"of the output files. For e.g. "
"if --egs-prefix=combine. , then the files produced "
"by this script will be "
"combine.output.*.ark, combine.weight.*.ark, and combine.*.scp")
parser.add_argument("--lang2weight", type=str,
help="Comma-separated list of weights, one per language. "
"The language order is as egs_scp_lists.")
# now the positional arguments
parser.add_argument("egs_scp_lists", nargs='+',
help="List of egs.scp files per input language."
"e.g. exp/lang1/egs/egs.scp exp/lang2/egs/egs.scp")
parser.add_argument("egs_dir",
help="Name of output egs directory e.g. exp/tdnn_multilingual_sp/egs")
print(sys.argv, file=sys.stderr)
args = parser.parse_args()
return args
def read_lines(file_handle, num_lines):
n_read = 0
lines = []
while n_read < num_lines:
line = file_handle.readline()
if not line:
break
lines.append(line.strip())
n_read += 1
return lines
def process_multilingual_egs(args):
args = get_args()
scp_lists = args.egs_scp_lists
num_langs = len(scp_lists)
lang_to_num_examples = [0] * num_langs
for lang in range(num_langs):
with open(scp_lists[lang]) as fh:
lang_to_num_examples[lang] = sum([1 for line in fh])
logger.info("Number of examples for language {0} "
"is {1}.".format(lang, lang_to_num_examples[lang]))
# If weights are not provided, the weights are 1.0.
if args.lang2weight is None:
lang2weight = [1.0] * num_langs
else:
lang2weight = args.lang2weight.split(",")
assert(len(lang2weight) == num_langs)
if not os.path.exists(os.path.join(args.egs_dir, 'info')):
os.makedirs(os.path.join(args.egs_dir, 'info'))
with open("{0}/info/{1}num_tasks".format(args.egs_dir, args.egs_prefix), "w") as fh:
print("{0}".format(num_langs), file=fh)
# Total number of egs in all languages
tot_num_egs = sum(lang_to_num_examples[i] for i in range(num_langs))
num_archives = args.num_archives
with open("{0}/info/{1}num_archives".format(args.egs_dir, args.egs_prefix), "w") as fh:
print("{0}".format(num_archives), file=fh)
logger.info("There are a total of {} examples in the input scp "
"files.".format(tot_num_egs))
logger.info("Number of blocks in each output archive will be approximately "
"{}, and block-size is {}.".format(int(round(tot_num_egs / num_archives / args.block_size)),
args.block_size))
for lang in range(num_langs):
blocks_per_archive_this_lang = lang_to_num_examples[lang] / num_archives / args.block_size
warning = ""
if blocks_per_archive_this_lang < 1.0:
warning = ("Warning: This means some of the output archives might "
"not include any examples from this lang.")
logger.info("The proportion of egs from lang {} is {:.2f}. The number of blocks "
"per archive for this lang is approximately {:.2f}. "
"{}".format(lang, float(lang_to_num_examples[lang]) / tot_num_egs,
blocks_per_archive_this_lang,
warning))
in_scp_file_handles = [open(scp_lists[lang], 'r') for lang in range(num_langs)]
num_remaining_egs = tot_num_egs
lang_to_num_remaining_egs = [n for n in lang_to_num_examples]
for archive_index in range(num_archives + 1): # +1 is because we write to the last archive in two rounds
num_remaining_archives = num_archives - archive_index
num_remaining_blocks = float(num_remaining_egs) / args.block_size
last_round = (archive_index == num_archives)
if not last_round:
num_blocks_this_archive = int(round(float(num_remaining_blocks) / num_remaining_archives))
logger.info("Generating archive {} containing {} blocks...".format(archive_index, num_blocks_this_archive))
else: # This is the second round for the last archive. Flush all the remaining egs...
archive_index = num_archives - 1
num_blocks_this_archive = num_langs
logger.info("Writing all the {} remaining egs to the last archive...".format(num_remaining_egs))
out_scp_file_handle = open('{0}/{1}{2}.scp'.format(args.egs_dir, args.egs_prefix, archive_index + 1),
'a' if last_round else 'w')
eg_to_output_file_handle = open("{0}/{1}output.{2}.ark".format(args.egs_dir, args.egs_prefix, archive_index + 1),
'a' if last_round else 'w')
eg_to_weight_file_handle = open("{0}/{1}weight.{2}.ark".format(args.egs_dir, args.egs_prefix, archive_index + 1),
'a' if last_round else 'w')
for block_index in range(num_blocks_this_archive):
# Find the lang with the highest proportion of remaining examples
remaining_proportions = [float(remain) / tot for remain, tot in zip(lang_to_num_remaining_egs, lang_to_num_examples)]
lang_index, max_proportion = max(enumerate(remaining_proportions), key=lambda a: a[1])
# Read 'block_size' examples from the selected lang and write them to the current output scp file:
example_lines = read_lines(in_scp_file_handles[lang_index], args.block_size)
for eg_line in example_lines:
eg_id = eg_line.split()[0]
print(eg_line, file=out_scp_file_handle)
print("{0} output-{1}".format(eg_id, lang_index), file=eg_to_output_file_handle)
print("{0} {1}".format(eg_id, lang2weight[lang_index]), file=eg_to_weight_file_handle)
num_remaining_egs -= len(example_lines)
lang_to_num_remaining_egs[lang_index] -= len(example_lines)
out_scp_file_handle.close()
eg_to_output_file_handle.close()
eg_to_weight_file_handle.close()
for handle in in_scp_file_handles:
handle.close()
logger.info("Finished generating {0}*.scp, {0}output.*.ark "
"and {0}weight.*.ark files. Wrote a total of {1} examples "
"to {2} archives.".format(args.egs_prefix,
tot_num_egs - num_remaining_egs, num_archives))
def main():
try:
args = get_args()
process_multilingual_egs(args)
except Exception as e:
traceback.print_exc()
sys.exit(1)
if __name__ == "__main__":
main()