Blame view
src/tfrnnlm/tensorflow-rnnlm.h
7.14 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
// tensorflow-rnnlm.h // Copyright (C) 2017 Intellisist, Inc. (Author: Hainan Xu) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_TFRNNLM_TENSORFLOW_RNNLM_H_ #define KALDI_TFRNNLM_TENSORFLOW_RNNLM_H_ #include <string> #include <vector> #include <unordered_map> #include "util/stl-utils.h" #include "base/kaldi-common.h" #include "fstext/deterministic-fst.h" #include "util/common-utils.h" // Following macros are defined in both OpenFst and Tensorflow headers. Here we // undef them before including "tensorflow/core/public/session.h" to silence // compiler warnings. Note that this is not a panacea. We should still pay // attention to the order of includes in other places in the codebase to avoid // using the wrong macro definitions. Any OpenFst header or any header including // an OpenFst header should be included before tfrnnlm/tensorflow-rnnlm.h. Also, // to avoid macro redefinitions, any Tensorflow header should be included after // tfrnnlm/tensorflow-rnnlm.h. #undef LOG #undef VLOG #undef CHECK #undef CHECK_EQ #undef CHECK_LT #undef CHECK_GT #undef CHECK_LE #undef CHECK_GE #undef CHECK_NE #undef DCHECK #undef DCHECK_EQ #undef DCHECK_LT #undef DCHECK_GT #undef DCHECK_LE #undef DCHECK_GE #undef DCHECK_NE #include "tensorflow/core/public/session.h" using tensorflow::Session; using tensorflow::Tensor; namespace kaldi { namespace tf_rnnlm { struct KaldiTfRnnlmWrapperOpts { std::string unk_symbol; int32 num_threads; // 0 means unlimited KaldiTfRnnlmWrapperOpts() : unk_symbol("<oos>"), num_threads(1) {} void Register(OptionsItf *opts) { opts->Register("unk-symbol", &unk_symbol, "Symbol for out-of-vocabulary " "words in rnnlm."); opts->Register("num-threads", &num_threads, "Number of threads for TF computation; " "0 means unlimited."); } }; /** This class wraps the TensorFlow based RNNLM, and provides a set of interfaces to be used for class TfRnnlmDeterministicFst, implemented later in this file */ class KaldiTfRnnlmWrapper { public: /// constructor /// opts specify symbol for <unk> and num-threads for computation /// rnn_wordlist specifies a wordlist file with format /// [int-word-id] [word] /// the word <oos> must appear in this file // /// word_symbol_table_rxfilename points to a standard word-list file in OpenFST style /// unk_prob_file has the format /// [word] [prob or count] (it auto-normalizes the probabilities) /// tf_model_path is the location of the TensorFlow model KaldiTfRnnlmWrapper(const KaldiTfRnnlmWrapperOpts &opts, const std::string &rnn_wordlist, const std::string &word_symbol_table_rxfilename, const std::string &unk_prob_file, const std::string &tf_model_path); ~KaldiTfRnnlmWrapper() { session_->Close(); } int32 GetEos() const { return eos_; } /// get an all-zero Tensor of the size that matches the hidden state of the TF model const Tensor& GetInitialContext() const; /// get the 2nd-to-last layer of RNN when feeding input of /// (initial-context, sentence-boundary) /// "cell" is short for "(last)cell-output"; calling it "cell" here because in /// later functions we have function GetLogProb() where we need to pass in /// one "cell" as input and another as output; to avoid confusing we use a single /// word "cell" for that instead of things like cell_out_in and cell_out_out. const Tensor& GetInitialCell() const; /// compute p(word | wseq) and return the log of that /// the computation used the input cell, /// which is the 2nd-to-last layer of the RNNLM associated with history wseq; /// /// and we generate (context_out, new_cell) by passing (context_in, word) /// into the TensorFlow session that manages the RNNLM /// if the last 2 pointers are NULL we don't query them in TF session /// e.g. in the case of computing p(</s>|some history) BaseFloat GetLogProb(int32 word, // word id in RNN wordlist int32 fst_word, // FST word label, only for computing OOS cost const Tensor &context_in, const Tensor &cell_in, Tensor *context_out, Tensor *cell_out); /// takes in a word-id for FST and return the word-id for RNNLM /// return the word-id for <oos> if not found int FstLabelToRnnLabel(int i) const; private: /// read the TensorFlow model and create the session for computation /// num-threads need to be specified in creating the session void ReadTfModel(const std::string &tf_model_path, int32 num_threads); /// do queries on the session to get the initial tensors (cell + context) void AcquireInitialTensors(); /// since usually we have a smaller vocab in RNN than the whole vocab, /// we use this mapping during rescoring std::vector<int> fst_label_to_rnn_label_; std::vector<std::string> rnn_label_to_word_; std::vector<std::string> fst_label_to_word_; KaldiTfRnnlmWrapperOpts opts_; Tensor initial_context_; Tensor initial_cell_; // this corresponds to the FST symbol table int32 num_total_words; // this corresponds to the RNNLM symbol table int32 num_rnn_words; Session* session_; // for TF computation; pointer owned here int32 eos_; int32 oos_; std::vector<float> unk_costs_; // extra cost for OOS symbol in RNNLM KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiTfRnnlmWrapper); }; class TfRnnlmDeterministicFst: public fst::DeterministicOnDemandFst<fst::StdArc> { public: typedef fst::StdArc::Weight Weight; typedef fst::StdArc::StateId StateId; typedef fst::StdArc::Label Label; // Does not take ownership. TfRnnlmDeterministicFst(int32 max_ngram_order, KaldiTfRnnlmWrapper *rnnlm); ~TfRnnlmDeterministicFst(); void Clear(); // We cannot use "const" because the pure virtual function in the interface is // not const. virtual StateId Start() { return start_state_; } // We cannot use "const" because the pure virtual function in the interface is // not const. virtual Weight Final(StateId s); virtual bool GetArc(StateId s, Label ilabel, fst::StdArc* oarc); private: typedef unordered_map<std::vector<Label>, StateId, VectorHasher<Label> > MapType; StateId start_state_; MapType wseq_to_state_; std::vector<std::vector<Label> > state_to_wseq_; KaldiTfRnnlmWrapper *rnnlm_; int32 max_ngram_order_; std::vector<Tensor*> state_to_context_; std::vector<Tensor*> state_to_cell_; }; } // namespace tf_rnnlm } // namespace kaldi #endif // KALDI_TFRNNLM_TENSORFLOW_RNNLM_H_ |