get_allowed_lengths.py
4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
# Copyright 2017 Hossein Hadian
# Apache 2.0
""" This script finds a set of allowed lengths for a given OCR/HWR data dir.
The allowed lengths are spaced by a factor (like 10%) and are written
in an output file named "allowed_lengths.txt" in the output data dir. This
file is later used by make_features.py to pad each image sufficiently so that
they all have an allowed length. This is intended for end2end chain training.
"""
from __future__ import division
import argparse
import os
import sys
import copy
import math
import logging
sys.path.insert(0, 'steps')
import libs.common as common_lib
logger = logging.getLogger('libs')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
"%(funcName)s - %(levelname)s ] %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def get_args():
parser = argparse.ArgumentParser(description="""This script finds a set of
allowed lengths for a given OCR/HWR data dir.
Intended for chain training.""")
parser.add_argument('factor', type=float, default=12,
help='Spacing (in percentage) between allowed lengths.')
parser.add_argument('srcdir', type=str,
help='path to source data dir')
parser.add_argument('--coverage-factor', type=float, default=0.05,
help="""Percentage of durations not covered from each
side of duration histogram.""")
parser.add_argument('--frame-subsampling-factor', type=int, default=3,
help="""Chain frame subsampling factor.
See steps/nnet3/chain/train.py""")
args = parser.parse_args()
return args
def read_kaldi_mapfile(path):
""" Read any Kaldi mapping file - like text, .scp files, etc.
"""
m = {}
with open(path, 'r', encoding='latin-1') as f:
for line in f:
line = line.strip()
sp_pos = line.find(' ')
key = line[:sp_pos]
val = line[sp_pos+1:]
m[key] = val
return m
def find_duration_range(img2len, coverage_factor):
"""Given a list of utterances, find the start and end duration to cover
If we try to cover
all durations which occur in the training set, the number of
allowed lengths could become very large.
Returns
-------
start_dur: int
end_dur: int
"""
durs = []
for im, imlen in img2len.items():
durs.append(int(imlen))
durs.sort()
to_ignore_dur = 0
tot_dur = sum(durs)
for d in durs:
to_ignore_dur += d
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
start_dur = d
break
to_ignore_dur = 0
for d in reversed(durs):
to_ignore_dur += d
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
end_dur = d
break
if start_dur < 30:
start_dur = 30 # a hard limit to avoid too many allowed lengths --not critical
return start_dur, end_dur
def find_allowed_durations(start_len, end_len, args):
"""Given the start and end duration, find a set of
allowed durations spaced by args.factor%. Also write
out the list of allowed durations and the corresponding
allowed lengths (in frames) on disk.
Returns
-------
allowed_durations: list of allowed durations (in seconds)
"""
allowed_lengths = []
length = start_len
with open(os.path.join(args.srcdir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as fp:
while length < end_len:
if length % args.frame_subsampling_factor != 0:
length = (args.frame_subsampling_factor *
(length // args.frame_subsampling_factor))
allowed_lengths.append(length)
fp.write("{}\n".format(int(length)))
length = max(length * args.factor, length + args.frame_subsampling_factor)
return allowed_lengths
def main():
args = get_args()
args.factor = 1.0 + args.factor/100.0
image2length = read_kaldi_mapfile(os.path.join(args.srcdir, 'image2num_frames'))
start_dur, end_dur = find_duration_range(image2length, args.coverage_factor)
logger.info("Lengths in the range [{},{}] will be covered. "
"Coverage rate: {}%".format(start_dur, end_dur,
100.0 - args.coverage_factor * 2))
logger.info("There will be {} unique allowed lengths "
"for the images.".format(int((math.log(float(end_dur)/start_dur))/
math.log(args.factor))))
allowed_durations = find_allowed_durations(start_dur, end_dur, args)
if __name__ == '__main__':
main()