Blame view

egs/wsj/s5/utils/lang/compute_sentence_probs_arpa.py 6.53 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
  #!/usr/bin/env python
  
  # Dongji Gao
  
  # We're using python 3.x style but want it to work in python 2.x
  
  from __future__ import print_function
  import argparse
  import sys
  import math
  
  parser = argparse.ArgumentParser(description="This script evaluates the log probabilty (default log base is e) of each sentence "
                                               "from data (in text form), given a language model in arpa form "
                                               "and a specific ngram order.",
                                   epilog="e.g. ./compute_sentence_probs_arpa.py ARPA_LM NGRAM_ORDER TEXT_IN PROB_FILE --log-base=LOG_BASE",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument("arpa_lm", type=str,
                      help="Input language model in arpa form.")
  parser.add_argument("ngram_order", type=int,
                      help="Order of ngram")
  parser.add_argument("text_in", type=str,
                      help="Filename of input text file (each line will be interpreted as a sentence).")
  parser.add_argument("prob_file", type=str,
                      help="Filename of output probability file.")
  parser.add_argument("--log-base", type=float, default=math.exp(1),
                      help="Log base for log porbability")
  args = parser.parse_args()
  
  def check_args(args):
      args.text_in_handle = sys.stdin if args.text_in == "-" else open(args.text_in, "r")
      args.prob_file_handle = sys.stdout if args.prob_file == "-" else open(args.prob_file, "w")
      if args.log_base <= 0:
          sys.exit("compute_sentence_probs_arpa.py: Invalid log base (must be greater than 0)")
  
  def is_logprob(input):
      if input[0] == "-":
          try:
              float(input[1:])
              return True
          except:
              return False
      else:
          return False
  
  def check_number(model_file, tot_num):
      cur_num = 0
      max_ngram_order = 0
      with open(model_file) as model:
          lines = model.readlines()
          for line in lines[1:]:
              if "=" not in line:
                  return (cur_num == tot_num), max_ngram_order
              cur_num += int(line.split("=")[-1])
              max_ngram_order = int(line.split("=")[0].split()[-1])
  
  # This function load language model in arpa form and save in a dictionary for
  # computing sentence probabilty of input text file.
  def load_model(model_file):
      with open(model_file) as model:
          ngram_dict = {}
          lines = model.readlines()
  
          # check arpa form
          if lines[0][:-1] != "\\data\\":
              sys.exit("compute_sentence_probs_arpa.py: Please make sure that language model is in arpa form.")
  
          # read line
          for line in lines:
              if line[0] == "-":
                  line_split = line.split()
                  if is_logprob(line_split[-1]):
                      ngram_key = " ".join(line_split[1:-1])
                      if ngram_key in ngram_dict:
                          sys.exit("compute_sentence_probs_arpa.py: Duplicated ngram in arpa language model: {}.".format(ngram_key))
                      ngram_dict[ngram_key] = (line_split[0], line_split[-1])
                  else:
                      ngram_key = " ".join(line_split[1:])
                      if ngram_key in ngram_dict:
                          sys.exit("compute_sentence_probs_arpa.py: Duplicated ngram in arpa language model: {}.".format(ngram_key))
                      ngram_dict[ngram_key] = (line_split[0],)
  
      return ngram_dict, len(ngram_dict)
  
  def compute_sublist_prob(sub_list):
      if len(sub_list) == 0:
          sys.exit("compute_sentence_probs_arpa.py: Ngram substring not found in arpa language model, please check.")
  
      sub_string = " ".join(sub_list)
      if sub_string in ngram_dict:
          return -float(ngram_dict[sub_string][0][1:])
      else:
          backoff_substring = " ".join(sub_list[:-1])
          backoff_weight = 0.0 if (backoff_substring not in ngram_dict or len(ngram_dict[backoff_substring]) < 2) \
                           else -float(ngram_dict[backoff_substring][1][1:])
          return compute_sublist_prob(sub_list[1:]) + backoff_weight
  
  def compute_begin_prob(sub_list):
      logprob = 0
      for i in range(1, len(sub_list) - 1):
          logprob += compute_sublist_prob(sub_list[:i + 1])
      return logprob
  
  # The probability is computed in this way:
  # p(word_N | word_N-1 ... word_1) = ngram_dict[word_1 ... word_N][0].
  # Here gram_dict is a dictionary stores a tuple corresponding to ngrams.
  # The first element of tuple is probablity and the second is backoff probability (if exists).
  # If the particular ngram (word_1 ... word_N) is not in the dictionary, then
  # p(word_N | word_N-1 ... word_1) = p(word_N | word_(N-1) ... word_2) * backoff_weight(word_(N-1) | word_(N-2) ... word_1)
  # If the sequence (word_(N-1) ... word_1) is not in the dictionary, then the backoff_weight gets replaced with 0.0 (log1)
  # More details can be found in https://cmusphinx.github.io/wiki/arpaformat/
  def compute_sentence_prob(sentence, ngram_order):
      sentence_split = sentence.split()
      for i in range(len(sentence_split)):
          if sentence_split[i] not in ngram_dict:
              sentence_split[i] = "<unk>"
      sen_length = len(sentence_split)
  
      if sen_length < ngram_order:
          return compute_begin_prob(sentence_split)
      else:
          logprob = 0
          begin_sublist = sentence_split[:ngram_order]
          logprob += compute_begin_prob(begin_sublist)
  
          for i in range(sen_length - ngram_order + 1):
              cur_sublist = sentence_split[i : i + ngram_order]
              logprob += compute_sublist_prob(cur_sublist)
  
      return logprob
  
  
  def output_result(text_in_handle, output_file_handle, ngram_order):
      lines = text_in_handle.readlines()
      logbase_modifier = math.log(10, args.log_base)
      for line in lines:
          new_line = "<s> " + line[:-1] + " </s>"
          logprob = compute_sentence_prob(new_line, ngram_order)
          new_logprob = logprob * logbase_modifier
          output_file_handle.write("{}
  ".format(new_logprob))
      text_in_handle.close()
      output_file_handle.close()
  
  
  if __name__ == "__main__":
      check_args(args)
      ngram_dict, tot_num = load_model(args.arpa_lm)
  
      num_valid, max_ngram_order = check_number(args.arpa_lm, tot_num)
      if not num_valid:
          sys.exit("compute_sentence_probs_arpa.py: Wrong loading model.")
      if args.ngram_order <= 0 or args.ngram_order > max_ngram_order:
          sys.exit("compute_sentence_probs_arpa.py: " +
              "Invalid ngram_order (either negative or greater than maximum ngram number ({}) allowed)".format(max_ngram_order))
  
      output_result(args.text_in_handle, args.prob_file_handle, args.ngram_order)