Blame view
egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py
6.53 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
#!/usr/bin/env python # Dongji Gao # We're using python 3.x style but want it to work in python 2.x from __future__ import print_function import argparse import sys import math parser = argparse.ArgumentParser(description="This script evaluates the log probabilty (default log base is e) of each sentence " "from data (in text form), given a language model in arpa form " "and a specific ngram order.", epilog="e.g. ./compute_sentence_probs_arpa.py ARPA_LM NGRAM_ORDER TEXT_IN PROB_FILE --log-base=LOG_BASE", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("arpa_lm", type=str, help="Input language model in arpa form.") parser.add_argument("ngram_order", type=int, help="Order of ngram") parser.add_argument("text_in", type=str, help="Filename of input text file (each line will be interpreted as a sentence).") parser.add_argument("prob_file", type=str, help="Filename of output probability file.") parser.add_argument("--log-base", type=float, default=math.exp(1), help="Log base for log porbability") args = parser.parse_args() def check_args(args): args.text_in_handle = sys.stdin if args.text_in == "-" else open(args.text_in, "r") args.prob_file_handle = sys.stdout if args.prob_file == "-" else open(args.prob_file, "w") if args.log_base <= 0: sys.exit("compute_sentence_probs_arpa.py: Invalid log base (must be greater than 0)") def is_logprob(input): if input[0] == "-": try: float(input[1:]) return True except: return False else: return False def check_number(model_file, tot_num): cur_num = 0 max_ngram_order = 0 with open(model_file) as model: lines = model.readlines() for line in lines[1:]: if "=" not in line: return (cur_num == tot_num), max_ngram_order cur_num += int(line.split("=")[-1]) max_ngram_order = int(line.split("=")[0].split()[-1]) # This function load language model in arpa form and save in a dictionary for # computing sentence probabilty of input text file. def load_model(model_file): with open(model_file) as model: ngram_dict = {} lines = model.readlines() # check arpa form if lines[0][:-1] != "\\data\\": sys.exit("compute_sentence_probs_arpa.py: Please make sure that language model is in arpa form.") # read line for line in lines: if line[0] == "-": line_split = line.split() if is_logprob(line_split[-1]): ngram_key = " ".join(line_split[1:-1]) if ngram_key in ngram_dict: sys.exit("compute_sentence_probs_arpa.py: Duplicated ngram in arpa language model: {}.".format(ngram_key)) ngram_dict[ngram_key] = (line_split[0], line_split[-1]) else: ngram_key = " ".join(line_split[1:]) if ngram_key in ngram_dict: sys.exit("compute_sentence_probs_arpa.py: Duplicated ngram in arpa language model: {}.".format(ngram_key)) ngram_dict[ngram_key] = (line_split[0],) return ngram_dict, len(ngram_dict) def compute_sublist_prob(sub_list): if len(sub_list) == 0: sys.exit("compute_sentence_probs_arpa.py: Ngram substring not found in arpa language model, please check.") sub_string = " ".join(sub_list) if sub_string in ngram_dict: return -float(ngram_dict[sub_string][0][1:]) else: backoff_substring = " ".join(sub_list[:-1]) backoff_weight = 0.0 if (backoff_substring not in ngram_dict or len(ngram_dict[backoff_substring]) < 2) \ else -float(ngram_dict[backoff_substring][1][1:]) return compute_sublist_prob(sub_list[1:]) + backoff_weight def compute_begin_prob(sub_list): logprob = 0 for i in range(1, len(sub_list) - 1): logprob += compute_sublist_prob(sub_list[:i + 1]) return logprob # The probability is computed in this way: # p(word_N | word_N-1 ... word_1) = ngram_dict[word_1 ... word_N][0]. # Here gram_dict is a dictionary stores a tuple corresponding to ngrams. # The first element of tuple is probablity and the second is backoff probability (if exists). # If the particular ngram (word_1 ... word_N) is not in the dictionary, then # p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1) # If the sequence (word_(N-1) ... word_1) is not in the dictionary, then the backoff_weight gets replaced with 0.0 (log1) # More details can be found in https://cmusphinx.github.io/wiki/arpaformat/ def compute_sentence_prob(sentence, ngram_order): sentence_split = sentence.split() for i in range(len(sentence_split)): if sentence_split[i] not in ngram_dict: sentence_split[i] = "<unk>" sen_length = len(sentence_split) if sen_length < ngram_order: return compute_begin_prob(sentence_split) else: logprob = 0 begin_sublist = sentence_split[:ngram_order] logprob += compute_begin_prob(begin_sublist) for i in range(sen_length - ngram_order + 1): cur_sublist = sentence_split[i : i + ngram_order] logprob += compute_sublist_prob(cur_sublist) return logprob def output_result(text_in_handle, output_file_handle, ngram_order): lines = text_in_handle.readlines() logbase_modifier = math.log(10, args.log_base) for line in lines: new_line = "<s> " + line[:-1] + " </s>" logprob = compute_sentence_prob(new_line, ngram_order) new_logprob = logprob * logbase_modifier output_file_handle.write("{} ".format(new_logprob)) text_in_handle.close() output_file_handle.close() if __name__ == "__main__": check_args(args) ngram_dict, tot_num = load_model(args.arpa_lm) num_valid, max_ngram_order = check_number(args.arpa_lm, tot_num) if not num_valid: sys.exit("compute_sentence_probs_arpa.py: Wrong loading model.") if args.ngram_order <= 0 or args.ngram_order > max_ngram_order: sys.exit("compute_sentence_probs_arpa.py: " + "Invalid ngram_order (either negative or greater than maximum ngram number ({}) allowed)".format(max_ngram_order)) output_result(args.text_in_handle, args.prob_file_handle, args.ngram_order) |