// fstext/context-fst-test.cc // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "fstext/context-fst.h" #include "fstext/fst-test-utils.h" #include "tree/context-dep.h" #include "util/kaldi-io.h" #include "base/kaldi-math.h" namespace fst { // GenAcceptorFromSequence generates a linear acceptor (identical input+output symbols) that has this // sequence of symbols, and template static VectorFst *GenAcceptorFromSequence(const vector &symbols, float cost) { typedef typename Arc::Weight Weight; typedef typename Arc::StateId StateId; vector split_cost(symbols.size()+1, 0.0); // for #-arcs + end-state. { // compute split_cost. it must sum to "cost". std::set indices; size_t num_indices = 1 + (kaldi::Rand() % split_cost.size()); while (indices.size() < num_indices) indices.insert(kaldi::Rand() % split_cost.size()); for (std::set::iterator iter = indices.begin(); iter != indices.end(); ++iter) { split_cost[*iter] = cost / num_indices; } } VectorFst *fst = new VectorFst(); StateId cur_state = fst->AddState(); fst->SetStart(cur_state); for (size_t i = 0; i < symbols.size(); i++) { StateId next_state = fst->AddState(); Arc arc; arc.ilabel = symbols[i]; arc.olabel = symbols[i]; arc.nextstate = next_state; arc.weight = (Weight) split_cost[i]; fst->AddArc(cur_state, arc); cur_state = next_state; } fst->SetFinal(cur_state, (Weight)split_cost[symbols.size()]); return fst; } // CheckPhones is used to test the correctness of an FST that is the result of // composition with a ContextFst. template static float CheckPhones(const VectorFst &linear_fst, const vector &phone_ids, const vector &disambig_ids, const vector &phone_seq, const vector > &ilabel_info, int N, int P) { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; assert(kaldi::IsSorted(phone_ids)); // so we can do binary_search. vector input_syms; vector output_syms; Weight tot_cost; bool ans = GetLinearSymbolSequence(linear_fst, &input_syms, &output_syms, &tot_cost); assert(ans); // should be linear. vector phone_seq_check; for (size_t i = 0; i < output_syms.size(); i++) if (std::binary_search(phone_ids.begin(), phone_ids.end(), output_syms[i])) phone_seq_check.push_back(output_syms[i]); assert(phone_seq_check == phone_seq); vector > input_syms_long; for (size_t i = 0; i < input_syms.size(); i++) { Label isym = input_syms[i]; if (ilabel_info[isym].size() == 0) continue; // epsilon. if ( (ilabel_info[isym].size() == 1 && ilabel_info[isym][0] <= 0) ) continue; // disambig. input_syms_long.push_back(ilabel_info[isym]); } for (size_t i = 0; i < input_syms_long.size(); i++) { vector phone_context_window(N); // phone at pos i will be at pos P in this window. int pos = ((int)i) - P; // pos of first phone in window [ may be out of range] . for (int j = 0; j < N; j++, pos++) { if (static_cast(pos) < phone_seq.size()) phone_context_window[j] = phone_seq[pos]; else phone_context_window[j] = 0; // 0 is a special symbol that context-dep-itf expects to see // when no phone is present due to out-of-window. context-fst knows about this too. } assert(input_syms_long[i] == phone_context_window); } return tot_cost.Value(); } template static VectorFst *GenRandPhoneSeq(vector &phone_syms, vector &disambig_syms, typename Arc::Label subsequential_symbol, int num_subseq_syms, float seq_prob, vector *phoneseq_out) { KALDI_ASSERT(phoneseq_out != NULL); typedef typename Arc::Label Label; // Generate an FST that is a random phone sequence, ending // with "num_subseq_syms" subsequential symbols. It will // have disambiguation symbols randomly interspersed throughout. // The number of phones is random (possibly zero). size_t len = (kaldi::Rand() % 4) * (kaldi::Rand() % 3); // up to 3*2=6 phones. float disambig_prob = 0.33; phoneseq_out->clear(); vector