Blame view

src/online2bin/online2-wav-nnet2-am-compute.cc 7.82 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
  // online2bin/online2-wav-nnet2-am-compute.cc
  
  // Copyright 2014  Johns Hopkins University (author: Daniel Povey)
  //           2014  David Snyder
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #include "feat/wave-reader.h"
  #include "online2/online-nnet2-decoding.h"
  #include "online2/online-nnet2-feature-pipeline.h"
  #include "online2/onlinebin-util.h"
  
  int main(int argc, char *argv[]) {
    try {
      using namespace kaldi;
      using namespace kaldi::nnet2;
      typedef kaldi::int32 int32;
      typedef kaldi::int64 int64;
      
      const char *usage =
          "Simulates the online neural net computation for each file of input
  " 
          "features, and outputs as a matrix the result, with optional
  "
          "iVector-based speaker adaptation. Note: some configuration values
  "
          "and inputs are set via config files whose filenames are passed as
  "
          "options.  Used mostly for debugging.
  "
          "Note: if you want it to apply a log (e.g. for log-likelihoods), use
  "
          "--apply-log=true.
  "
          "
  "
          "Usage:  online2-wav-nnet2-am-compute [options] <nnet-in>
  "
          "<spk2utt-rspecifier> <wav-rspecifier> <feature-or-loglikes-wspecifier>
  "
          "The spk2utt-rspecifier can just be <utterance-id> <utterance-id> if
  "
          "you want to compute utterance by utterance.
  ";
      
      BaseFloat chunk_length_secs = 0.05;
      bool apply_log = false;
      bool pad_input = true;
      bool online = true;
  
      // feature_config includes configuration for the iVector adaptation,
      // as well as the basic features.
      OnlineNnet2FeaturePipelineConfig feature_config;  
      ParseOptions po(usage);
      po.Register("apply-log", &apply_log, "Apply a log to the result of the computation "
                  "before outputting.");
      po.Register("pad-input", &pad_input, "If true, duplicate the first and last frames "
                  "of input features as required for temporal context, to prevent #frames "
                  "of output being less than those of input.");
      po.Register("chunk-length", &chunk_length_secs,
                  "Length of chunk size in seconds, that we process.");
      po.Register("online", &online,
                  "You can set this to false to disable online iVector estimation "
                  "and have all the data for each utterance used, even at "
                  "utterance start.  This is useful where you just want the best "
                  "results and don't care about online operation.  Setting this to "
                  "false has the same effect as setting "
                  "--use-most-recent-ivector=true and --greedy-ivector-extractor=true "
                  "in the file given to --ivector-extraction-config, and "
                  "--chunk-length=-1.");
      
      feature_config.Register(&po);
      po.Read(argc, argv);
      if (po.NumArgs() != 4) {
        po.PrintUsage();
        return 1;
      }
      
      std::string nnet2_rxfilename = po.GetArg(1),
          spk2utt_rspecifier = po.GetArg(2),
          wav_rspecifier = po.GetArg(3),
          features_or_loglikes_wspecifier = po.GetArg(4);
      
      OnlineNnet2FeaturePipelineInfo feature_info(feature_config);
      if (!online) {
        feature_info.ivector_extractor_info.use_most_recent_ivector = true;
        feature_info.ivector_extractor_info.greedy_ivector_extractor = true;
        chunk_length_secs = -1.0;
      }
  
      TransitionModel trans_model;
      AmNnet am_nnet;
      {
        bool binary;
        Input ki(nnet2_rxfilename, &binary);
        trans_model.Read(ki.Stream(), binary);
        am_nnet.Read(ki.Stream(), binary);
      }
      Nnet &nnet = am_nnet.GetNnet();
      
      int64 num_done = 0, num_frames = 0;
      SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
      RandomAccessTableReader<WaveHolder> wav_reader(wav_rspecifier);
      BaseFloatCuMatrixWriter writer(features_or_loglikes_wspecifier);
      
      for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
        std::string spk = spk2utt_reader.Key();
        const std::vector<std::string> &uttlist = spk2utt_reader.Value();
        OnlineIvectorExtractorAdaptationState adaptation_state(
            feature_info.ivector_extractor_info);
        for (size_t i = 0; i < uttlist.size(); i++) {
          std::string utt = uttlist[i];
          if (!wav_reader.HasKey(utt)) {
            KALDI_WARN << "Did not find audio for utterance " << utt;
            continue;
          }
          const WaveData &wave_data = wav_reader.Value(utt);
          // get the data for channel zero (if the signal is not mono, we only
          // take the first channel).
          SubVector<BaseFloat> data(wave_data.Data(), 0);
  
          OnlineNnet2FeaturePipeline feature_pipeline(feature_info);
          feature_pipeline.SetAdaptationState(adaptation_state);
          
          BaseFloat samp_freq = wave_data.SampFreq();
          int32 chunk_length;
          if (chunk_length_secs > 0) {
            chunk_length = int32(samp_freq * chunk_length_secs);
            if (chunk_length == 0) chunk_length = 1;
          } else {
            chunk_length = std::numeric_limits<int32>::max();
          }
          
          int32 samp_offset = 0;
          while (samp_offset < data.Dim()) {
            int32 samp_remaining = data.Dim() - samp_offset;
            int32 num_samp = chunk_length < samp_remaining ? chunk_length
                                                           : samp_remaining;
            
            SubVector<BaseFloat> wave_part(data, samp_offset, num_samp);
            feature_pipeline.AcceptWaveform(samp_freq, wave_part);
            
            samp_offset += num_samp;
            if (samp_offset == data.Dim()) {
              // no more input. flush out last frames
              feature_pipeline.InputFinished();
            }
          }
          
          int32 feats_num_frames = feature_pipeline.NumFramesReady(),
                feats_dim = feature_pipeline.Dim();
          Matrix<BaseFloat> feats(feats_num_frames, feats_dim);
  
          for (int32 i = 0; i < feats_num_frames; i++) {
            SubVector<BaseFloat> frame_vector(feats, i);
            feature_pipeline.GetFrame(i, &frame_vector);
          }
  
          // In an application you might avoid updating the adaptation state if
          // you felt the utterance had low confidence.  See lat/confidence.h
          feature_pipeline.GetAdaptationState(&adaptation_state);
  
          int32 output_frames = feats.NumRows(), 
                output_dim = nnet.OutputDim();
          CuMatrix<BaseFloat> output(output_frames, output_dim),
                              feats_cu(feats);
  
          if (!pad_input)
            output_frames -= nnet.LeftContext() + nnet.RightContext();
          if (output_frames <= 0) {
            KALDI_WARN << "Skipping utterance " << utt << " because output "
                       << "would be empty.";
            continue;
          }
          
          NnetComputation(nnet, feats_cu, pad_input, &output);
  
          if (apply_log) {
            output.ApplyFloor(1.0e-20);
            output.ApplyLog();
          }
  
          writer.Write(utt, output);
          num_frames += feats.NumRows();
          num_done++;
  
          KALDI_LOG << "Processed data for utterance " << utt;
        }
      }
  
      KALDI_LOG << "Processed " << num_done << " feature files, "
                << num_frames << " frames of input were processed.";
  
      return (num_done != 0 ? 0 : 1);
    } catch(const std::exception& e) {
      std::cerr << e.what() << '
  ';
      return -1;
    }
  } // main()