// lm/kaldi-rnnlm.cc // Copyright 2015 Guoguo Chen // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include #include "lm/kaldi-rnnlm.h" #include "util/stl-utils.h" #include "util/text-utils.h" namespace kaldi { KaldiRnnlmWrapper::KaldiRnnlmWrapper( const KaldiRnnlmWrapperOpts &opts, const std::string &unk_prob_rspecifier, const std::string &word_symbol_table_rxfilename, const std::string &rnnlm_rxfilename) { rnnlm_.setRnnLMFile(rnnlm_rxfilename); rnnlm_.setRandSeed(1); rnnlm_.setUnkSym(opts.unk_symbol); rnnlm_.setUnkPenalty(unk_prob_rspecifier); rnnlm_.restoreNet(); // Reads symbol table. fst::SymbolTable *word_symbols = NULL; if (!(word_symbols = fst::SymbolTable::ReadText(word_symbol_table_rxfilename))) { KALDI_ERR << "Could not read symbol table from file " << word_symbol_table_rxfilename; } label_to_word_.resize(word_symbols->NumSymbols() + 1); for (int32 i = 0; i < label_to_word_.size() - 1; ++i) { label_to_word_[i] = word_symbols->Find(i); if (label_to_word_[i] == "") { KALDI_ERR << "Could not find word for integer " << i << "in the word " << "symbol table, mismatched symbol table or you have discoutinuous " << "integers in your symbol table?"; } } label_to_word_[label_to_word_.size() - 1] = opts.eos_symbol; eos_ = label_to_word_.size() - 1; } BaseFloat KaldiRnnlmWrapper::GetLogProb( int32 word, const std::vector &wseq, const std::vector &context_in, std::vector *context_out) { std::vector wseq_symbols(wseq.size()); for (int32 i = 0; i < wseq_symbols.size(); ++i) { KALDI_ASSERT(wseq[i] < label_to_word_.size()); wseq_symbols[i] = label_to_word_[wseq[i]]; } return rnnlm_.computeConditionalLogprob(label_to_word_[word], wseq_symbols, context_in, context_out); } RnnlmDeterministicFst::RnnlmDeterministicFst(int32 max_ngram_order, KaldiRnnlmWrapper *rnnlm) { KALDI_ASSERT(rnnlm != NULL); max_ngram_order_ = max_ngram_order; rnnlm_ = rnnlm; // Uses empty history for . std::vector