remove_test_utterances_from_lob.py
4.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python3
# Copyright 2018 Ashish Arora
import argparse
import os
import numpy as np
import sys
import re
parser = argparse.ArgumentParser(description="""Removes dev/test set lines
from the LOB corpus. Reads the
corpus from stdin, and writes it to stdout.""")
parser.add_argument('dev_text', type=str,
help='dev transcription location.')
parser.add_argument('test_text', type=str,
help='test transcription location.')
args = parser.parse_args()
def remove_punctuations(transcript):
char_list = []
for char in transcript:
if char.isdigit() or char == '+' or char == '~' or char == '?':
continue
if char == '#' or char == '=' or char == '-' or char == '!':
continue
if char == ',' or char == '.' or char == ')' or char == '\'':
continue
if char == '(' or char == ':' or char == ';' or char == '"':
continue
char_list.append(char)
return char_list
def remove_special_words(words):
word_list = []
for word in words:
if word == '<SIC>' or word == '#':
continue
word_list.append(word)
return word_list
# process and add dev/eval transcript in a list
# remove special words, punctuations, spaces between words
# lowercase the characters
def read_utterances(text_file_path):
with open(text_file_path, 'rt') as in_file:
for line in in_file:
words = line.strip().split()
words_wo_sw = remove_special_words(words)
transcript = ''.join(words_wo_sw[1:])
transcript = transcript.lower()
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
utterance_dict[words_wo_sw[0]] = transcript
### main ###
# read utterances and add it to utterance_dict
utterance_dict = dict()
read_utterances(args.dev_text)
read_utterances(args.test_text)
# read corpus and add it to below lists
corpus_text_lowercase_wo_sc = list()
corpus_text_wo_sc = list()
original_corpus_text = list()
for line in sys.stdin:
original_corpus_text.append(line)
words = line.strip().split()
words_wo_sw = remove_special_words(words)
transcript = ''.join(words_wo_sw)
transcript = transcript.lower()
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
corpus_text_lowercase_wo_sc.append(transcript)
transcript = ''.join(words_wo_sw)
trans_wo_punct = remove_punctuations(transcript)
transcript = ''.join(trans_wo_punct)
corpus_text_wo_sc.append(transcript)
# find majority of utterances below
# for utterances which were not found
# add them to remaining_utterances
row_to_keep = [True for i in range(len(original_corpus_text))]
remaining_utterances = dict()
for line_id, line_to_find in utterance_dict.items():
found_line = False
for i in range(1, (len(corpus_text_lowercase_wo_sc) - 2)):
# Combine 3 consecutive lines of the corpus into a single line
prev_words = corpus_text_lowercase_wo_sc[i - 1].strip()
curr_words = corpus_text_lowercase_wo_sc[i].strip()
next_words = corpus_text_lowercase_wo_sc[i + 1].strip()
new_line = prev_words + curr_words + next_words
transcript = ''.join(new_line)
if line_to_find in transcript:
found_line = True
row_to_keep[i-1] = False
row_to_keep[i] = False
row_to_keep[i+1] = False
if not found_line:
remaining_utterances[line_id] = line_to_find
for i in range(len(original_corpus_text)):
transcript = original_corpus_text[i].strip()
if row_to_keep[i]:
print(transcript)
print('Sentences not removed from LOB: {}'.format(remaining_utterances), file=sys.stderr)
print('Total test+dev sentences: {}'.format(len(utterance_dict)), file=sys.stderr)
print('Number of sentences not removed from LOB: {}'. format(len(remaining_utterances)), file=sys.stderr)
print('LOB lines: Before: {} After: {}'.format(len(original_corpus_text),
row_to_keep.count(True)), file=sys.stderr)