// fstext/table-matcher-test.cc // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "fstext/table-matcher.h" #include "fstext/fst-test-utils.h" #include "base/kaldi-math.h" namespace fst{ // Don't instantiate with log semiring, as RandEquivalent may fail. template void TestTableMatcher(bool connect, bool left) { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; VectorFst *fst1 = RandFst(); VectorFst *fst2 = RandFst(); ILabelCompare ilabel_comp; OLabelCompare olabel_comp; TableComposeOptions opts; if (left) opts.table_match_type = MATCH_OUTPUT; else opts.table_match_type = MATCH_INPUT; opts.min_table_size = 1 + kaldi::Rand() % 5; opts.table_ratio = 0.25 * (kaldi::Rand() % 5); opts.connect = connect; ArcSort(fst1, olabel_comp); ArcSort(fst2, ilabel_comp); VectorFst composed; TableCompose(*fst1, *fst2, &composed, opts); if (!connect) Connect(&composed); VectorFst composed_baseline; Compose(*fst1, *fst2, &composed_baseline); std::cout << "Connect = "<< (connect?"True\n":"False\n"); std::cout <<"Table-Composed FST\n"; { FstPrinter fstprinter(composed, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } std::cout <<" Baseline-Composed FST\n"; { FstPrinter fstprinter(composed_baseline, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } if ( !RandEquivalent(composed, composed_baseline, 3/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 20/*path length-- max?*/)) { VectorFst diff1; Difference(composed, composed_baseline, &diff1); std::cout <<" Diff1 (composed - baseline) \n"; { FstPrinter fstprinter(diff1, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } VectorFst diff2; Difference(composed_baseline, composed, &diff2); std::cout <<" Diff2 (baseline - composed) \n"; { FstPrinter fstprinter(diff2, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } assert(0); } delete fst1; delete fst2; } // Don't instantiate with log semiring, as RandEquivalent may fail. template void TestTableMatcherCacheLeft(bool connect) { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; VectorFst *fst1 = RandFst(); TableComposeOptions opts; opts.table_match_type = MATCH_OUTPUT; opts.min_table_size = 1 + kaldi::Rand() % 5; opts.table_ratio = 0.25 * (kaldi::Rand() % 5); opts.connect = connect; TableComposeCache > cache(opts); for (size_t i = 0; i < 3; i++) { VectorFst *fst2 = RandFst(); ILabelCompare ilabel_comp; OLabelCompare olabel_comp; ArcSort(fst1, olabel_comp); ArcSort(fst2, ilabel_comp); VectorFst composed; TableCompose(*fst1, *fst2, &composed, &cache); if (!connect) Connect(&composed); VectorFst composed_baseline; Compose(*fst1, *fst2, &composed_baseline); std::cout << "Connect = "<< (connect?"True\n":"False\n"); if ( !RandEquivalent(composed, composed_baseline, 3/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 100/*path length-- max?*/)) { VectorFst diff1; Difference(composed, composed_baseline, &diff1); std::cout <<" Diff1 (composed - baseline) \n"; { FstPrinter fstprinter(diff1, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } VectorFst diff2; Difference(composed_baseline, composed, &diff2); std::cout <<" Diff2 (baseline - composed) \n"; { FstPrinter fstprinter(diff2, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } assert(0); } delete fst2; } delete fst1; } template void TestTableMatcherCacheRight(bool connect) { typedef typename Arc::Label Label; typedef typename Arc::StateId StateId; typedef typename Arc::Weight Weight; VectorFst *fst2 = RandFst(); ILabelCompare ilabel_comp; ArcSort(fst2, ilabel_comp); TableComposeOptions opts; opts.table_match_type = MATCH_INPUT; opts.min_table_size = 1 + kaldi::Rand() % 5; opts.table_ratio = 0.25 * (kaldi::Rand() % 5); opts.connect = connect; TableComposeCache > cache(opts); for (size_t i = 0; i < 2; i++) { VectorFst *fst1 = RandFst(); OLabelCompare olabel_comp; ArcSort(fst1, olabel_comp); VectorFst composed; TableCompose(*fst1, *fst2, &composed, &cache); if (!connect) Connect(&composed); VectorFst composed_baseline; Compose(*fst1, *fst2, &composed_baseline); std::cout << "Connect = "<< (connect?"True\n":"False\n"); if ( !RandEquivalent(composed, composed_baseline, 5/*paths*/, 0.01/*delta*/, kaldi::Rand()/*seed*/, 20/*path length-- max?*/)) { VectorFst diff1; Difference(composed, composed_baseline, &diff1); std::cout <<" Diff1 (composed - baseline) \n"; { FstPrinter fstprinter(diff1, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } VectorFst diff2; Difference(composed_baseline, composed, &diff2); std::cout <<" Diff2 (baseline - composed) \n"; { FstPrinter fstprinter(diff2, NULL, NULL, NULL, false, true, "\t"); fstprinter.Print(&std::cout, "standard output"); } assert(0); } delete fst1; } delete fst2; } } // namespace fst int main() { using namespace fst; for (int i = 0;i < 1;i++) { TestTableMatcher(true, true); TestTableMatcher(false, true); TestTableMatcher(true, false); TestTableMatcher(false, false); TestTableMatcherCacheLeft(true); TestTableMatcherCacheLeft(false); TestTableMatcherCacheRight(true); TestTableMatcherCacheRight(false); } }