Blame view

egs/cifar/v1/image/get_allowed_lengths.py 4.86 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  #!/usr/bin/env python3
  
  # Copyright     2017  Hossein Hadian
  # Apache 2.0
  
  
  """ This script finds a set of allowed lengths for a given OCR/HWR data dir.
      The allowed lengths are spaced by a factor (like 10%) and are written
      in an output file named "allowed_lengths.txt" in the output data dir. This
      file is later used by make_features.py to pad each image sufficiently so that
      they all have an allowed length. This is intended for end2end chain training.
  """
  from __future__ import division
  
  import argparse
  import os
  import sys
  import copy
  import math
  import logging
  
  sys.path.insert(0, 'steps')
  import libs.common as common_lib
  
  logger = logging.getLogger('libs')
  logger.setLevel(logging.INFO)
  handler = logging.StreamHandler()
  handler.setLevel(logging.INFO)
  formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
                                "%(funcName)s - %(levelname)s ] %(message)s")
  handler.setFormatter(formatter)
  logger.addHandler(handler)
  
  def get_args():
      parser = argparse.ArgumentParser(description="""This script finds a set of
                                     allowed lengths for a given OCR/HWR data dir.
                                     Intended for chain training.""")
      parser.add_argument('factor', type=float, default=12,
                          help='Spacing (in percentage) between allowed lengths.')
      parser.add_argument('srcdir', type=str,
                          help='path to source data dir')
      parser.add_argument('--coverage-factor', type=float, default=0.05,
                          help="""Percentage of durations not covered from each
                               side of duration histogram.""")
      parser.add_argument('--frame-subsampling-factor', type=int, default=3,
                          help="""Chain frame subsampling factor.
                               See steps/nnet3/chain/train.py""")
  
      args = parser.parse_args()
      return args
  
  
  def read_kaldi_mapfile(path):
      """ Read any Kaldi mapping file - like text, .scp files, etc.
      """
  
      m = {}
      with open(path, 'r', encoding='latin-1') as f:
          for line in f:
              line = line.strip()
              sp_pos = line.find(' ')
              key = line[:sp_pos]
              val = line[sp_pos+1:]
              m[key] = val
      return m
  
  def find_duration_range(img2len, coverage_factor):
      """Given a list of utterances, find the start and end duration to cover
  
       If we try to cover
       all durations which occur in the training set, the number of
       allowed lengths could become very large.
  
       Returns
       -------
       start_dur: int
       end_dur: int
      """
      durs = []
      for im, imlen in img2len.items():
          durs.append(int(imlen))
      durs.sort()
      to_ignore_dur = 0
      tot_dur = sum(durs)
      for d in durs:
          to_ignore_dur += d
          if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
              start_dur = d
              break
      to_ignore_dur = 0
      for d in reversed(durs):
          to_ignore_dur += d
          if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
              end_dur = d
              break
      if start_dur < 30:
          start_dur = 30  # a hard limit to avoid too many allowed lengths --not critical
      return start_dur, end_dur
  
  
  def find_allowed_durations(start_len, end_len, args):
      """Given the start and end duration, find a set of
         allowed durations spaced by args.factor%. Also write
         out the list of allowed durations and the corresponding
         allowed lengths (in frames) on disk.
  
       Returns
       -------
       allowed_durations: list of allowed durations (in seconds)
      """
  
      allowed_lengths = []
      length = start_len
      with open(os.path.join(args.srcdir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as fp:
          while length < end_len:
              if length % args.frame_subsampling_factor != 0:
                  length = (args.frame_subsampling_factor *
                            (length // args.frame_subsampling_factor))
              allowed_lengths.append(length)
              fp.write("{}
  ".format(int(length)))
              length = max(length * args.factor, length + args.frame_subsampling_factor)
      return allowed_lengths
  
  
  
  def main():
      args = get_args()
      args.factor = 1.0 + args.factor/100.0
  
      image2length = read_kaldi_mapfile(os.path.join(args.srcdir, 'image2num_frames'))
  
      start_dur, end_dur = find_duration_range(image2length, args.coverage_factor)
      logger.info("Lengths in the range [{},{}] will be covered. "
                  "Coverage rate: {}%".format(start_dur, end_dur,
                                        100.0 - args.coverage_factor * 2))
      logger.info("There will be {} unique allowed lengths "
                  "for the images.".format(int((math.log(float(end_dur)/start_dur))/
                                               math.log(args.factor))))
  
      allowed_durations = find_allowed_durations(start_dur, end_dur, args)
  
  
  if __name__ == '__main__':
        main()