Blame view

src/cudadecoder/batched-threaded-nnet3-cuda-pipeline.h 15.4 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  // cudadecoder/batched-threaded-nnet3-cuda-pipeline.h
  //
  // Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
  // Hugo Braun, Justin Luitjens, Ryan Leary
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //     http://www.apache.org/licenses/LICENSE-2.0
  //
  // Unless required by applicable law or agreed to in writing, software
  // distributed under the License is distributed on an "AS IS" BASIS,
  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  // See the License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_
  #define KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_
  
  #include <atomic>
  #include <thread>
  
  #include "cudadecoder/cuda-decoder.h"
  #include "decodable-cumatrix.h"
  #include "feat/wave-reader.h"
  #include "lat/determinize-lattice-pruned.h"
  #include "nnet3/nnet-batch-compute.h"
  #include "online2/online-nnet2-feature-pipeline.h"
  #include "cudafeat/online-cuda-feature-pipeline.h"
  #include "thread-pool.h"
  
  // If num_channels sets to automatic,
  // num_channels = [this define] * max_batch_size
  #define KALDI_CUDA_DECODER_CHANNELS_BATCH_SIZE_RATIO 1.3
  
  namespace kaldi {
  namespace cuda_decoder {
  
  /* BatchedThreadedNnet3CudaPipelineConfig
   * This class is a common configuration class for the various components
   * of a batched cuda multi-threaded pipeline.  It defines a single place
   * to control all operations and ensures that the various componets
   * match configurations
   */
  // configuration options common to the BatchedThreadedNnet3CudaPipeline and
  // BatchedThreadedNnet3CudaPipeline
  struct BatchedThreadedNnet3CudaPipelineConfig {
    BatchedThreadedNnet3CudaPipelineConfig()
        : max_batch_size(200),
          num_channels(-1),
          batch_drain_size(10),
          num_control_threads(2),
          num_worker_threads(20),
          determinize_lattice(true),
          max_pending_tasks(4000),
          num_decoder_copy_threads(2),
          gpu_feature_extract(true) {};
    void Register(OptionsItf *po) {
      po->Register("max-batch-size", &max_batch_size,
                   "The maximum batch size to be used by the decoder. "
                   "This is also the number of lanes in the CudaDecoder. "
                   "Larger = Faster and more GPU memory used.");
      std::ostringstream num_channels_desc;
      num_channels_desc
          << "The number of channels "
             "allocated to the cuda decoder.  This should be larger "
             "than max_batch_size.  Each channel consumes a small "
             "amount of memory but also allows us to better overlap "
             "computation"
             " (-1 = set to "
          << KALDI_CUDA_DECODER_CHANNELS_BATCH_SIZE_RATIO << "*max-batch-size).";
      po->Register("num-channels", &num_channels, num_channels_desc.str());
      po->Register("batch-drain-size", &batch_drain_size,
                   "How far to drain the batch before refilling work. This "
                   "batches pre/post decode work.");
      po->Register("cuda-control-threads", &num_control_threads,
                   "The number of pipeline control threads for the CUDA work. "
                   "e.g. 2 control threads -> 2 independent CUDA pipeline (nnet3 "
                   "and decoder).");
      po->Register(
          "cuda-worker-threads", &num_worker_threads,
          "The total number of CPU threads launched to process CPU tasks.");
      po->Register("determinize-lattice", &determinize_lattice,
                   "Determinize the lattice before output.");
      po->Register("max-outstanding-queue-length", &max_pending_tasks,
                   "Number of files to allow to be outstanding at a time. When "
                   "the number of files is larger than this handles will be "
                   "closed before opening new ones in FIFO order.");
      po->Register("cuda-decoder-copy-threads", &num_decoder_copy_threads,
                   "Advanced - Number of worker threads used in the decoder for "
                   "the host to host copies.");
      po->Register("gpu-feature-extract", &gpu_feature_extract,
                   "Extract features on the GPU.  This reduces CPU overhead "
                   "leading to better scalability but may reduce overall "
                   "performance for a single GPU.");
  
      feature_opts.Register(po);
      decoder_opts.Register(po);
      det_opts.Register(po);
      compute_opts.Register(po);
    }
    int max_batch_size;
    int num_channels;
    int batch_drain_size;
    int num_control_threads;
    int num_worker_threads;
    bool determinize_lattice;
    int max_pending_tasks;
    int num_decoder_copy_threads;
    bool gpu_feature_extract;
  
    void ComputeConfig() {
      if (num_channels == -1)
        num_channels =
            max_batch_size * KALDI_CUDA_DECODER_CHANNELS_BATCH_SIZE_RATIO;
    }
  
    OnlineNnet2FeaturePipelineConfig feature_opts;      // constant readonly
    CudaDecoderConfig decoder_opts;                     // constant readonly
    fst::DeterminizeLatticePhonePrunedOptions det_opts; // constant readonly
    nnet3::NnetBatchComputerOptions compute_opts;       // constant readonly
  };
  
  /*
   * BatchedThreadedNnet3CudaPipeline uses multiple levels of parallelism in order to
   * decode quickly on CUDA GPUs. This is the primary interface for cuda decoding.
   * For examples of how to use this decoder see cudadecoder/README and
   * cudadecoderbin/batched-wav-nnet3-cuda.cc
   */
  class BatchedThreadedNnet3CudaPipeline {
  public:
   BatchedThreadedNnet3CudaPipeline(
       const BatchedThreadedNnet3CudaPipelineConfig &config)
       : config_(config), all_group_tasks_not_done_(0) {
     config_.ComputeConfig();
   };
  
   // allocates reusable objects that are common across all decodings
   void Initialize(const fst::Fst<fst::StdArc> &decode_fst,
                   const nnet3::AmNnetSimple &nnet,
                   const TransitionModel &trans_model);
  
   // deallocates reusable objects
   void Finalize();
  
   // query a specific key to see if compute on it is complete
   bool isFinished(const std::string &key);
  
   // remove an audio file from the decoding and clean up resources
   void CloseDecodeHandle(const std::string &key);
   void CloseAllDecodeHandlesForGroup(const std::string &group);
   void CloseAllDecodeHandles();
  
   // Adds a decoding task to the decoder
   // When passing in a vector of data, the caller must ensure the data exists
   // until the CloseDecodeHandle/WaitForAllTasks is called
   // callback is called once task is done and we pass it the final lattice
   // callback can be used to compute lattice rescoring, find best path in
   // lattice, writing lattice to disk, etc.
   // Important: callback is launched in the threadpool. It must be threadsafe.
   // For instance, if writing to disk, or to stdout,
   // use a lock:
   // e.g. :
   // {
   // 	std::lock_guard<std::mutex> lock(global_mutex);
   // 	// write lattice to disk
   //    // lock is released in the destructor of lock_guard<>
   // }
   void OpenDecodeHandle(
       const std::string &key, const WaveData &wave_data,
       const std::string &group = std::string(),
       const std::function<void(CompactLattice &clat)> &callback =
           std::function<void(CompactLattice &clat)>());
   // When passing in a vector of data, the caller must ensure the data exists
   // until the CloseDecodeHandle is called
   void OpenDecodeHandle(
       const std::string &key, const VectorBase<BaseFloat> &wave_data,
       float sample_rate, const std::string &group = std::string(),
       const std::function<void(CompactLattice &clat)> &callback =
           std::function<void(CompactLattice &clat)>());
  
   // Copies the raw lattice for decoded handle "key" into lat
   bool GetRawLattice(const std::string &key, Lattice *lat);
   // Determinizes raw lattice and returns a compact lattice
   bool GetLattice(const std::string &key, CompactLattice *lat);
  
   int32 GetNumberOfTasksPending();
  
   // Wait for all tasks to complete
   void WaitForAllTasks();
   // Wait for all tasks in the group to complete
   void WaitForGroup(const std::string &group);
   // Check if a group is available. Returns if not.
   bool IsGroupCompleted(const std::string &group);
   // Wait for any group to complete, then returns which group completed
   std::string WaitForAnyGroup();
   // Check if any group is available. If one is available, set its name in *group
   bool IsAnyGroupCompleted(std::string *group);
   inline int NumPendingTasks() {
     return (tasks_back_ - tasks_front_ + config_.max_pending_tasks + 1) %
            (config_.max_pending_tasks + 1);
    };
  
  private:
   // Task data used during computation
   // Is cleared when task is completed
   struct TaskData {
     Vector<BaseFloat> raw_data;  // Wave input data when wave_reader passed
     std::shared_ptr<SubVector<BaseFloat>>
         wave_samples;  // Used as a pointer to either the raw
                        // data or the samples passed
     float sample_frequency;
     Vector<BaseFloat> ivector_features_cpu;
     Matrix<BaseFloat> input_features_cpu;
     CuVector<BaseFloat> ivector_features;
     CuMatrix<BaseFloat> input_features;
     CuMatrix<BaseFloat> posteriors;
  
     TaskData(const WaveData &wave_data_in)
         : wave_samples(NULL), sample_frequency(0) {
       raw_data.Resize(
           wave_data_in.Data().NumRows() * wave_data_in.Data().NumCols(),
           kUndefined);
       memcpy(raw_data.Data(), wave_data_in.Data().Data(),
              raw_data.Dim() * sizeof(BaseFloat));
       wave_samples =
           std::make_shared<SubVector<BaseFloat>>(raw_data, 0, raw_data.Dim());
       sample_frequency = wave_data_in.SampFreq();
     };
  
     // Init when raw data is passed in.  This data is shallow copied.
     TaskData(const VectorBase<BaseFloat> &wave_data_in, float sample_rate) {
       wave_samples = std::make_shared<SubVector<BaseFloat>>(wave_data_in, 0,
                                                             wave_data_in.Dim());
       sample_frequency = sample_rate;
     }
   };
  
   // State needed for each decode task.
   // This state can be passed around by reference or pointer safely
   // and provides a convieniet way to store all decoding state.
   struct TaskState {
     std::string key;
     std::string group;  // group for that task. "" is default
     bool error;
     std::string error_string;
  
     std::shared_ptr<TaskData> task_data;
  
     int32 ichannel;              // associated CudaDecoder channel
     Lattice lat;                 // Raw Lattice output
     CompactLattice dlat;         // Determinized lattice output.  Only set if
                                  // determinize-lattice=true
     std::atomic<bool> finished;  // Tells master thread if task has finished
                                  // execution
  
     bool determinized;
  
     // (optional) callback is called task is finished and we have a lattice
     // ready
     // that way we can compute all CPU tasks in the threadpool (lattice
     // rescoring, find best path in lattice, etc.)
     std::function<void(CompactLattice &clat)> callback;
  
     TaskState() : error(false), finished(false), determinized(false) {}
  
     // Init when wave data is passed directly in.  This data is deep copied.
     void Init(const std::string &key_in, const WaveData &wave_data_in) {
       task_data = std::make_shared<TaskData>(wave_data_in);
       key = key_in;
     };
     // Init when raw data is passed in.  This data is shallow copied.
     void Init(const std::string &key_in,
               const VectorBase<BaseFloat> &wave_data_in, float sample_rate) {
       task_data = std::make_shared<TaskData>(wave_data_in, sample_rate);
       key = key_in;
     }
    };
  
    // Creating a new task in the hashmaps
    TaskState *AddTask(const std::string &key, const std::string &group);
  
    // Holds the current channel state for a worker
    struct ChannelState {
      std::vector<ChannelId> channels;
      std::vector<ChannelId> free_channels;
      std::vector<ChannelId> completed_channels;
      std::mutex free_channels_mutex;
    };
  
    // Adds task to the PendingTaskQueue
    void AddTaskToPendingTaskQueue(TaskState *task);
  
    // Attempts to fill the batch from the task queue.  May not fully fill the
    // batch.
    void AquireAdditionalTasks(CudaDecoder &cuda_decoder,
                               ChannelState &channel_state,
                               std::vector<TaskState *> &tasks);
  
    // Computes Features for a single decode instance.
    void ComputeOneFeatureCPU(TaskState *task);
  
    // Computes features across the tasks[first,tasks.size()
    void ComputeBatchFeatures(int32 first,
                              std::vector<TaskState *> &tasks,
                              OnlineCudaFeaturePipeline &feature_pipeline);
  
    // Computes Nnet across the current decode batch
    void ComputeBatchNnet(nnet3::NnetBatchComputer &computer, int32 first,
                          std::vector<TaskState *> &tasks);
  
    // Allocates decodables for tasks in the range of
    // dstates[first,dstates.size())
    void AllocateDecodables(int32 first, std::vector<TaskState *> &tasks,
                            std::vector<CudaDecodableInterface *> &decodables);
  
    // Removes all completed channels from the channel list.
    // Also enqueues up work for post processing
    void
    RemoveCompletedChannels(CudaDecoder &cuda_decoder,
                            ChannelState &channel_state,
                            std::vector<CudaDecodableInterface *> &decodables,
                            std::vector<TaskState *> &tasks);
  
    // For each completed decode perform post processing work and clean up
    void PostDecodeProcessing(CudaDecoder &cuda_decoder,
                              ChannelState &channel_state,
                              std::vector<CudaDecodableInterface *> &decodables,
                              std::vector<TaskState *> &tasks);
  
    // Calls ConcurrentGetRawLatticeSingleChannel and Determinize
    // on a dedicated CPU worker thread at the end of the decode
    void CompleteTask(CudaDecoder *cuda_decoder, ChannelState *channel_state,
                      TaskState *state);
  
    // Determinize one lattice
    void DeterminizeOneLattice(TaskState *task);
    // Thread execution function.  This is a single worker thread which processes
    // input.
    void ExecuteWorker(int threadId);
  
    BatchedThreadedNnet3CudaPipelineConfig config_;
  
    CudaFst cuda_fst_;
    const TransitionModel *trans_model_;
    const nnet3::AmNnetSimple *am_nnet_;
    nnet3::DecodableNnetSimpleLoopedInfo *decodable_info_;
    OnlineNnet2FeaturePipelineInfo *feature_info_;
  
    std::mutex tasks_mutex_; // protects tasks_front_ and pending_task_queue_ for
                             // workers
    std::mutex tasks_add_mutex_; // protect OpenDecodeHandle if multiple threads
                                 // access
    std::mutex tasks_lookup_mutex_; // protext tasks_lookup map
    std::condition_variable tasks_lookup_cv_;
    std::atomic<int> tasks_front_, tasks_back_;
    TaskState **pending_task_queue_;
  
    std::atomic<bool> exit_;      // signals threads to exit
    std::atomic<int> numStarted_; // signals master how many threads have started
  
    ThreadPool *work_pool_; // thread pool for CPU work
    std::map<std::string, int32> group_tasks_not_done_;
    int32 all_group_tasks_not_done_;
    std::mutex group_tasks_mutex_;
    std::condition_variable group_done_cv_;
    std::unordered_multimap<std::string, TaskState *>
        tasks_group_lookup_;  // group -> list of tasks
    std::unordered_map<std::string, TaskState>
        tasks_lookup_;                              // Contains a map of
                                                    // utterance to TaskState
    std::vector<std::thread> thread_contexts_;      // A list of thread contexts
  };
  
  }  // end namespace cuda_decoder
  } // end namespace kaldi.
  
  #endif  // KALDI_CUDA_DECODER_BATCHED_THREADED_CUDA_DECODER_H_