// kwsbin/compute-atwv.cc // Copyright (c) 2015, Johns Hopkins University (Yenda Trmal) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include #include // std::setw #include "base/kaldi-common.h" #include "util/common-utils.h" #include "util/stl-utils.h" #include "kws/kws-scoring.h" int main(int argc, char *argv[]) { try { using namespace kaldi; typedef kaldi::int32 int32; typedef kaldi::uint32 uint32; typedef kaldi::uint64 uint64; const char *usage = "Computes the Actual Term-Weighted Value and prints it." "\n" "Usage: \n" " compute-atwv [options] " " [alignment-csv-filename]\n" "e.g.: \n" " compute-atwv 32485.4 ark:ref.1 ark:hyp.1 ali.csv\n" "or: \n" " compute-atwv 32485.4 ark:ref.1 ark:hyp.1\n" "\n" "NOTES: \n" " a) the number of trials is usually equal to the size of the searched\n" " collection in seconds\n" " b the ref-rspecifier/hyp-rspecifier are the kaldi IO specifiers \n" " for both the reference and the hypotheses (found hits), " " respectively The format is the same for both of them. Each line\n" " is of the following format\n" "\n" " \n\n" " e.g.:\n\n" " KW106-189 348 459 560 0.8\n" "\n" " b) the alignment-csv-filename is an optional parameter. \n" " If present, the alignment i.e. detailed information about what \n" " hypotheses match up with which reference entries will be \n" " generated. The alignemnt file format is equivalent to \n" " the alignment file produced using the F4DE tool. However, we do" " not set some fields and the utterance identifiers are numeric.\n" " You can use the script utils/int2sym.pl and the utterance and \n" " keyword maps to convert the numerical ids into text form\n" " c) the scores are expected to be probabilities. Please note that\n" " the output from the kws-search is in -log(probability).\n" " d) compute-atwv does not perform any score normalization (it's just\n" " for scoring purposes). Without score normalization/calibration\n" " the performance of the search will be quite poor.\n"; ParseOptions po(usage); KwsTermsAlignerOptions ali_opts; TwvMetricsOptions twv_opts; int frames_per_sec = 100; ali_opts.Register(&po); twv_opts.Register(&po); po.Register("frames-per-sec", &frames_per_sec, "Number of feature vector frames per second. This is used only when" "writing the alignment to a file"); po.Read(argc, argv); if (po.NumArgs() < 3 || po.NumArgs() > 4) { po.PrintUsage(); exit(1); } if (!kaldi::ConvertStringToReal(po.GetArg(1), &twv_opts.audio_duration)) { KALDI_ERR << "The duration parameter is not a number"; } if (twv_opts.audio_duration <= 0) { KALDI_ERR << "The duration is either negative or zero"; } KwsTermsAligner aligner(ali_opts); TwvMetrics twv_scores(twv_opts); std::string ref_rspecifier = po.GetArg(2), hyp_rspecifier = po.GetArg(3), ali_output = po.GetOptArg(4); kaldi::SequentialTableReader< kaldi::BasicVectorHolder > ref_reader(ref_rspecifier); for (; !ref_reader.Done(); ref_reader.Next()) { std::string kwid = ref_reader.Key(); std::vector vals = ref_reader.Value(); if (vals.size() != 4) { KALDI_ERR << "Incorrect format of the reference file" << " -- 4 entries expected, " << vals.size() << " given!\n" << "Key: " << kwid; } KwsTerm inst(kwid, vals); aligner.AddRef(inst); } kaldi::SequentialTableReader< kaldi::BasicVectorHolder > hyp_reader(hyp_rspecifier); for (; !hyp_reader.Done(); hyp_reader.Next()) { std::string kwid = hyp_reader.Key(); std::vector vals = hyp_reader.Value(); if (vals.size() != 4) { KALDI_ERR << "Incorrect format of the hypotheses file" << " -- 4 entries expected, " << vals.size() << " given!\n" << "Key: " << kwid; } KwsTerm inst(kwid, vals); aligner.AddHyp(inst); } KALDI_LOG << "Read " << aligner.nof_hyps() << " hypotheses"; KALDI_LOG << "Read " << aligner.nof_refs() << " references"; KwsAlignment ali = aligner.AlignTerms(); if (ali_output != "") { std::fstream fs; fs.open(ali_output.c_str(), std::fstream::out); ali.WriteCsv(fs, frames_per_sec); fs.close(); } TwvMetrics scores(twv_opts); scores.AddAlignment(ali); std::cout << "aproximate ATWV = " << std::fixed << std::setprecision(4) << scores.Atwv() << std::endl; std::cout << "aproximate STWV = " << std::fixed << std::setprecision(4) << scores.Stwv() << std::endl; float mtwv, mtwv_threshold, otwv; scores.GetOracleMeasures(&mtwv, &mtwv_threshold, &otwv); std::cout << "aproximate MTWV = " << std::fixed << std::setprecision(4) << mtwv << std::endl; std::cout << "aproximate MTWV threshold = " << std::fixed << std::setprecision(4) << mtwv_threshold << std::endl; std::cout << "aproximate OTWV = " << std::fixed << std::setprecision(4) << otwv << std::endl; } catch(const std::exception &e) { std::cerr << e.what(); return -1; } }