// See www.openfst.org for extensive documentation on this weighted // finite-state transducer library. // // Regression test for FST weights. #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "./weight-tester.h" DEFINE_int32(seed, -1, "random seed"); DEFINE_int32(repeat, 10000, "number of test repetitions"); namespace { using fst::Adder; using fst::ExpectationWeight; using fst::GALLIC; using fst::GallicWeight; using fst::LexicographicWeight; using fst::LogWeight; using fst::LogWeightTpl; using fst::MinMaxWeight; using fst::MinMaxWeightTpl; using fst::NaturalLess; using fst::PowerWeight; using fst::ProductWeight; using fst::SetWeight; using fst::SET_INTERSECT_UNION; using fst::SET_UNION_INTERSECT; using fst::SET_BOOLEAN; using fst::SignedLogWeight; using fst::SignedLogWeightTpl; using fst::SparsePowerWeight; using fst::StringWeight; using fst::STRING_LEFT; using fst::STRING_RIGHT; using fst::TropicalWeight; using fst::TropicalWeightTpl; using fst::UnionWeight; using fst::WeightConvert; using fst::WeightGenerate; using fst::WeightTester; template void TestTemplatedWeights(int repeat) { using TropicalWeightGenerate = WeightGenerate>; TropicalWeightGenerate tropical_generate; WeightTester, TropicalWeightGenerate> tropical_tester( tropical_generate); tropical_tester.Test(repeat); using LogWeightGenerate = WeightGenerate>; LogWeightGenerate log_generate; WeightTester, LogWeightGenerate> log_tester(log_generate); log_tester.Test(repeat); using MinMaxWeightGenerate = WeightGenerate>; MinMaxWeightGenerate minmax_generate(true); WeightTester, MinMaxWeightGenerate> minmax_tester( minmax_generate); minmax_tester.Test(repeat); using SignedLogWeightGenerate = WeightGenerate>; SignedLogWeightGenerate signedlog_generate; WeightTester, SignedLogWeightGenerate> signedlog_tester(signedlog_generate); signedlog_tester.Test(repeat); } template void TestAdder(int n) { Weight sum = Weight::Zero(); Adder adder; for (int i = 0; i < n; ++i) { sum = Plus(sum, Weight::One()); adder.Add(Weight::One()); } CHECK(ApproxEqual(sum, adder.Sum())); } template void TestSignedAdder(int n) { Weight sum = Weight::Zero(); Adder adder; const Weight minus_one = Minus(Weight::Zero(), Weight::One()); for (int i = 0; i < n; ++i) { if (i < n/4 || i > 3*n/4) { sum = Plus(sum, Weight::One()); adder.Add(Weight::One()); } else { sum = Minus(sum, Weight::One()); adder.Add(minus_one); } } CHECK(ApproxEqual(sum, adder.Sum())); } template void TestWeightConversion(Weight1 w1) { // Tests round-trp conversion. WeightConvert to_w1_; WeightConvert to_w2_; Weight2 w2 = to_w2_(w1); Weight1 nw1 = to_w1_(w2); CHECK_EQ(w1, nw1); } template void TestWeightCopy(FromWeight w) { // Test copy constructor. const ToWeight to_copied(w); const FromWeight roundtrip_copied(to_copied); CHECK_EQ(w, roundtrip_copied); // Test copy assign. ToWeight to_copy_assigned; to_copy_assigned = w; CHECK_EQ(to_copied, to_copy_assigned); FromWeight roundtrip_copy_assigned; roundtrip_copy_assigned = to_copy_assigned; CHECK_EQ(w, roundtrip_copy_assigned); } template void TestWeightMove(FromWeight w) { // Assume FromWeight -> FromWeight copy works. const FromWeight orig(w); ToWeight to_moved(std::move(w)); const FromWeight roundtrip_moved(std::move(to_moved)); CHECK_EQ(orig, roundtrip_moved); // Test move assign. w = orig; ToWeight to_move_assigned; to_move_assigned = std::move(w); FromWeight roundtrip_move_assigned; roundtrip_move_assigned = std::move(to_move_assigned); CHECK_EQ(orig, roundtrip_move_assigned); } template void TestImplicitConversion() { // Only test a few of the operations; assumes they are implemented with the // same pattern. CHECK(Weight(2.0f) == 2.0f); CHECK(Weight(2.0) == 2.0); CHECK(2.0f == Weight(2.0f)); CHECK(2.0 == Weight(2.0)); CHECK_EQ(Weight::Zero(), Times(Weight::Zero(), 3.0f)); CHECK_EQ(Weight::Zero(), Times(Weight::Zero(), 3.0)); CHECK_EQ(Weight::Zero(), Times(3.0, Weight::Zero())); CHECK_EQ(Weight(3.0), Plus(Weight::Zero(), 3.0f)); CHECK_EQ(Weight(3.0), Plus(Weight::Zero(), 3.0)); CHECK_EQ(Weight(3.0), Plus(3.0, Weight::Zero())); } void TestPowerWeightGetSetValue() { PowerWeight w; // LogWeight has unspecified initial value, so don't check it. w.SetValue(0, LogWeight(2)); w.SetValue(1, LogWeight(3)); CHECK_EQ(LogWeight(2), w.Value(0)); CHECK_EQ(LogWeight(3), w.Value(1)); } void TestSparsePowerWeightGetSetValue() { const LogWeight default_value(17); SparsePowerWeight w; w.SetDefaultValue(default_value); // All gets should be the default. CHECK_EQ(default_value, w.Value(0)); CHECK_EQ(default_value, w.Value(100)); // First set should fill first_. w.SetValue(10, LogWeight(10)); CHECK_EQ(LogWeight(10), w.Value(10)); w.SetValue(10, LogWeight(20)); CHECK_EQ(LogWeight(20), w.Value(10)); // Add a smaller index. w.SetValue(5, LogWeight(5)); CHECK_EQ(LogWeight(5), w.Value(5)); CHECK_EQ(LogWeight(20), w.Value(10)); // Add some larger indices. w.SetValue(30, LogWeight(30)); CHECK_EQ(LogWeight(5), w.Value(5)); CHECK_EQ(LogWeight(20), w.Value(10)); CHECK_EQ(LogWeight(30), w.Value(30)); w.SetValue(29, LogWeight(29)); CHECK_EQ(LogWeight(5), w.Value(5)); CHECK_EQ(LogWeight(20), w.Value(10)); CHECK_EQ(LogWeight(29), w.Value(29)); CHECK_EQ(LogWeight(30), w.Value(30)); w.SetValue(31, LogWeight(31)); CHECK_EQ(LogWeight(5), w.Value(5)); CHECK_EQ(LogWeight(20), w.Value(10)); CHECK_EQ(LogWeight(29), w.Value(29)); CHECK_EQ(LogWeight(30), w.Value(30)); CHECK_EQ(LogWeight(31), w.Value(31)); // Replace a value. w.SetValue(30, LogWeight(60)); CHECK_EQ(LogWeight(60), w.Value(30)); // Replace a value with the default. CHECK_EQ(5, w.Size()); w.SetValue(30, default_value); CHECK_EQ(default_value, w.Value(30)); CHECK_EQ(4, w.Size()); // Replace lowest index by the default value. w.SetValue(5, default_value); CHECK_EQ(default_value, w.Value(5)); CHECK_EQ(3, w.Size()); // Clear out everything. w.SetValue(31, default_value); w.SetValue(29, default_value); w.SetValue(10, default_value); CHECK_EQ(0, w.Size()); CHECK_EQ(default_value, w.Value(5)); CHECK_EQ(default_value, w.Value(10)); CHECK_EQ(default_value, w.Value(29)); CHECK_EQ(default_value, w.Value(30)); CHECK_EQ(default_value, w.Value(31)); } } // namespace int main(int argc, char **argv) { std::set_new_handler(FailedNewHandler); SET_FLAGS(argv[0], &argc, &argv, true); LOG(INFO) << "Seed = " << FLAGS_seed; srand(FLAGS_seed); TestTemplatedWeights(FLAGS_repeat); TestTemplatedWeights(FLAGS_repeat); FLAGS_fst_weight_parentheses = "()"; TestTemplatedWeights(FLAGS_repeat); TestTemplatedWeights(FLAGS_repeat); FLAGS_fst_weight_parentheses = ""; // Makes sure type names for templated weights are consistent. CHECK(TropicalWeight::Type() == "tropical"); CHECK(TropicalWeightTpl::Type() != TropicalWeightTpl::Type()); CHECK(LogWeight::Type() == "log"); CHECK(LogWeightTpl::Type() != LogWeightTpl::Type()); TropicalWeightTpl w(2.0); TropicalWeight tw(2.0); TestAdder(1000); TestAdder(1000); TestSignedAdder(1000); TestImplicitConversion(); TestImplicitConversion(); TestImplicitConversion(); TestWeightConversion(2.0); using LeftStringWeight = StringWeight; using LeftStringWeightGenerate = WeightGenerate; LeftStringWeightGenerate left_string_generate; WeightTester left_string_tester( left_string_generate); left_string_tester.Test(FLAGS_repeat); using RightStringWeight = StringWeight; using RightStringWeightGenerate = WeightGenerate; RightStringWeightGenerate right_string_generate; WeightTester right_string_tester(right_string_generate); right_string_tester.Test(FLAGS_repeat); // STRING_RESTRICT not tested since it requires equal strings, // so would fail. using IUSetWeight = SetWeight; using IUSetWeightGenerate = WeightGenerate; IUSetWeightGenerate iu_set_generate; WeightTester iu_set_tester(iu_set_generate); iu_set_tester.Test(FLAGS_repeat); using UISetWeight = SetWeight; using UISetWeightGenerate = WeightGenerate; UISetWeightGenerate ui_set_generate; WeightTester ui_set_tester(ui_set_generate); ui_set_tester.Test(FLAGS_repeat); // SET_INTERSECT_UNION_RESTRICT not tested since it requires equal sets, // so would fail. using BoolSetWeight = SetWeight; using BoolSetWeightGenerate = WeightGenerate; BoolSetWeightGenerate bool_set_generate; WeightTester bool_set_tester(bool_set_generate); bool_set_tester.Test(FLAGS_repeat); TestWeightConversion(iu_set_generate()); TestWeightCopy(iu_set_generate()); TestWeightCopy(iu_set_generate()); TestWeightCopy(ui_set_generate()); TestWeightCopy(ui_set_generate()); TestWeightCopy(bool_set_generate()); TestWeightCopy(bool_set_generate()); TestWeightMove(iu_set_generate()); TestWeightMove(iu_set_generate()); TestWeightMove(ui_set_generate()); TestWeightMove(ui_set_generate()); TestWeightMove(bool_set_generate()); TestWeightMove(bool_set_generate()); // COMPOSITE WEIGHTS AND TESTERS - DEFINITIONS using TropicalGallicWeight = GallicWeight; using TropicalGallicWeightGenerate = WeightGenerate; TropicalGallicWeightGenerate tropical_gallic_generate(true); WeightTester tropical_gallic_tester(tropical_gallic_generate); using TropicalGenGallicWeight = GallicWeight; using TropicalGenGallicWeightGenerate = WeightGenerate; TropicalGenGallicWeightGenerate tropical_gen_gallic_generate(false); WeightTester tropical_gen_gallic_tester(tropical_gen_gallic_generate); using TropicalProductWeight = ProductWeight; using TropicalProductWeightGenerate = WeightGenerate; TropicalProductWeightGenerate tropical_product_generate; WeightTester tropical_product_tester(tropical_product_generate); using TropicalLexicographicWeight = LexicographicWeight; using TropicalLexicographicWeightGenerate = WeightGenerate; TropicalLexicographicWeightGenerate tropical_lexicographic_generate; WeightTester tropical_lexicographic_tester(tropical_lexicographic_generate); using TropicalCubeWeight = PowerWeight; using TropicalCubeWeightGenerate = WeightGenerate; TropicalCubeWeightGenerate tropical_cube_generate; WeightTester tropical_cube_tester(tropical_cube_generate); using FirstNestedProductWeight = ProductWeight; using FirstNestedProductWeightGenerate = WeightGenerate; FirstNestedProductWeightGenerate first_nested_product_generate; WeightTester first_nested_product_tester(first_nested_product_generate); using SecondNestedProductWeight = ProductWeight; using SecondNestedProductWeightGenerate = WeightGenerate; SecondNestedProductWeightGenerate second_nested_product_generate; WeightTester second_nested_product_tester(second_nested_product_generate); using NestedProductCubeWeight = PowerWeight; using NestedProductCubeWeightGenerate = WeightGenerate; NestedProductCubeWeightGenerate nested_product_cube_generate; WeightTester nested_product_cube_tester(nested_product_cube_generate); using SparseNestedProductCubeWeight = SparsePowerWeight; using SparseNestedProductCubeWeightGenerate = WeightGenerate; SparseNestedProductCubeWeightGenerate sparse_nested_product_cube_generate; WeightTester sparse_nested_product_cube_tester(sparse_nested_product_cube_generate); using LogSparsePowerWeight = SparsePowerWeight; using LogSparsePowerWeightGenerate = WeightGenerate; LogSparsePowerWeightGenerate log_sparse_power_generate; WeightTester log_sparse_power_tester(log_sparse_power_generate); using LogLogExpectationWeight = ExpectationWeight; using LogLogExpectationWeightGenerate = WeightGenerate; LogLogExpectationWeightGenerate log_log_expectation_generate; WeightTester log_log_expectation_tester(log_log_expectation_generate); using LogLogSparseExpectationWeight = ExpectationWeight; using LogLogSparseExpectationWeightGenerate = WeightGenerate; LogLogSparseExpectationWeightGenerate log_log_sparse_expectation_generate; WeightTester log_log_sparse_expectation_tester(log_log_sparse_expectation_generate); struct UnionWeightOptions { using Compare = NaturalLess; struct Merge { TropicalWeight operator()(const TropicalWeight &w1, const TropicalWeight &w2) const { return w1; } }; using ReverseOptions = UnionWeightOptions; }; using TropicalUnionWeight = UnionWeight; using TropicalUnionWeightGenerate = WeightGenerate; TropicalUnionWeightGenerate tropical_union_generate; WeightTester tropical_union_tester(tropical_union_generate); // COMPOSITE WEIGHTS AND TESTERS - TESTING // Tests composite weight I/O with parentheses. FLAGS_fst_weight_parentheses = "()"; // Unnested composite. tropical_gallic_tester.Test(FLAGS_repeat); tropical_gen_gallic_tester.Test(FLAGS_repeat); tropical_product_tester.Test(FLAGS_repeat); tropical_lexicographic_tester.Test(FLAGS_repeat); tropical_cube_tester.Test(FLAGS_repeat); log_sparse_power_tester.Test(FLAGS_repeat); log_log_expectation_tester.Test(FLAGS_repeat, false); tropical_union_tester.Test(FLAGS_repeat, false); // Nested composite. first_nested_product_tester.Test(FLAGS_repeat); second_nested_product_tester.Test(5); nested_product_cube_tester.Test(FLAGS_repeat); sparse_nested_product_cube_tester.Test(FLAGS_repeat); log_log_sparse_expectation_tester.Test(FLAGS_repeat, false); // ... and tests composite weight I/O without parentheses. FLAGS_fst_weight_parentheses = ""; // Unnested composite. tropical_gallic_tester.Test(FLAGS_repeat); tropical_product_tester.Test(FLAGS_repeat); tropical_lexicographic_tester.Test(FLAGS_repeat); tropical_cube_tester.Test(FLAGS_repeat); log_sparse_power_tester.Test(FLAGS_repeat); log_log_expectation_tester.Test(FLAGS_repeat, false); tropical_union_tester.Test(FLAGS_repeat, false); // Nested composite. second_nested_product_tester.Test(FLAGS_repeat); log_log_sparse_expectation_tester.Test(FLAGS_repeat, false); TestPowerWeightGetSetValue(); TestSparsePowerWeightGetSetValue(); std::cout << "PASS" << std::endl; return 0; }