// online2bin/online2-wav-nnet2-latgen-threaded.cc // Copyright 2014-2015 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "feat/wave-reader.h" #include "online2/online-nnet2-decoding-threaded.h" #include "online2/online-nnet2-feature-pipeline.h" #include "online2/onlinebin-util.h" #include "online2/online-timing.h" #include "online2/online-endpoint.h" #include "fstext/fstext-lib.h" #include "lat/lattice-functions.h" #include "util/kaldi-thread.h" namespace kaldi { void GetDiagnosticsAndPrintOutput(const std::string &utt, const fst::SymbolTable *word_syms, const CompactLattice &clat, int64 *tot_num_frames, double *tot_like) { if (clat.NumStates() == 0) { KALDI_WARN << "Empty lattice."; return; } CompactLattice best_path_clat; CompactLatticeShortestPath(clat, &best_path_clat); Lattice best_path_lat; ConvertLattice(best_path_clat, &best_path_lat); double likelihood; LatticeWeight weight; int32 num_frames; std::vector alignment; std::vector words; GetLinearSymbolSequence(best_path_lat, &alignment, &words, &weight); num_frames = alignment.size(); likelihood = -(weight.Value1() + weight.Value2()); *tot_num_frames += num_frames; *tot_like += likelihood; KALDI_VLOG(2) << "Likelihood per frame for utterance " << utt << " is " << (likelihood / num_frames) << " over " << num_frames << " frames."; if (word_syms != NULL) { std::cerr << utt << ' '; for (size_t i = 0; i < words.size(); i++) { std::string s = word_syms->Find(words[i]); if (s == "") KALDI_ERR << "Word-id " << words[i] << " not in symbol table."; std::cerr << s << ' '; } std::cerr << std::endl; } } } int main(int argc, char *argv[]) { try { using namespace kaldi; using namespace fst; typedef kaldi::int32 int32; typedef kaldi::int64 int64; const char *usage = "Reads in wav file(s) and simulates online decoding with neural nets\n" "(nnet2 setup), with optional iVector-based speaker adaptation and\n" "optional endpointing. This version uses multiple threads for decoding.\n" "Note: some configuration values and inputs are set via config files\n" "whose filenames are passed as options\n" "\n" "Usage: online2-wav-nnet2-latgen-threaded [options] " " \n" "The spk2utt-rspecifier can just be if\n" "you want to decode utterance by utterance.\n" "See egs/rm/s5/local/run_online_decoding_nnet2.sh for example\n" "See also online2-wav-nnet2-latgen-faster\n"; ParseOptions po(usage); std::string word_syms_rxfilename; OnlineEndpointConfig endpoint_config; // feature_config includes configuration for the iVector adaptation, // as well as the basic features. OnlineNnet2FeaturePipelineConfig feature_config; OnlineNnet2DecodingThreadedConfig nnet2_decoding_config; BaseFloat chunk_length_secs = 0.05; bool do_endpointing = false; bool modify_ivector_config = false; bool simulate_realtime_decoding = true; po.Register("chunk-length", &chunk_length_secs, "Length of chunk size in seconds, that we provide each time to the " "decoder. The actual chunk sizes it processes for various stages " "of decoding are dynamically determinated, and unrelated to this"); po.Register("word-symbol-table", &word_syms_rxfilename, "Symbol table for words [for debug output]"); po.Register("do-endpointing", &do_endpointing, "If true, apply endpoint detection"); po.Register("modify-ivector-config", &modify_ivector_config, "If true, modifies the iVector configuration from the config files " "by setting --use-most-recent-ivector=true and --greedy-ivector-extractor=true. " "This will give the best possible results, but the results may become dependent " "on the speed of your machine (slower machine -> better results). Compare " "to the --online option in online2-wav-nnet2-latgen-faster"); po.Register("simulate-realtime-decoding", &simulate_realtime_decoding, "If true, simulate real-time decoding scenario by providing the " "data incrementally, calling sleep() until each piece is ready. " "If false, don't sleep (so it will be faster)."); po.Register("num-threads-startup", &g_num_threads, "Number of threads used when initializing iVector extractor. "); feature_config.Register(&po); nnet2_decoding_config.Register(&po); endpoint_config.Register(&po); po.Read(argc, argv); if (po.NumArgs() != 5) { po.PrintUsage(); return 1; } std::string nnet2_rxfilename = po.GetArg(1), fst_rxfilename = po.GetArg(2), spk2utt_rspecifier = po.GetArg(3), wav_rspecifier = po.GetArg(4), clat_wspecifier = po.GetArg(5); OnlineNnet2FeaturePipelineInfo feature_info(feature_config); if (modify_ivector_config) { feature_info.ivector_extractor_info.use_most_recent_ivector = true; feature_info.ivector_extractor_info.greedy_ivector_extractor = true; } TransitionModel trans_model; nnet2::AmNnet am_nnet; { bool binary; Input ki(nnet2_rxfilename, &binary); trans_model.Read(ki.Stream(), binary); am_nnet.Read(ki.Stream(), binary); } fst::Fst *decode_fst = ReadFstKaldiGeneric(fst_rxfilename); fst::SymbolTable *word_syms = NULL; if (word_syms_rxfilename != "") if (!(word_syms = fst::SymbolTable::ReadText(word_syms_rxfilename))) KALDI_ERR << "Could not read symbol table from file " << word_syms_rxfilename; int32 num_done = 0, num_err = 0; double tot_like = 0.0; int64 num_frames = 0; Timer global_timer; SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier); RandomAccessTableReader wav_reader(wav_rspecifier); CompactLatticeWriter clat_writer(clat_wspecifier); OnlineTimingStats timing_stats; for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { std::string spk = spk2utt_reader.Key(); const std::vector &uttlist = spk2utt_reader.Value(); OnlineIvectorExtractorAdaptationState adaptation_state( feature_info.ivector_extractor_info); for (size_t i = 0; i < uttlist.size(); i++) { std::string utt = uttlist[i]; if (!wav_reader.HasKey(utt)) { KALDI_WARN << "Did not find audio for utterance " << utt; num_err++; continue; } const WaveData &wave_data = wav_reader.Value(utt); // get the data for channel zero (if the signal is not mono, we only // take the first channel). SubVector data(wave_data.Data(), 0); SingleUtteranceNnet2DecoderThreaded decoder( nnet2_decoding_config, trans_model, am_nnet, *decode_fst, feature_info, adaptation_state); OnlineTimer decoding_timer(utt); BaseFloat samp_freq = wave_data.SampFreq(); int32 chunk_length; KALDI_ASSERT(chunk_length_secs > 0); chunk_length = int32(samp_freq * chunk_length_secs); if (chunk_length == 0) chunk_length = 1; int32 samp_offset = 0; while (samp_offset < data.Dim()) { int32 samp_remaining = data.Dim() - samp_offset; int32 num_samp = chunk_length < samp_remaining ? chunk_length : samp_remaining; SubVector wave_part(data, samp_offset, num_samp); // The endpointing code won't work if we let the waveform be given to // the decoder all at once, because we'll exit this while loop, and // the endpointing happens inside this while loop. The next statement // is intended to prevent this from happening. while (do_endpointing && decoder.NumWaveformPiecesPending() * chunk_length_secs > 2.0) Sleep(0.5f); decoder.AcceptWaveform(samp_freq, wave_part); samp_offset += num_samp; if (simulate_realtime_decoding) { // Note: the next call may actually call sleep(). decoding_timer.SleepUntil(samp_offset / samp_freq); } if (samp_offset == data.Dim()) { // no more input. flush out last frames decoder.InputFinished(); } if (do_endpointing && decoder.EndpointDetected(endpoint_config)) { decoder.TerminateDecoding(); break; } } Timer timer; decoder.Wait(); if (simulate_realtime_decoding) { KALDI_VLOG(1) << "Waited " << timer.Elapsed() << " seconds for decoder to " << "finish after giving it last chunk."; } decoder.FinalizeDecoding(); CompactLattice clat; bool end_of_utterance = true; decoder.GetLattice(end_of_utterance, &clat, NULL); GetDiagnosticsAndPrintOutput(utt, word_syms, clat, &num_frames, &tot_like); decoding_timer.OutputStats(&timing_stats); // In an application you might avoid updating the adaptation state if // you felt the utterance had low confidence. See lat/confidence.h decoder.GetAdaptationState(&adaptation_state); // we want to output the lattice with un-scaled acoustics. BaseFloat inv_acoustic_scale = 1.0 / nnet2_decoding_config.acoustic_scale; ScaleLattice(AcousticLatticeScale(inv_acoustic_scale), &clat); if (simulate_realtime_decoding) { KALDI_VLOG(1) << "Adding the various end-of-utterance tasks took the " << "total latency to " << timer.Elapsed() << " seconds."; } clat_writer.Write(utt, clat); KALDI_LOG << "Decoded utterance " << utt; num_done++; } } bool online = true; if (simulate_realtime_decoding) { timing_stats.Print(online); } else { BaseFloat frame_shift = 0.01; BaseFloat real_time_factor = global_timer.Elapsed() / (frame_shift * num_frames); if (num_frames > 0) KALDI_LOG << "Real-time factor was " << real_time_factor << " assuming frame shift of " << frame_shift; } KALDI_LOG << "Decoded " << num_done << " utterances, " << num_err << " with errors."; KALDI_LOG << "Overall likelihood per frame was " << (tot_like / num_frames) << " per frame over " << num_frames << " frames."; delete decode_fst; delete word_syms; // will delete if non-NULL. return (num_done != 0 ? 0 : 1); } catch(const std::exception& e) { std::cerr << e.what(); return -1; } } // main()