rnnlm-test-utils.cc 15.3 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
// rnnlm/rnnlm-test-utils.cc

// Copyright 2017  Daniel Povey
//           2017  Hossein Hadian
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include <numeric>
#include "rnnlm/rnnlm-test-utils.h"

namespace kaldi {
namespace rnnlm {

void GetForbiddenSymbols(std::set<std::string> *forbidden_symbols) {
  *forbidden_symbols = {"<eps>", "<s>", "<brk>", "</s>"};
}

///  Reads all the lines from a text file and appends
///  them to the "sentences" vector.
void ReadAllLines(const std::string &filename,
                  std::vector<std::vector<std::string> > *sentences) {
  std::ifstream is(filename.c_str());
  std::string line;
  while (std::getline(is, line)) {
    std::vector<std::string> split_line;
    SplitStringToVector(line, "\t\r\n ", true, &split_line);
    sentences->push_back(split_line);
  }
  if (sentences->size() < 1)
    KALDI_ERR << "No line could be read from the file.";
}

void GetTestSentences(const std::set<std::string> &forbidden_symbols,
                      std::vector<std::vector<std::string> > *sentences) {
  sentences->clear();
  ReadAllLines("sampling-lm-test.cc", sentences);
  ReadAllLines("rnnlm-example-test.cc", sentences);
  ReadAllLines("rnnlm-example.cc", sentences);
  ReadAllLines("rnnlm-example-utils.cc", sentences);

  // find and escape forbidden symbols
  for (int i = 0; i < sentences->size(); i++)
    for (int j = 0; j < (*sentences)[i].size(); j++)
      if (forbidden_symbols.find((*sentences)[i][j]) != forbidden_symbols.end())
        (*sentences)[i][j] = "\\" + (*sentences)[i][j];
}

fst::SymbolTable *GetSymbolTable(
    const std::vector<std::vector<std::string> > &sentences) {
  fst::SymbolTable* table = new fst::SymbolTable();
  table->AddSymbol("<eps>", 0);
  table->AddSymbol("<s>", 1);
  table->AddSymbol("</s>", 2);
  table->AddSymbol("<brk>", 3);
  for (int i = 0; i < sentences.size(); i++)
    for (int j = 0; j < sentences[i].size(); j++)
      table->AddSymbol(sentences[i][j]);
  return table;
}

void ConvertToInteger(
    const std::vector<std::vector<std::string> > &string_sentences,
    const fst::SymbolTable &symbol_table,
    std::vector<std::vector<int32> > *int_sentences) {
  int_sentences->resize(string_sentences.size());
  for (int i = 0; i < string_sentences.size(); i++) {
    (*int_sentences)[i].resize(string_sentences[i].size());
    for (int j = 0; j < string_sentences[i].size(); j++) {
      kaldi::int64 key = symbol_table.Find(string_sentences[i][j]);
      KALDI_ASSERT(key != -1); // fst::kNoSymbol
      (*int_sentences)[i][j] = static_cast<int32>(key);
    }
  }
}

/** Simple implementation of Interpolated Kneser-Ney smoothing,
    see the formulas in "A Bit of Progress in Language Modeling"

    Note that we won't follow the procedure of SRILM implementation(collect
    counts, then modify counts, and then discount to get probs). Instead,
    we accumulate the number of context occurrences directly during we pass
    through the training text. We translate the perl code in the appendix of
    the paper and extend it to arbitrary ngram order. Also, as in SRILM,
    we use the original(unmodified) count for the ngrams starting with <s>.
    We don't do any min-count prune for ngrams;
*/
class InterpolatedKneserNeyLM {
 public:
  struct LMState {
    int32 numerator;
    int32 denominator;
    int32 non_zero_count;
    BaseFloat prob;
    BaseFloat bow;

    LMState() : numerator(0), denominator(0), non_zero_count(0),
                prob(0.0), bow(0.0) {};
  };
  typedef unordered_map<std::vector<int32>, LMState, VectorHasher<int32> > Ngrams;

  /** Constructor.
       @param [in] ngram_order  The n-gram order of the language model to
                                be estimated.
       @param [in] discount   Fixed value of discount, i.e. the D in formula.
       @param [in] bos_symbol  The integer id of the beginning-of-sentence
                               symbol
       @param [in] eos_symbol  The integer id of the end-of-sentence
                              symbol
   */
  InterpolatedKneserNeyLM(int32 ngram_order, int32 bos_symbol,
                          int32 eos_symbol, double discount) :
      unigram_denominator_(0) {
    ngram_order_ = ngram_order;
    discount_ = discount;
    bos_symbol_ = bos_symbol;
    eos_symbol_ = eos_symbol;
    ngrams_.resize(ngram_order + 1); // ngrams_[0] unused
  }

  void FillWords(const std::vector<int32> &sentence,
                 int32 pos, int32 order,
                 std::vector<int32> *words) {
    KALDI_ASSERT(pos >= -1 && pos <= static_cast<int32>(sentence.size()));

    words->resize(order);
    for (int32 k = 0; k < order; k++, pos++) {
      if (pos < 0) {
        (*words)[k] = bos_symbol_;
      } else if (pos >= sentence.size()) {
        (*words)[k] = eos_symbol_;
      } else {
        (*words)[k] = sentence[pos];
      }
    }
  }

  /* Collect the ngram counts from corpus. */
  void CollectCounts(const std::vector<std::vector<int32> > &sentences) {
    std::vector<int32> words;
    std::vector<int32> subwords;

    for (int32 i = 0; i < sentences.size(); i++) {
      for (int32 j = 0; j < sentences[i].size() + 1; j++) {
        int32 max_order = j - ngram_order_ + 1;
        if (max_order < -1) {
          max_order = -max_order;
        } else {
          max_order = ngram_order_;
        }

        // in the following for-loop, only the max_order ngrams will
        // get their actual counts. And the max_order ngrams are the
        // ngram with ngram_order_ or the ngrams starting with <s>.
        for (int32 order = max_order; order >= 1; order--) {
          FillWords(sentences[i], j - order + 1, order, &words);
          // accumulate numerator
          LMState& this_ngram = ngrams_[order][words];
          this_ngram.numerator++;

          if (order == 1) {
            unigram_denominator_++;
          } else {
            // accumulate denominator for context
            subwords.assign(words.begin(), words.end() - 1);
            LMState &context = ngrams_[order - 1][subwords];
            context.denominator++;
            if (this_ngram.numerator <= 1) { // first insertion
              // accumulate for context
              context.non_zero_count++;
            } else {
              // for lower order ngram, we only need occurrence, so if it is
              // already in the map, we skip it
              break;
            }
          }
        }
      }
    }
  }

  /* Compute ngram probs and bows with the counts. */
  void EstimateProbAndBow() {
    for (int32 order = 1; order <= ngram_order_; order++) {
      Ngrams::iterator it = ngrams_[order].begin();
      for (; it != ngrams_[order].end(); it++) {
        LMState& state = it->second;
        if (order == 1) {
          // here, we assume all words in the vocabulary are appeared in the
          // training text, or we have to do discount. Since we won't get
          // the symbol table until WriteToARPA(), we don't know the size
          // of vocabulary and can't work out the discount.
          state.prob = 1.0 * state.numerator / unigram_denominator_;
        } else {
          std::vector<int32> subwords;
          Ngrams::const_iterator context, lower_order;

          subwords.assign(it->first.begin(), it->first.end() - 1);
          context = ngrams_[order - 1].find(subwords);
          KALDI_ASSERT(context != ngrams_[order - 1].end());
          state.prob = (state.numerator - discount_)
                       / context->second.denominator;

          // interpolate lower order
          subwords.assign(it->first.begin(), it->first.end() - 1);
          context = ngrams_[order - 1].find(subwords);
          KALDI_ASSERT(context != ngrams_[order - 1].end());

          subwords.assign(it->first.begin() + 1, it->first.end());
          lower_order = ngrams_[order - 1].find(subwords);
          KALDI_ASSERT(lower_order != ngrams_[order - 1].end());

          state.prob += context->second.bow * lower_order->second.prob;
        }

        if (state.denominator > 0) {
          state.bow = state.non_zero_count * discount_ / state.denominator;
        }
      }
    }
  }

  /** Estimate the language model with corpus in sentences.

      @param [in] sentences   The sentences of input data.  These will contain
                              just the actual words, not the BOS or EOS symbols.
   */
  void Estimate(const std::vector<std::vector<int32> > &sentences) {
    CollectCounts(sentences);
    EstimateProbAndBow();
  }

  static BaseFloat ProbToLogProb(BaseFloat prob) {
    if (prob == 0.0) {
      return -99.0;
    } else {
      return log10(prob);
    }
  }

  static void WriteNgram(const std::vector<int32> &words,
                  BaseFloat prob, BaseFloat bow,
                  const fst::SymbolTable &symbol_table, std::ostream &os) {
    os << ProbToLogProb(prob) << "\t";
    for (int32 i = 0; i < words.size() - 1; i++) {
      os << symbol_table.Find(words[i]) << " ";
    }
    os << symbol_table.Find(words[words.size() - 1]);
    if (bow != 0.0) {
      os << "\t" << ProbToLogProb(bow);
    }
    os << "\n";
  }

  /** Write to the ostream with ARPA format. Throws on error.
       @param [in] symbol_table  The OpenFst symbol table. It's needed
                                 because the ARPA file we write out
                                 is in text form.
       @param [out] os       The stream to which this function will write
                             the language model in ARPA format.
   */
  void WriteToARPA(const fst::SymbolTable &symbol_table,
                   std::ostream &os) const {
    // we write out only the words appeared in training text, instead of all
    // words in symbol_table, since there would be some special symbols in
    // symbol_table and our unigram distribution are calculated without
    // considering the (maybe exist) extra words.
    os << "\\data\\\n";
    for (int32 order = 1; order <= ngram_order_; order++) {
      os << "ngram " << order << "=" << ngrams_[order].size() << "\n";
    }

    for (int32 order = 1; order <= ngram_order_; order++) {
      os << "\n\\" << order << "-grams:\n";
      Ngrams::const_iterator it = ngrams_[order].begin();
      for (; it != ngrams_[order].end(); it++) {
        WriteNgram(it->first, it->second.prob, it->second.bow,
                   symbol_table, os);
      }
    }

    os << "\n\\end\\\n";
  }

  // the context ngram must be exist in the LM.
  double GetNgramProb(const std::vector<int32> &context, int32 word) {
    KALDI_ASSERT(context.size() < ngrams_.size() - 1);

    std::vector<int32> words(context);
    words.push_back(word);
    Ngrams::const_iterator it = ngrams_[words.size()].find(words);
    if (it != ngrams_[words.size()].end()) {
      return it->second.prob;
    }

    double prob = 1.0;
    for (int32 o = 1; o < words.size(); o++) {
      std::vector<int32> subwords;

      subwords.assign(words.begin() + o - 1, words.end() - 1);
      it = ngrams_[words.size() - o].find(subwords);
      prob *= it->second.bow;

      subwords.assign(words.begin() + o, words.end());
      it = ngrams_[words.size() - o].find(subwords);
      if (it != ngrams_[words.size() - o].end()) {
        prob *= it->second.prob;
        break;
      }
    }

    return prob;
  }

  /** Check the property of LM for two points:
      1. it is properly normalized
      2. the probability getting from backoff from any n-gram is always
         less than probability explicitly computed from that n-gram
   */
  bool Check(BaseFloat spot_check_prob) {
    KALDI_ASSERT (spot_check_prob > 0.0 && spot_check_prob <= 1.0);

    std::vector<int32> all_words;

    double total_prob = 0.0;
    Ngrams::const_iterator it = ngrams_[1].begin();
    for (; it != ngrams_[1].end(); it++) {
      total_prob += it->second.prob;
      all_words.push_back(it->first[0]);
    }
    if (std::fabs(total_prob - 1.0) > 1e-6) {
      KALDI_WARN << "total probability of unigram is not 1.0: " << total_prob;
      return false;
    }

    for (int32 order = 1; order <= ngram_order_; order++) {
      it = ngrams_[order].begin();
      for (; it != ngrams_[order].end(); it++) {
        if (RandUniform() > spot_check_prob) {
          continue;
        }

        if (order < ngram_order_ && it->first[order - 1] != eos_symbol_) {
          // check it is normalized with current ngram as context
          total_prob = 0.0;
          for (int32 i = 0; i < all_words.size(); i++) {
            total_prob += GetNgramProb(it->first, all_words[i]);
          }
          if (std::fabs(total_prob - 1.0) > 1e-6) {
            std::string str;
            for (int32 i = 0; i < it->first.size() - 1; i++) {
              str.append(std::to_string(it->first[i]));
              str.append(" ");
            }
            str.append(std::to_string(it->first[it->first.size() - 1]));
            KALDI_WARN << "total probability of context [" << str
                       << "] is not 1.0: " << total_prob;
            return false;
          }
        }

        if (order > 1) {
          // check that the explicitly prob is larger than backoff prob
          std::vector<int32> subwords;
          subwords.assign(it->first.begin(), it->first.end() - 1);
          Ngrams::const_iterator context = ngrams_[order - 1].find(subwords);
          KALDI_ASSERT(context != ngrams_[order - 1].end());

          subwords.assign(it->first.begin() + 1, it->first.end());
          Ngrams::const_iterator lower_order = ngrams_[order - 1].find(subwords);
          KALDI_ASSERT(lower_order != ngrams_[order - 1].end());

          if (context->second.bow * lower_order->second.prob >= it->second.prob) {
            std::string str;
            for (int32 i = 0; i < it->first.size() - 1; i++) {
              str.append(std::to_string(it->first[i]));
              str.append(" ");
            }
            str.append(std::to_string(it->first[it->first.size() - 1]));
            KALDI_WARN << "backoff probability of ngram [" << str
                       << "] is larger than explicitly probability.";
            return false;
          }
        }
      }
    }

    return true;
  }

 private:

  // Ngram order
  int32 ngram_order_;

  // Fix value of discount
  double discount_;

  // ngrams for each order
  std::vector<Ngrams> ngrams_;

  // denominator for unigrams
  int32 unigram_denominator_;

  // The integer id of the beginning-of-sentence symbol
  int32 bos_symbol_;

  // The integer id of the end-of-sentence symbol
  int32 eos_symbol_;
};

void EstimateAndWriteLanguageModel(
    int32 ngram_order,
    const fst::SymbolTable &symbol_table,
    const std::vector<std::vector<int32> > &sentences,
    int32 bos_symbol, int32 eos_symbol,
    std::ostream &os) {
  InterpolatedKneserNeyLM lm(ngram_order, bos_symbol, eos_symbol, 0.6);
  lm.Estimate(sentences);
#ifdef _KALDI_RNNLM_TEST_CHECK_LM_
  KALDI_ASSERT(lm.Check(1.0));
#endif
  lm.WriteToARPA(symbol_table, os);
}

}  // namespace rnnlm
}  // namespace kaldi