make_phone_lm.py 44.6 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886
#!/usr/bin/env python

# Copyright 2016  Johns Hopkins University (Author: Daniel Povey)
# Apache 2.0.

from __future__ import print_function
from __future__ import division
import sys
import argparse
import math
from collections import defaultdict

# note, this was originally based

parser = argparse.ArgumentParser(description="""
This script creates a language model that's intended to be used in modeling
phone sequences (either of sentences or of dictionary entries), although of
course it will work for any type of data.  The easiest way
to describe it is as a a Kneser-Ney language model (unmodified, with addition)
with a fixed discounting constant equal to 1, except with no smoothing of the
bigrams (and hence no unigram state).  This is (a) because we want to keep the
graph after context expansion small, (b) because languages tend to have
constraints on which phones can follow each other, and (c) in order to get valid
sequences of word-position-dependent phones so that lattice-align-words can
work.  It also includes have a special entropy-based pruning technique that
backs off the statistics of pruned n-grams to lower-order states.

This script reads lines from its standard input, each
consisting of a sequence of integer symbol-ids (which should be > 0),
representing the phone sequences of a sentence or dictionary entry.
This script outputs a backoff language model in FST format""",
                                 epilog="See also utils/lang/make_phone_bigram_lang.sh")


parser.add_argument("--phone-disambig-symbol", type = int, required = False,
                    help = "Integer corresponding to an otherwise-unused "
                    "phone-level disambiguation symbol (e.g. #5).  This is "
                    "inserted at the beginning of the phone sequence and "
                    "whenever we back off.")
parser.add_argument("--ngram-order", type = int, default = 4,
                    choices = [2,3,4,5,6,7],
                    help = "Order of n-gram to use (but see also --num-extra-states;"
                    "the effective order after pruning may be less.")
parser.add_argument("--num-extra-ngrams", type = int, default = 20000,
                    help = "Target number of n-grams in addition to the n-grams in "
                    "the bigram LM states which can't be pruned away.  n-grams "
                    "will be pruned to reach this target.")
parser.add_argument("--no-backoff-ngram-order", type = int, default = 2,
                    choices = [1,2,3,4,5],
                    help = "This specifies the n-gram order at which (and below which) "
                    "no backoff or pruning should be done.  This is expected to normally "
                    "be bigram, but for testing purposes you may want to set it to "
                    "1.")
parser.add_argument("--print-as-arpa", type = str, default = "false",
                    choices = ["true", "false"],
                    help = "If true, print LM in ARPA format (default is to print "
                    "as FST).  You must also set --no-backoff-ngram-order=1 or "
                    "this is not allowed.")
parser.add_argument("--verbose", type = int, default = 0,
                    choices=[0,1,2,3,4,5], help = "Verbose level")

args = parser.parse_args()

if args.verbose >= 1:
    print(' '.join(sys.argv), file = sys.stderr)



class CountsForHistory(object):
    ## This class (which is more like a struct) stores the counts seen in a
    ## particular history-state.  It is used inside class NgramCounts.
    ## It really does the job of a dict from int to float, but it also
    ## keeps track of the total count.
    def __init__(self):
        # The 'lambda: defaultdict(float)' is an anonymous function taking no
        # arguments that returns a new defaultdict(float).
        self.word_to_count = defaultdict(int)
        self.total_count = 0

    def Words(self):
        return list(self.word_to_count.keys())

    def __str__(self):
        # e.g. returns ' total=12 3->4 4->6 -1->2'
        return ' total={0} {1}'.format(
            str(self.total_count),
            ' '.join(['{0} -> {1}'.format(word, count)
                      for word, count in self.word_to_count.items()]))


    ## Adds a certain count (expected to be integer, but might be negative).  If
    ## the resulting count for this word is zero, removes the dict entry from
    ## word_to_count.
    ## [note, though, that in some circumstances we 'add back' zero counts
    ## where the presence of n-grams would be structurally required by the arpa,
    ## specifically if a higher-order history state has a nonzero count,
    ## we need to structurally have the count there in the states it backs
    ## off to.
    def AddCount(self, predicted_word, count):
        self.total_count += count
        assert self.total_count >= 0
        old_count = self.word_to_count[predicted_word]
        new_count = old_count + count
        if new_count < 0:
            print("predicted-word={0}, old-count={1}, count={2}".format(
                    predicted_word, old_count, count))
        assert new_count >= 0
        if new_count == 0:
            del self.word_to_count[predicted_word]
        else:
            self.word_to_count[predicted_word] = new_count

class NgramCounts(object):
    ## A note on data-structure.  Firstly, all words are represented as
    ## integers.  We store n-gram counts as an array, indexed by (history-length
    ## == n-gram order minus one) (note: python calls arrays "lists") of dicts
    ## from histories to counts, where histories are arrays of integers and
    ## "counts" are dicts from integer to float.  For instance, when
    ## accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
    ## do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
    ## array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
    def __init__(self, ngram_order):
        assert ngram_order >= 2
        # Integerized counts will never contain negative numbers, so
        # inside this program, we use -3 and -2 for the BOS and EOS symbols
        # respectively.
        # Note: it's actually important that the bos-symbol is the most negative;
        # it helps ensure that we print the state with left-context <s> first
        # when we print the FST, and this means that the start-state will have
        # the correct value.
        self.bos_symbol = -3
        self.eos_symbol = -2
        # backoff_symbol is kind of a pseudo-word, it's used in keeping track of
        # the backoff counts in each state.
        self.backoff_symbol = -1
        self.total_num_words = 0  # count includes EOS but not BOS.
        self.counts = []
        for n in range(ngram_order):
            self.counts.append(defaultdict(lambda: CountsForHistory()))

    # adds a raw count (called while processing input data).
    # Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
    # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
    # 1.
    def AddCount(self, history, predicted_word, count):
        self.counts[len(history)][history].AddCount(predicted_word, count)


    # 'line' is a string containing a sequence of integer word-ids.
    # This function adds the un-smoothed counts from this line of text.
    def AddRawCountsFromLine(self, line):
        try:
            words = [self.bos_symbol] + [ int(x) for x in line.split() ] + [self.eos_symbol]
        except:
            sys.exit("make_phone_lm.py: bad input line {0} (expected a sequence "
                     "of integers)".format(line))

        for n in range(1, len(words)):
            predicted_word = words[n]
            history_start = max(0, n + 1 - args.ngram_order)
            history = tuple(words[history_start:n])
            self.AddCount(history, predicted_word, 1)
            self.total_num_words += 1

    def AddRawCountsFromStandardInput(self):
        lines_processed = 0
        while True:
            line = sys.stdin.readline()
            if line == '':
                break
            self.AddRawCountsFromLine(line)
            lines_processed += 1
        if lines_processed == 0 or args.verbose > 0:
            print("make_phone_lm.py: processed {0} lines of input".format(
                    lines_processed), file = sys.stderr)


    # This backs off the counts by subtracting 1 and assigning the subtracted
    # count to the backoff state.  It's like a special case of Kneser-Ney with D
    # = 1.  The optimal D would likely be something like 0.9, but we plan to
    # later do entropy-pruning, and the remaining small counts of 0.1 would
    # essentially all get pruned away anyway, so we don't lose much by doing it
    # like this.
    def ApplyBackoff(self):
        # note: in the normal case where args.no_backoff_ngram_order == 2 we
        # don't do backoff for history-length = 1 (i.e. for bigrams)... this is
        # a kind of special LM where we're not going to back off to unigram,
        # there will be no unigram.
        if args.verbose >= 1:
            initial_num_ngrams = self.GetNumNgrams()
        for n in reversed(list(range(args.no_backoff_ngram_order, args.ngram_order))):
            this_order_counts = self.counts[n]
            for hist, counts_for_hist in this_order_counts.items():
                backoff_hist = hist[1:]
                backoff_counts_for_hist = self.counts[n-1][backoff_hist]
                this_discount_total = 0
                for word in counts_for_hist.Words():
                    counts_for_hist.AddCount(word, -1)
                    # You can interpret the following line as incrementing the
                    # count-of-counts for the next-lower order.  Note, however,
                    # that later when we remove n-grams, we'll also add their
                    # counts to the next-lower-order history state, so the
                    # resulting counts won't strictly speaking be
                    # counts-of-counts.
                    backoff_counts_for_hist.AddCount(word, 1)
                    this_discount_total += 1
                counts_for_hist.AddCount(self.backoff_symbol, this_discount_total)

        if args.verbose >= 1:
            # Note: because D == 1, we completely back off singletons.
            print("make_phone_lm.py: ApplyBackoff() reduced the num-ngrams from "
                  "{0} to {1}".format(initial_num_ngrams, self.GetNumNgrams()),
                  file = sys.stderr)


    # This function prints out to stderr the n-gram counts stored in this
    # object; it's used for debugging.
    def Print(self, info_string):
        print(info_string, file=sys.stderr)
        # these are useful for debug.
        total = 0.0
        total_excluding_backoff = 0.0
        for this_order_counts in self.counts:
            for hist, counts_for_hist in this_order_counts.items():
                print(str(hist) + str(counts_for_hist), file = sys.stderr)
                total += counts_for_hist.total_count
                total_excluding_backoff += counts_for_hist.total_count
                if self.backoff_symbol in counts_for_hist.word_to_count:
                    total_excluding_backoff -= counts_for_hist.word_to_count[self.backoff_symbol]
        print('total count = {0}, excluding backoff = {1}'.format(
                total, total_excluding_backoff), file = sys.stderr)

    def GetHistToStateMap(self):
        # This function, called from PrintAsFst, returns a map from
        # history to integer FST-state.
        hist_to_state = dict()
        fst_state_counter = 0
        for n in range(0, args.ngram_order):
            for hist in self.counts[n].keys():
                hist_to_state[hist] = fst_state_counter
                fst_state_counter += 1
        return hist_to_state

    # Returns the probability of word 'word' in history-state 'hist'.
    # If 'word' is self.backoff_symbol, returns the backoff prob
    # of this history-state.
    # Returns None if there is no such word in this history-state, or this
    # history-state does not exist.
    def GetProb(self, hist, word):
        if len(hist) >= args.ngram_order or not hist in self.counts[len(hist)]:
            return None
        counts_for_hist = self.counts[len(hist)][hist]
        total_count = float(counts_for_hist.total_count)
        if not word in counts_for_hist.word_to_count:
            print("make_phone_lm.py: no prob for {0} -> {1} "
                  "[no such count]".format(hist, word),
                  file = sys.stderr)
            return None
        prob = float(counts_for_hist.word_to_count[word]) / total_count
        if len(hist) > 0 and word != self.backoff_symbol and \
          self.backoff_symbol in counts_for_hist.word_to_count:
            prob_in_backoff = self.GetProb(hist[1:], word)
            backoff_prob = float(counts_for_hist.word_to_count[self.backoff_symbol]) / total_count
            try:
                prob += backoff_prob * prob_in_backoff
            except:
                sys.exit("problem, hist is {0}, word is {1}".format(hist, word))
        return prob

    def PruneEmptyStates(self):
        # Removes history-states that have no counts.

        # It's possible in principle for history-states to have no counts and
        # yet they cannot be pruned away because a higher-order version of the
        # state exists with nonzero counts, so we have to keep track of this.
        protected_histories = set()

        states_removed_per_hist_len = [ 0 ] * args.ngram_order

        for n in reversed(list(range(args.no_backoff_ngram_order,
                                args.ngram_order))):
            num_states_removed = 0
            for hist, counts_for_hist in self.counts[n].items():
                l = len(counts_for_hist.word_to_count)
                assert l > 0 and self.backoff_symbol in counts_for_hist.word_to_count
                if l == 1 and not hist in protected_histories:  # only the backoff symbol has a count.
                    del self.counts[n][hist]
                    num_states_removed += 1
                else:
                    # if this state was not pruned away, then the state that
                    # it backs off to may not be pruned away either.
                    backoff_hist = hist[1:]
                    protected_histories.add(backoff_hist)
            states_removed_per_hist_len[n] = num_states_removed
        if args.verbose >= 1:
            print("make_phone_lm.py: in PruneEmptyStates(), num states removed for "
                  "each history-length was: " + str(states_removed_per_hist_len),
                  file = sys.stderr)

    def EnsureStructurallyNeededNgramsExist(self):
        # makes sure that if an n-gram like (6, 7, 8) -> 9 exists,
        # then counts exist for (7, 8) -> 9 and (8,) -> 9.  It does so
        # by adding zero counts where such counts were absent.
        # [note: () -> 9 is guaranteed anyway by the backoff method, if
        # we have a unigram state].
        if args.verbose >= 1:
            num_ngrams_initial = self.GetNumNgrams()
        for n in reversed(list(range(args.no_backoff_ngram_order,
                                args.ngram_order))):

            for hist, counts_for_hist in self.counts[n].items():
                # This loop ensures that if we have an n-gram like (6, 7, 8) -> 9,
                # then, say, (7, 8) -> 9 and (8) -> 9 exist.
                reduced_hist = hist
                for m in reversed(list(range(args.no_backoff_ngram_order, n))):
                    reduced_hist = reduced_hist[1:]  # shift an element off
                                                     # the history.
                    counts_for_backoff_hist = self.counts[m][reduced_hist]
                    for word in counts_for_hist.word_to_count.keys():
                        counts_for_backoff_hist.word_to_count[word] += 0
                # This loop ensures that if we have an n-gram like (6, 7, 8) -> 9,
                # then, say, (6, 7) -> 8 and (6) -> 7 exist.  This will be needed
                # for FST representations of the ARPA LM.
                reduced_hist = hist
                for m in reversed(list(range(args.no_backoff_ngram_order, n))):
                    this_word = reduced_hist[-1]
                    reduced_hist = reduced_hist[:-1]  # pop an element off the
                                                      # history
                    counts_for_backoff_hist = self.counts[m][reduced_hist]
                    counts_for_backoff_hist.word_to_count[this_word] += 0
        if args.verbose >= 1:
            print("make_phone_lm.py: in EnsureStructurallyNeededNgramsExist(), "
                  "added {0} n-grams".format(self.GetNumNgrams() - num_ngrams_initial),
                  file = sys.stderr)



    # This function prints the estimated language model as an FST.
    def PrintAsFst(self, word_disambig_symbol):
        # n is the history-length (== order + 1).  We iterate over the
        # history-length in the order 1, 0, 2, 3, and then iterate over the
        # histories of each order in sorted order.  Putting order 1 first
        # and sorting on the histories
        # ensures that the bigram state with <s> as the left context comes first.
        # (note: self.bos_symbol is the most negative symbol)

        # History will map from history (as a tuple) to integer FST-state.
        hist_to_state = self.GetHistToStateMap()

        for n in [ 1, 0 ] + list(range(2, args.ngram_order)):
            this_order_counts = self.counts[n]
            # For order 1, make sure the keys are sorted.
            keys = this_order_counts.keys() if n != 1 else sorted(this_order_counts.keys())
            for hist in keys:
                word_to_count = this_order_counts[hist].word_to_count
                this_fst_state = hist_to_state[hist]

                for word in word_to_count.keys():
                    # work out this_cost.  Costs in OpenFst are negative logs.
                    this_cost = -math.log(self.GetProb(hist, word))

                    if word > 0: # a real word.
                        next_hist = hist + (word,)  # appending tuples
                        while not next_hist in hist_to_state:
                            next_hist = next_hist[1:]
                        next_fst_state = hist_to_state[next_hist]
                        print(this_fst_state, next_fst_state, word, word,
                              this_cost)
                    elif word == self.eos_symbol:
                        # print final-prob for this state.
                        print(this_fst_state, this_cost)
                    else:
                        assert word == self.backoff_symbol
                        backoff_fst_state = hist_to_state[hist[1:len(hist)]]
                        print(this_fst_state, backoff_fst_state,
                              word_disambig_symbol, 0, this_cost)

    # This function returns a set of n-grams that cannot currently be pruned
    # away, either because a higher-order form of the same n-gram already exists,
    # or because the n-gram leads to an n-gram state that exists.
    # [Note: as we prune, we remove any states that can be removed; see that
    # PruneToIntermediateTarget() calls PruneEmptyStates().

    def GetProtectedNgrams(self):
        ans = set()
        for n in range(args.no_backoff_ngram_order + 1, args.ngram_order):
            for hist, counts_for_hist in self.counts[n].items():
                # If we have an n-gram (6, 7, 8) -> 9, the following loop will
                # add the backed-off n-grams (7, 8) -> 9 and (8) -> 9 to
                # 'protected-ngrams'.
                reduced_hist = hist
                for m in reversed(list(range(args.no_backoff_ngram_order, n))):
                    reduced_hist = reduced_hist[1:]  # shift an element off
                                                     # the history.

                    for word in counts_for_hist.word_to_count.keys():
                        if word != self.backoff_symbol:
                            ans.add(reduced_hist + (word,))
                # The following statement ensures that if we are in a
                # history-state (6, 7, 8), then n-grams (6, 7, 8) and (6, 7) are
                # protected.  This assures that the FST states are accessible.
                reduced_hist = hist
                for m in reversed(list(range(args.no_backoff_ngram_order, n))):
                    ans.add(reduced_hist)
                    reduced_hist = reduced_hist[:-1]  # pop an element off the
                                                      # history
        return ans

    def PruneNgram(self, hist, word):
        counts_for_hist = self.counts[len(hist)][hist]
        assert word != self.backoff_symbol and word in counts_for_hist.word_to_count
        count = counts_for_hist.word_to_count[word]
        del counts_for_hist.word_to_count[word]
        counts_for_hist.word_to_count[self.backoff_symbol] += count
        # the next call adds the count to the symbol 'word' in the backoff
        # history-state, and also updates its 'total_count'.
        self.counts[len(hist) - 1][hist[1:]].AddCount(word, count)

    # The function PruningLogprobChange is the same as the same-named
    # function in float-counts-prune.cc in pocolm.  Note, it doesn't access
    # any class members.

    # This function computes the log-likelihood change (<= 0) from backing off
    # a particular symbol to the lower-order state.
    # The value it returns can be interpreted as a lower bound the actual log-likelihood
    # change.  By "the actual log-likelihood change" we mean of data generated by
    # the model itself before making the change, then modeled with the changed model
    # [and comparing the log-like with the log-like before changing the model].  That is,
    # it's a K-L divergence, but with the caveat that we don't normalize by the
    # overall count of the data, so it's a K-L divergence multiplied by the training-data
    # count.

    #  'count' is the count of the word (call it 'a') in this state.  It's an integer.
    #  'discount' is the discount-count in this state (represented as the count
    #         for the symbol self.backoff_symbol).  It's an integer.
    #  [note: we don't care about the total-count in this state, it cancels out.]
    #  'backoff_count' is the count of word 'a' in the lower-order state.
    #                 [actually it is the augmented count, treating any
    #                  extra probability from even-lower-order states as
    #                  if it were a count].  It's a float.
    #  'backoff_total' is the total count in the lower-order state.  It's a float.
    def PruningLogprobChange(self, count, discount, backoff_count, backoff_total):
        if count == 0:
            return 0.0

        assert discount > 0 and backoff_total >= backoff_count and backoff_total >= 0.99 * discount


        # augmented_count is like 'count', but with the extra count for symbol
        # 'a' due to backoff included.
        augmented_count = count + discount * backoff_count / backoff_total

        # We imagine a phantom symbol 'b' that represents all symbols other than
        # 'a' appearing in this history-state that are accessed via backoff.  We
        # treat these as being distinct symbols from the same symbol if accessed
        # not-via-backoff.  (Treating same symbols as distinct gives an upper bound
        # on the divergence).  We also treat them as distinct from the same symbols
        # that are being accessed via backoff from other states.  b_count is the
        # observed count of symbol 'b' in this state (the backed-off count is
        # zero).  b_count is also the count of symbol 'b' in the backoff state.
        # Note: b_count will not be negative because backoff_total >= backoff_count.
        b_count = discount * ((backoff_total - backoff_count) / backoff_total)
        assert b_count >= -0.001 * backoff_total

        # We imagine a phantom symbol 'c' that represents all symbols other than
        # 'a' and 'b' appearing in the backoff state, which got there from
        # backing off other states (other than 'this' state).  Again, we imagine
        # the symbols are distinct even though they may not be (i.e. that c and
        # b represent disjoint sets of symbol, even though they might not really
        # be disjoint), and this gives us an upper bound on the divergence.
        c_count = backoff_total - backoff_count - b_count
        assert c_count >= -0.001 * backoff_total

        # a_other is the count of 'a' in the backoff state that comes from
        # 'other sources', i.e. it was backed off from history-states other than
        # the current history state.
        a_other_count = backoff_count - discount * backoff_count / backoff_total
        assert a_other_count >= -0.001 * backoff_count

        # the following sub-expressions are the 'new' versions of certain
        # quantities after we assign the total count 'count' to backoff.  it
        # increases the backoff count in 'this' state, and also the total count
        # in the backoff state, and the count of symbol 'a' in the backoff
        # state.
        new_backoff_count = backoff_count + count  # new count of symbol 'a' in
                                                    # backoff state
        new_backoff_total = backoff_total + count  # new total count in
                                                    # backoff state.
        new_discount = discount + count  # new discount-count in 'this' state.


        # all the loglike changes below are of the form
        # count-of-symbol * log(new prob / old prob)
        # which can be more conveniently written (by canceling the denominators),
        # count-of-symbol * log(new count / old count).

        # this_a_change is the log-like change of symbol 'a' coming from 'this'
        # state.  bear in mind that
        # augmented_count = count + discount * backoff_count / backoff_total,
        # and the 'count' term is zero in the numerator part of the log expression,
        # because symbol 'a' is completely backed off in 'this' state.
        this_a_change = augmented_count * \
            math.log((new_discount * new_backoff_count / new_backoff_total)/ \
                         augmented_count)

        # other_a_change is the log-like change of symbol 'a' coming from all
        # other states than 'this'.  For speed reasons we don't examine the
        # direct (non-backoff) counts of symbol 'a' in all other states than
        # 'this' that back off to the backoff state-- it would be slower.
        # Instead we just treat the direct part of the prob for symbol 'a' as a
        # distinct symbol when it comes from those other states... as usual,
        # doing so gives us an upper bound on the divergence.
        other_a_change = \
            a_other_count * math.log((new_backoff_count / new_backoff_total) / \
                                         (backoff_count / backoff_total)) 

        # b_change is the log-like change of phantom symbol 'b' coming from
        # 'this' state (and note: it only comes from this state, that's how we
        # defined it).
        # note: the expression below could be more directly written as a
        # ratio of pseudo-counts as follows, by converting the backoff probabilities
        # into pseudo-counts in 'this' state:
        #  b_count * logf((new_discount * b_count / new_backoff_total) /
        #                 (discount * b_count / backoff_total),
        # but we cancel b_count to give us the expression below.
        b_change = b_count * math.log((new_discount / new_backoff_total) / \
                                          (discount / backoff_total))

        # c_change is the log-like change of phantom symbol 'c' coming from
        # all other states that back off to the backoff sate (and all prob. mass of
        # 'c' comes from those other states).  The expression below could be more
        # directly written as a ratio of counts, as c_count * logf((c_count /
        # new_backoff_total) / (c_count / backoff_total)), but we simplified it to
        # the expression below.
        c_change = c_count * math.log(backoff_total / new_backoff_total)

        ans = this_a_change + other_a_change + b_change + c_change
        # the answer should not be positive.
        assert ans <= 0.0001 * (count + discount + backoff_count + backoff_total)
        if args.verbose >= 4:
            print("pruning-logprob-change for {0},{1},{2},{3} is {4}".format(
                    count, discount, backoff_count, backoff_total, ans),
                  file = sys.stderr)
        return ans


    def GetLikeChangeFromPruningNgram(self, hist, word):
        counts_for_hist = self.counts[len(hist)][hist]
        counts_for_backoff_hist = self.counts[len(hist) - 1][hist[1:]]
        assert word != self.backoff_symbol and word in counts_for_hist.word_to_count
        count = counts_for_hist.word_to_count[word]
        discount = counts_for_hist.word_to_count[self.backoff_symbol]
        backoff_total = counts_for_backoff_hist.total_count
        # backoff_count is a pseudo-count: it's like the count of 'word' in the
        # backoff history-state, but adding something to account for further
        # levels of backoff.
        try:
            backoff_count = self.GetProb(hist[1:], word) * backoff_total
        except:
            print("problem getting backoff count: hist = {0}, word = {1}".format(hist, word),
                  file = sys.stderr)
            sys.exit(1)

        return self.PruningLogprobChange(float(count), float(discount),
                                         backoff_count, float(backoff_total))

    # note: returns loglike change per word.
    def PruneToIntermediateTarget(self, num_extra_ngrams):
        protected_ngrams = self.GetProtectedNgrams()
        initial_num_extra_ngrams = self.GetNumExtraNgrams()
        num_ngrams_to_prune = initial_num_extra_ngrams - num_extra_ngrams
        assert num_ngrams_to_prune > 0

        num_candidates_per_order = [ 0 ] * args.ngram_order
        num_pruned_per_order = [ 0 ] * args.ngram_order


        # like_change_and_ngrams this will be a list of tuples consisting
        # of the likelihood change as a float and then the words of the n-gram
        # that we're considering pruning,
        # e.g. (-0.164, 7, 8, 9)
        # meaning that pruning the n-gram (7, 8) -> 9 leads to
        # a likelihood change of -0.164.  We'll later sort this list
        # so we can prune the n-grams that made the least-negative
        # likelihood change.
        like_change_and_ngrams = []
        for n in range(args.no_backoff_ngram_order, args.ngram_order):
            for hist, counts_for_hist in self.counts[n].items():
                for word, count in counts_for_hist.word_to_count.items():
                    if word != self.backoff_symbol:
                        if not hist + (word,) in protected_ngrams:
                            like_change = self.GetLikeChangeFromPruningNgram(hist, word)
                            like_change_and_ngrams.append((like_change,) + hist + (word,))
                            num_candidates_per_order[len(hist)] += 1

        like_change_and_ngrams.sort(reverse = True)

        if num_ngrams_to_prune > len(like_change_and_ngrams):
            print('make_phone_lm.py: aimed to prune {0} n-grams but could only '
                  'prune {1}'.format(num_ngrams_to_prune, len(like_change_and_ngrams)),
                  file = sys.stderr)
            num_ngrams_to_prune = len(like_change_and_ngrams)

        total_loglike_change = 0.0

        for i in range(num_ngrams_to_prune):
            total_loglike_change += like_change_and_ngrams[i][0]
            hist = like_change_and_ngrams[i][1:-1]  # all but 1st and last elements
            word = like_change_and_ngrams[i][-1]  # last element
            num_pruned_per_order[len(hist)] += 1
            self.PruneNgram(hist, word)

        like_change_per_word = total_loglike_change / self.total_num_words

        if args.verbose >= 1:
            effective_threshold = (like_change_and_ngrams[num_ngrams_to_prune - 1][0]
                                   if num_ngrams_to_prune >= 0 else 0.0)
            print("Pruned from {0} ngrams to {1}, with threshold {2}.  Candidates per order were {3}, "
                  "num-ngrams pruned per order were {4}.  Like-change per word was {5}".format(
                    initial_num_extra_ngrams,
                    initial_num_extra_ngrams - num_ngrams_to_prune,
                    '%.4f' % effective_threshold,
                    num_candidates_per_order,
                    num_pruned_per_order,
                    like_change_per_word), file = sys.stderr)

        if args.verbose >= 3:
            print("Pruning: like_change_and_ngrams is:\n" +
                  '\n'.join([str(x) for x in like_change_and_ngrams[:num_ngrams_to_prune]]) +
                  "\n-------- stop pruning here: ----------\n" +
                  '\n'.join([str(x) for x in like_change_and_ngrams[num_ngrams_to_prune:]]),
                  file = sys.stderr)
            self.Print("Counts after pruning to num-extra-ngrams={0}".format(
                    initial_num_extra_ngrams - num_ngrams_to_prune))

        self.PruneEmptyStates()
        if args.verbose >= 3:
            ngram_counts.Print("Counts after removing empty states [inside pruning algorithm]:")
        return like_change_per_word



    def PruneToFinalTarget(self, num_extra_ngrams):
        # prunes to a specified num_extra_ngrams.  The 'extra_ngrams' refers to
        # the count of n-grams of order higher than args.no_backoff_ngram_order.
        # We construct a sequence of targets that gradually approaches
        # this value.  Doing it iteratively like this is a good way
        # to deal with the fact that sometimes we can't prune a certain
        # n-gram before certain other n-grams are pruned (because
        # they lead to a state that must be kept, or an n-gram exists
        # that backs off to this n-gram).

        current_num_extra_ngrams = self.GetNumExtraNgrams()

        if num_extra_ngrams >= current_num_extra_ngrams:
            print('make_phone_lm.py: not pruning since target num-extra-ngrams={0} is >= '
                  'current num-extra-ngrams={1}'.format(num_extra_ngrams, current_num_extra_ngrams),
                  file=sys.stderr)
            return

        target_sequence = [num_extra_ngrams]
        # two final iterations where the targets differ by factors of 1.1,
        # preceded by two iterations where the targets differ by factors of 1.2.
        for this_factor in [ 1.1, 1.2 ]:
            for n in range(0,2):
                if int((target_sequence[-1]+1) * this_factor) < current_num_extra_ngrams:
                    target_sequence.append(int((target_sequence[-1]+1) * this_factor))
        # then change in factors of 1.3
        while True:
            this_factor = 1.3
            if int((target_sequence[-1]+1) * this_factor) < current_num_extra_ngrams:
                target_sequence.append(int((target_sequence[-1]+1) * this_factor))
            else:
                break

        target_sequence = list(set(target_sequence))  # only keep unique targets.
        target_sequence.sort(reverse = True)

        print('make_phone_lm.py: current num-extra-ngrams={0}, pruning with '
              'following sequence of targets: {1}'.format(current_num_extra_ngrams,
                                                          target_sequence),
              file = sys.stderr)
        total_like_change_per_word = 0.0
        for target in target_sequence:
            total_like_change_per_word += self.PruneToIntermediateTarget(target)

        if args.verbose >= 1:
            print('make_phone_lm.py: K-L divergence from pruning (upper bound) is '
                  '%.4f' % total_like_change_per_word, file = sys.stderr)


    # returns the number of n-grams on top of those that can't be pruned away
    # because their order is <= args.no_backoff_ngram_order.
    def GetNumExtraNgrams(self):
        ans = 0
        for hist_len in range(args.no_backoff_ngram_order, args.ngram_order):
            # note: hist_len + 1 is the actual order.
            ans += self.GetNumNgrams(hist_len)
        return ans


    def GetNumNgrams(self, hist_len = None):
        ans = 0
        if hist_len == None:
            for hist_len in range(args.ngram_order):
                # note: hist_len + 1 is the actual order.
                ans += self.GetNumNgrams(hist_len)
            return ans
        else:
            for counts_for_hist in self.counts[hist_len].values():
                ans += len(counts_for_hist.word_to_count)
                if self.backoff_symbol in counts_for_hist.word_to_count:
                    ans -= 1  # don't count the backoff symbol, it doesn't produce
                              # its own n-gram line.
            return ans


    # this function, used in PrintAsArpa, converts an integer to
    # a string by either printing it as a string, or for self.bos_symbol
    # and self.eos_symbol, printing them as "<s>" and "</s>" respectively.
    def IntToString(self, i):
        if i == self.bos_symbol:
            return '<s>'
        elif i == self.eos_symbol:
            return '</s>'
        else:
            assert i != self.backoff_symbol
            return str(i)



    def PrintAsArpa(self):
        # Prints out the FST in ARPA format.
        assert args.no_backoff_ngram_order == 1  # without unigrams we couldn't
                                                 # print as ARPA format.

        print('\\data\\');
        for hist_len in range(args.ngram_order):
            # print the number of n-grams.  Add 1 for the 1-gram
            # section because of <s>, we print -99 as the prob so we
            # have a place to put the backoff prob.
            print('ngram {0}={1}'.format(
                    hist_len + 1,
                    self.GetNumNgrams(hist_len) + (1 if hist_len == 0 else 0)))

        print('')

        for hist_len in range(args.ngram_order):
            print('\\{0}-grams:'.format(hist_len + 1))

            # print fake n-gram for <s>, for its backoff prob.
            if hist_len == 0:
                backoff_prob = self.GetProb((self.bos_symbol,), self.backoff_symbol)
                if backoff_prob != None:
                    print('-99\t<s>\t{0}'.format('%.5f' % math.log10(backoff_prob)))

            for hist in self.counts[hist_len].keys():
                for word in self.counts[hist_len][hist].word_to_count.keys():
                    if word != self.backoff_symbol:
                        prob = self.GetProb(hist, word)
                        assert prob != None and prob > 0
                        backoff_prob = self.GetProb((hist)+(word,), self.backoff_symbol)
                        line = '{0}\t{1}'.format('%.5f' % math.log10(prob),
                                                 ' '.join(self.IntToString(x) for x in hist + (word,)))
                        if backoff_prob != None:
                            line += '\t{0}'.format('%.5f' % math.log10(backoff_prob))
                        print(line)
            print('')
        print('\\end\\')



ngram_counts = NgramCounts(args.ngram_order)
ngram_counts.AddRawCountsFromStandardInput()

if args.verbose >= 3:
    ngram_counts.Print("Raw counts:")
ngram_counts.ApplyBackoff()
if args.verbose >= 3:
    ngram_counts.Print("Counts after applying Kneser-Ney discounting:")
ngram_counts.EnsureStructurallyNeededNgramsExist()
if args.verbose >= 3:
    ngram_counts.Print("Counts after adding structurally-needed n-grams (1st time):")
ngram_counts.PruneEmptyStates()
if args.verbose >= 3:
    ngram_counts.Print("Counts after removing empty states:")
ngram_counts.PruneToFinalTarget(args.num_extra_ngrams)

ngram_counts.EnsureStructurallyNeededNgramsExist()
if args.verbose >= 3:
    ngram_counts.Print("Counts after adding structurally-needed n-grams (2nd time):")




if args.print_as_arpa == "true":
    ngram_counts.PrintAsArpa()
else:
    if args.phone_disambig_symbol == None:
        sys.exit("make_phone_lm.py: --phone-disambig-symbol must be provided (unless "
                 "you are writing as ARPA")
    ngram_counts.PrintAsFst(args.phone_disambig_symbol)


## Below are some little test commands that can be used to look at the detailed stats
## for a kind of sanity check.
# test comand:
# (echo 6 7 8 4; echo 7 8 9; echo 7 8; echo 7 4; echo 8 4 ) | utils/lang/make_phone_lm.py --phone-disambig-symbol=400  --verbose=3
#  (echo 6 7 8 4; echo 7 8 9; echo 7 8; echo 7 4; echo 8 4 ) | utils/lang/make_phone_lm.py --phone-disambig-symbol=400  --verbose=3 --num-extra-ngrams=0
# (echo 6 7 8 4; echo 6 7 ) | utils/lang/make_phone_lm.py --print-as-arpa=true --no-backoff-ngram-order=1  --verbose=3


## The following shows how we created some data suitable to do comparisons with
## other language modeling toolkits.  Note: we're running in a configuration
## where --no-backoff-ngram-order=1 (i.e. we have a unigram LM state) because
## it's the only way to get perplexity calculations and to write an ARPA file.
##
# cd egs/tedlium/s5_r2
# . ./path.sh
# mkdir -p lm_test
# ali-to-phones exp/tri3/final.mdl "ark:gunzip -c exp/tri3/ali.*.gz|" ark,t:-  | awk '{$1 = ""; print}' > lm_test/phone_seqs
# wc lm_test/phone_seqs
# 92464  8409563 27953288 lm_test/phone_seqs
# head -n 20000 lm_test/phone_seqs > lm_test/train.txt
# tail -n 1000 lm_test/phone_seqs > lm_test/test.txt

## This shows make_phone_lm.py with the default number of extra-lm-states (20k)
## You have to have SRILM on your path to ger perplexities [note: it should be on the
## path if you installed it and you sourced the tedlium s5b path.sh, as above.]
# utils/lang/make_phone_lm.py --print-as-arpa=true --no-backoff-ngram-order=1 --verbose=1 < lm_test/train.txt > lm_test/arpa_pr20k
# ngram -order 4 -unk -lm lm_test/arpa_pr20k -ppl lm_test/test.txt
# file lm_test/test.txt: 1000 sentences, 86489 words, 3 OOVs
# 0 zeroprobs, logprob= -80130.1 ppl=*8.23985* ppl1= 8.44325
# on training data: 0 zeroprobs, logprob= -1.6264e+06 ppl= 7.46947 ppl1= 7.63431

## This shows make_phone_lm.py without any pruning (make --num-extra-ngrams very large).
# utils/lang/make_phone_lm.py --print-as-arpa=true --num-extra-ngrams=1000000 --no-backoff-ngram-order=1 --verbose=1 < lm_test/train.txt > lm_test/arpa
# ngram -order 4 -unk -lm lm_test/arpa -ppl lm_test/test.txt
# file lm_test/test.txt: 1000 sentences, 86489 words, 3 OOVs
# 0 zeroprobs, logprob= -74976 ppl=*7.19459* ppl1= 7.36064
# on training data: 0 zeroprobs, logprob= -1.44198e+06 ppl= 5.94659 ppl1= 6.06279

## This is SRILM without pruning (c.f. the 7.19 above, it's slightly better).
# ngram-count -text lm_test/train.txt -order 4 -kndiscount2 -kndiscount3 -kndiscount4 -interpolate -lm lm_test/arpa_srilm
# ngram -order 4 -unk -lm lm_test/arpa_srilm -ppl lm_test/test.txt
# file lm_test/test.txt: 1000 sentences, 86489 words, 3 OOVs
# 0 zeroprobs, logprob= -74742.2 ppl= *7.15044* ppl1= 7.31494


## This is SRILM with a pruning beam tuned to get 20k n-grams above unigram
##  (c.f. the 8.23 above, it's a lot worse).
# ngram-count -text lm_test/train.txt -order 4 -kndiscount2 -kndiscount3 -kndiscount4 -interpolate -prune 1.65e-05 -lm lm_test/arpa_srilm.pr1.65e-5
# the model has 20249 n-grams above unigram [c.f. our 20k]
# ngram -order 4 -unk -lm lm_test/arpa_srilm.pr1.65e-5 -ppl lm_test/test.txt
# file lm_test/test.txt: 1000 sentences, 86489 words, 3 OOVs
# 0 zeroprobs, logprob= -86803.7 ppl=*9.82202* ppl1= 10.0849


## This is pocolm..
## Note: we have to hold out some of the training data as dev to
## estimate the hyperparameters, but we'll fold it back in before
## making the final LM. [--fold-dev-into=train]
# mkdir -p lm_test/data/text
# head -n 1000 lm_test/train.txt > lm_test/data/text/dev.txt
# tail -n +1001 lm_test/train.txt > lm_test/data/text/train.txt
## give it a 'large' num-words so it picks them all.
# export PATH=$PATH:../../../tools/pocolm/scripts
# train_lm.py --num-word=100000 --fold-dev-into=train lm_test/data/text 4 lm_test/data/lm_unpruned
# get_data_prob.py lm_test/test.txt lm_test/data/lm_unpruned/100000_4.pocolm
## compute-probs: average log-prob per word was -1.95956 (perplexity = *7.0962*) over 87489 words.
## Note: we can compare this perplexity with 7.15 with SRILM and 7.19 with make_phone_lm.py.

#   pruned_lm_dir=${lm_dir}/${num_word}_${order}_prune${threshold}.pocolm
# prune_lm_dir.py --target-num-ngrams=20100 lm_test/data/lm_unpruned/100000_4.pocolm lm_test/data/lm_unpruned/100000_4_pr20k.pocolm
# get_data_prob.py lm_test/test.txt lm_test/data/lm_unpruned/100000_4_pr20k.pocolm
## compute-probs: average log-prob per word was -2.0409 (perplexity = 7.69757) over 87489 words.
## note: the 7.69 can be compared with 9.82 from SRILM and 8.23 from pocolm.
## format_arpa_lm.py lm_test/data/lm_unpruned/100000_4_pr20k.pocolm | head
## .. it has 20488 n-grams above unigram.  More than 20k but not enough to explain the difference
## .. in perplexity.

## OK... if I reran after modifying prune_lm_dir.py to comment out the line
## 'steps += 'EM EM'.split()' which adds the two EM stages per step, and got the
## perplexity again, I got the following:
## compute-probs: average log-prob per word was -2.09722 (perplexity = 8.14353) over 87489 words.
## .. so it turns out the E-M is actually important.