Blame view

egs/wsj/s5/utils/data/extend_segment_times.py 4.71 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  #!/usr/bin/env python
  
  from __future__ import print_function
  import sys
  import argparse
  from collections import defaultdict
  
  
  parser = argparse.ArgumentParser(description="""
   Usage: extend_segment_times.py [options] <input-segments >output-segments
   This program pads the times in a 'segments' file (e.g. data/train/segments)
   with specified left and right context (for cases where there was no
   silence padding in the original segments file)""")
  
  parser.add_argument("--start-padding", type = float, default = 0.1,
                      help="Amount of padding, in seconds, for the start time of "
                      "each segment (start times <0 will be set to zero).")
  parser.add_argument("--end-padding", type = float, default = 0.1,
                      help="Amount of padding, in seconds, for the end time of "
                      "each segment.")
  parser.add_argument("--last-segment-end-padding", type = float, default = 0.1,
                      help="Amount of padding, in seconds, for the end time of "
                      "the last segment of each file (maximum allowed).")
  parser.add_argument("--fix-overlapping-segments", type = str,
                      default = 'true', choices=['true', 'false'],
                      help="If true, prevent segments from overlapping as a result "
                      "of the padding (or that were already overlapping)")
  args = parser.parse_args()
  
  
  # the input file will be a sequence of lines which are each of the form:
  # <utterance-id> <recording-id> <start-time> <end-time>
  # e.g.
  # utt-1 recording-1 0.62 5.40
  # The output will be in the same format and in the same
  # order, except wiht modified times.
  
  # This variable maps from a recording-id to a listof the utterance
  # indexes (as integer indexes into 'entries']
  # that are part of that recording.
  recording_to_utt_indexes = defaultdict(list)
  
  # This is an array of the entries in the segments file, in the fomrat:
  # (utterance-id as astring, recording-id as string,
  #  start-time as float, end-time as float)
  entries = []
  
  
  while True:
      line = sys.stdin.readline()
      if line == '':
          break
      try:
          [ utt_id, recording_id, start_time, end_time ] = line.split()
          start_time = float(start_time)
          end_time = float(end_time)
      except:
          sys.exit("extend_segment_times.py: could not interpret line: " + line)
      if not end_time > start_time:
          print("extend_segment_times.py: bad segment (ignoring): " + line,
                file = sys.stderr)
      recording_to_utt_indexes[recording_id].append(len(entries))
      entries.append([utt_id, recording_id, start_time, end_time])
  
  num_times_fixed = 0
  
  for recording, utt_indexes in recording_to_utt_indexes.items():
      # this_entries is a list of lists, sorted on mid-time.
      # Notice: because lists are objects, when we change 'this_entries'
      # we change the underlying entries.
      this_entries = sorted([ entries[x] for x in utt_indexes ],
                            key = lambda x : 0.5 * (x[2] + x[3]))
      min_time = 0
      max_time = max([ x[3] for x in this_entries ]) + args.last_segment_end_padding
      start_padding = args.start_padding
      end_padding = args.end_padding
      for n in range(len(this_entries)):
          this_entries[n][2] = max(min_time, this_entries[n][2] - start_padding)
          this_entries[n][3] = min(max_time, this_entries[n][3] + end_padding)
  
      for n in range(len(this_entries) - 1):
          this_end_time = this_entries[n][3]
          next_start_time = this_entries[n+1][2]
          if this_end_time > next_start_time and args.fix_overlapping_segments == 'true':
              midpoint = 0.5 * (this_end_time + next_start_time)
              this_entries[n][3] = midpoint
              this_entries[n+1][2] = midpoint
              num_times_fixed += 1
  
  
  # this prints a number with a certain number of digits after
  # the point, while removing trailing zeros.
  def FloatToString(f):
      num_digits = 6 # we want to print 6 digits after the zero
      g = f
      while abs(g) > 1.0:
          g *= 0.1
          num_digits += 1
      format_str = '%.{0}g'.format(num_digits)
      return format_str % f
  
  for entry in entries:
      [ utt_id, recording_id, start_time, end_time ] = entry
      if not start_time < end_time:
          print("extend_segment_times.py: bad segment after processing (ignoring): " +
                ' '.join(entry), file = sys.stderr)
          continue
      print(utt_id, recording_id, FloatToString(start_time), FloatToString(end_time))
  
  
  print("extend_segment_times.py: extended {0} segments; fixed {1} "
        "overlapping segments".format(len(entries), num_times_fixed),
        file = sys.stderr)
  
  ## test:
  #  (echo utt1 reco1 0.2 6.2; echo utt2 reco1 6.3 9.8 )| extend_segment_times.py
  # and also try the above with the options --last-segment-end-padding=0.0 --fix-overlapping-segments=false