Blame view
egs/iam/v1/local/remove_test_utterances_from_lob.py
4.16 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
#!/usr/bin/env python3 # Copyright 2018 Ashish Arora import argparse import os import numpy as np import sys import re parser = argparse.ArgumentParser(description="""Removes dev/test set lines from the LOB corpus. Reads the corpus from stdin, and writes it to stdout.""") parser.add_argument('dev_text', type=str, help='dev transcription location.') parser.add_argument('test_text', type=str, help='test transcription location.') args = parser.parse_args() def remove_punctuations(transcript): char_list = [] for char in transcript: if char.isdigit() or char == '+' or char == '~' or char == '?': continue if char == '#' or char == '=' or char == '-' or char == '!': continue if char == ',' or char == '.' or char == ')' or char == '\'': continue if char == '(' or char == ':' or char == ';' or char == '"': continue char_list.append(char) return char_list def remove_special_words(words): word_list = [] for word in words: if word == '<SIC>' or word == '#': continue word_list.append(word) return word_list # process and add dev/eval transcript in a list # remove special words, punctuations, spaces between words # lowercase the characters def read_utterances(text_file_path): with open(text_file_path, 'rt') as in_file: for line in in_file: words = line.strip().split() words_wo_sw = remove_special_words(words) transcript = ''.join(words_wo_sw[1:]) transcript = transcript.lower() trans_wo_punct = remove_punctuations(transcript) transcript = ''.join(trans_wo_punct) utterance_dict[words_wo_sw[0]] = transcript ### main ### # read utterances and add it to utterance_dict utterance_dict = dict() read_utterances(args.dev_text) read_utterances(args.test_text) # read corpus and add it to below lists corpus_text_lowercase_wo_sc = list() corpus_text_wo_sc = list() original_corpus_text = list() for line in sys.stdin: original_corpus_text.append(line) words = line.strip().split() words_wo_sw = remove_special_words(words) transcript = ''.join(words_wo_sw) transcript = transcript.lower() trans_wo_punct = remove_punctuations(transcript) transcript = ''.join(trans_wo_punct) corpus_text_lowercase_wo_sc.append(transcript) transcript = ''.join(words_wo_sw) trans_wo_punct = remove_punctuations(transcript) transcript = ''.join(trans_wo_punct) corpus_text_wo_sc.append(transcript) # find majority of utterances below # for utterances which were not found # add them to remaining_utterances row_to_keep = [True for i in range(len(original_corpus_text))] remaining_utterances = dict() for line_id, line_to_find in utterance_dict.items(): found_line = False for i in range(1, (len(corpus_text_lowercase_wo_sc) - 2)): # Combine 3 consecutive lines of the corpus into a single line prev_words = corpus_text_lowercase_wo_sc[i - 1].strip() curr_words = corpus_text_lowercase_wo_sc[i].strip() next_words = corpus_text_lowercase_wo_sc[i + 1].strip() new_line = prev_words + curr_words + next_words transcript = ''.join(new_line) if line_to_find in transcript: found_line = True row_to_keep[i-1] = False row_to_keep[i] = False row_to_keep[i+1] = False if not found_line: remaining_utterances[line_id] = line_to_find for i in range(len(original_corpus_text)): transcript = original_corpus_text[i].strip() if row_to_keep[i]: print(transcript) print('Sentences not removed from LOB: {}'.format(remaining_utterances), file=sys.stderr) print('Total test+dev sentences: {}'.format(len(utterance_dict)), file=sys.stderr) print('Number of sentences not removed from LOB: {}'. format(len(remaining_utterances)), file=sys.stderr) print('LOB lines: Before: {} After: {}'.format(len(original_corpus_text), row_to_keep.count(True)), file=sys.stderr) |