Blame view

egs/iam/v1/local/remove_test_utterances_from_lob.py 4.16 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  #!/usr/bin/env python3
  # Copyright   2018 Ashish Arora
  
  import argparse
  import os
  import numpy as np
  import sys
  import re
  
  parser = argparse.ArgumentParser(description="""Removes dev/test set lines
                                                  from the LOB corpus. Reads the
                                                  corpus from stdin, and writes it to stdout.""")
  parser.add_argument('dev_text', type=str,
                      help='dev transcription location.')
  parser.add_argument('test_text', type=str,
                      help='test transcription location.')
  args = parser.parse_args()
  
  def remove_punctuations(transcript):
      char_list = []
      for char in transcript:
          if char.isdigit() or char == '+' or char == '~' or char == '?':
              continue
          if char == '#' or char == '=' or char == '-' or char == '!':
              continue
          if char == ',' or char == '.' or char == ')' or char == '\'':
              continue
          if char == '(' or char == ':' or char == ';' or char == '"':
              continue
          char_list.append(char)
      return char_list
  
  
  def remove_special_words(words):
      word_list = []
      for word in words:
          if word == '<SIC>' or word == '#':
              continue
          word_list.append(word)
      return word_list
  
  
  # process and add dev/eval transcript in a list
  # remove special words, punctuations, spaces between words
  # lowercase the characters
  def read_utterances(text_file_path):
      with open(text_file_path, 'rt') as in_file:
          for line in in_file:
              words = line.strip().split()
              words_wo_sw = remove_special_words(words)
              transcript = ''.join(words_wo_sw[1:])
              transcript = transcript.lower()
              trans_wo_punct = remove_punctuations(transcript)
              transcript = ''.join(trans_wo_punct)
              utterance_dict[words_wo_sw[0]] = transcript
  
  
  ### main ###
  
  # read utterances and add it to utterance_dict
  utterance_dict = dict()
  read_utterances(args.dev_text)
  read_utterances(args.test_text)
  
  # read corpus and add it to below lists
  corpus_text_lowercase_wo_sc = list()
  corpus_text_wo_sc = list()
  original_corpus_text = list()
  for line in sys.stdin:
      original_corpus_text.append(line)
      words = line.strip().split()
      words_wo_sw = remove_special_words(words)
  
      transcript = ''.join(words_wo_sw)
      transcript = transcript.lower()
      trans_wo_punct = remove_punctuations(transcript)
      transcript = ''.join(trans_wo_punct)
      corpus_text_lowercase_wo_sc.append(transcript)
  
      transcript = ''.join(words_wo_sw)
      trans_wo_punct = remove_punctuations(transcript)
      transcript = ''.join(trans_wo_punct)
      corpus_text_wo_sc.append(transcript)
  
  # find majority of utterances below
  # for utterances which were not found
  # add them to remaining_utterances
  row_to_keep = [True for i in range(len(original_corpus_text))]
  remaining_utterances = dict()
  for line_id, line_to_find in utterance_dict.items():
      found_line = False
      for i in range(1, (len(corpus_text_lowercase_wo_sc) - 2)):
          # Combine 3 consecutive lines of the corpus into a single line
          prev_words = corpus_text_lowercase_wo_sc[i - 1].strip()
          curr_words = corpus_text_lowercase_wo_sc[i].strip()
          next_words = corpus_text_lowercase_wo_sc[i + 1].strip()
          new_line = prev_words + curr_words + next_words
          transcript = ''.join(new_line)
          if line_to_find in transcript:
              found_line = True
              row_to_keep[i-1] = False
              row_to_keep[i] = False
              row_to_keep[i+1] = False
      if not found_line:
          remaining_utterances[line_id] = line_to_find
  
  
  for i in range(len(original_corpus_text)):
      transcript = original_corpus_text[i].strip()
      if row_to_keep[i]:
          print(transcript)
  
  print('Sentences not removed from LOB: {}'.format(remaining_utterances), file=sys.stderr)
  print('Total test+dev sentences: {}'.format(len(utterance_dict)), file=sys.stderr)
  print('Number of sentences not removed from LOB: {}'. format(len(remaining_utterances)), file=sys.stderr)
  print('LOB lines: Before: {}   After: {}'.format(len(original_corpus_text),
                                                   row_to_keep.count(True)), file=sys.stderr)