Blame view

src/cudadecoder/cuda-decoder.h 40.1 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
  // cudadecoder/cuda-decoder.h
  //
  // Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
  // Hugo Braun, Justin Luitjens, Ryan Leary
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //     http://www.apache.org/licenses/LICENSE-2.0
  //
  // Unless required by applicable law or agreed to in writing, software
  // distributed under the License is distributed on an "AS IS" BASIS,
  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  // See the License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_CUDA_DECODER_CUDA_DECODER_H_
  #define KALDI_CUDA_DECODER_CUDA_DECODER_H_
  
  #include "cudadecoder/cuda-decodable-itf.h"
  #include "cudadecoder/cuda-decoder-common.h"
  #include "cudadecoder/cuda-fst.h"
  #include "nnet3/decodable-online-looped.h"
  #include "thread-pool.h"
  
  #include <cuda_runtime_api.h>
  #include <mutex>
  #include <tuple>
  #include <vector>
  namespace kaldi {
  namespace cuda_decoder {
  
  struct CudaDecoderConfig {
    BaseFloat default_beam;
    BaseFloat lattice_beam;
    int32 ntokens_pre_allocated;
    int32 main_q_capacity, aux_q_capacity;
    int32 max_active;
  
    CudaDecoderConfig()
        : default_beam(15.0),
          lattice_beam(10.0),
          ntokens_pre_allocated(2000000),
          main_q_capacity(-1),
          aux_q_capacity(-1),
          max_active(10000) {}
  
    void Register(OptionsItf *opts) {
      opts->Register("beam", &default_beam,
                     "Decoding beam. Larger->slower, more accurate. If "
                     "aux-q-capacity is too small, we may decrease the beam "
                     "dynamically to avoid overflow (adaptive beam, see "
                     "aux-q-capacity parameter)");
      opts->Register("lattice-beam", &lattice_beam,
                     "The width of the lattice beam");
      opts->Register("max-active", &max_active,
                     "At the end of each frame computation, we keep only its "
                     "best max-active tokens. One token is the instantiation of "
                     "a single arc. Typical values are within the 5k-10k range.");
      opts->Register("ntokens-pre-allocated", &ntokens_pre_allocated,
                     "Advanced - Number of tokens pre-allocated in host buffers. "
                     "If this size is exceeded the buffer will reallocate, "
                     "reducing performance.");
      std::ostringstream main_q_capacity_desc;
      main_q_capacity_desc
          << "Advanced - Capacity of the main queue : Maximum number of "
             "tokens that can be stored *after* pruning for each frame. "
             "Lower -> less memory usage, Higher -> More accurate. "
             "Tokens stored in the main queue were already selected "
             "through a max-active pre-selection. It means that for each "
             "emitting/non-emitting iteration, we can add at most "
             "~max-active tokens to the main queue. Typically only the "
             "emitting iteration creates a large number of tokens. Using "
             "main-q-capacity=k*max-active with k=4..10 should be safe. "
             "If main-q-capacity is too small, we will print a warning "
             "but prevent the overflow. The computation can safely "
             "continue, but the quality of the output may decrease "
             "(-1 = set to "
          << KALDI_CUDA_DECODER_MAX_ACTIVE_MAIN_Q_CAPACITY_FACTOR
          << "*max-active).";
      opts->Register("main-q-capacity", &main_q_capacity,
                     main_q_capacity_desc.str());
      std::ostringstream aux_q_capacity_desc;
      aux_q_capacity_desc
          << "Advanced - Capacity of the auxiliary queue : Maximum "
             "number of raw tokens that can be stored *before* pruning "
             "for each frame. Lower -> less memory usage, Higher -> More "
             "accurate. During the tokens generation, if we detect that "
             "we are getting close to saturating that capacity, we will "
             "reduce the beam dynamically (adaptive beam) to keep only "
             "the best tokens in the remaining space. If the aux queue "
             "is still too small, we will print an overflow warning, but "
             "prevent the overflow. The computation can safely continue, "
             "but the quality of the output may decrease. We strongly "
             "recommend keeping aux-q-capacity large (>400k), to avoid "
             "triggering the adaptive beam and/or the overflow "
             "(-1 = set to "
          << KALDI_CUDA_DECODER_AUX_Q_MAIN_Q_CAPACITIES_FACTOR
          << "*main-q-capacity).";
      opts->Register("aux-q-capacity", &aux_q_capacity,
                     aux_q_capacity_desc.str());
    }
  
    void Check() const {
      KALDI_ASSERT(default_beam > 0.0 && ntokens_pre_allocated >= 0 &&
                   lattice_beam >= 0.0f && max_active > 0);
    }
  
    void ComputeConfig() {
      if (main_q_capacity == -1)
        main_q_capacity =
            max_active * KALDI_CUDA_DECODER_MAX_ACTIVE_MAIN_Q_CAPACITY_FACTOR;
      if (aux_q_capacity == -1)
        aux_q_capacity =
            main_q_capacity * KALDI_CUDA_DECODER_AUX_Q_MAIN_Q_CAPACITIES_FACTOR;
    }
  };
  
  // Forward declaration.
  // Those contains CUDA code. We don't want to include their definition
  // in this header
  class DeviceParams;
  class KernelParams;
  
  class CudaDecoder {
   public:
    // Creating a new CudaDecoder, associated to the FST fst
    // nlanes and nchannels are defined as follow
  
    // A decoder channel is linked to one utterance.
    // When we need to perform decoding on an utterance,
    // we pick an available channel, call InitDecoding on that channel
    // (with that ChannelId in the channels vector in the arguments)
    // then call AdvanceDecoding whenever frames are ready for the decoder
    // for that utterance (also passing the same ChannelId to AdvanceDecoding)
    //
    // A decoder lane is where the computation actually happens
    // a decoder lane is channel, and perform the actual decoding
    // of that channel.
    // If we have 200 lanes, we can compute 200 utterances (channels)
    // at the same time. We need many lanes in parallel to saturate the big GPUs
    //
    // An analogy would be lane -> a CPU core, channel -> a software thread
    // A channel saves the current state of the decoding for a given utterance.
    // It can be kept idle until more frames are ready to be processed
    //
    // We will use as many lanes as necessary to saturate the GPU, but not more.
    // A lane has an higher memory usage than a channel. If you just want to be
    // able to
    // keep more audio channels open at the same time (when I/O is the bottleneck
    // for instance,
    // typically in the context of online decoding), you should instead use more
    // channels.
    //
    // A channel is typically way smaller in term of memory usage, and can be used
    // to oversubsribe lanes in the context of online decoding
    // For instance, we could choose nlanes=200 because it gives us good
    // performance
    // on a given GPU. It gives us an end-to-end performance of 3000 XRTF. We are
    // doing online,
    // so we only get audio at realtime speed for a given utterance/channel.
    // We then decide to receive audio from 2500 audio channels at the same time
    // (each at realtime speed),
    // and as soon as we have frames ready for nlanes=200 channels, we call
    // AdvanceDecoding on those channels
    // In that configuration, we have nlanes=200 (for performance), and
    // nchannels=2500 (to have enough audio
    // available at a given time).
    // Using nlanes=2500 in that configuration would first not be possible (out of
    // memory), but also not necessary.
    // Increasing the number of lanes is only useful if it increases performance.
    // If the GPU is saturated at nlanes=200,
    // you should not increase that number
    CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config, int32 nlanes,
                int32 nchannels);
  
    // Reads the config from config
    void ReadConfig(const CudaDecoderConfig &config);
    // Special constructor for nlanes = nchannels. Here for the non-advanced user
    // Here we can consider nchannels = batch size. If we want to decode 10
    // utterances at a time,
    // we can use nchannels = 10
    CudaDecoder(const CudaFst &fst, const CudaDecoderConfig &config,
                int32 nchannels)
        : CudaDecoder(fst, config, nchannels, nchannels) {}
    ~CudaDecoder();
  
    // InitDecoding initializes the decoding, and should only be used if you
    // intend to call AdvanceDecoding() on the channels listed in channels
    void InitDecoding(const std::vector<ChannelId> &channels);
    // Computes the heavy H2H copies of InitDecoding. Usually launched on the
    // threadpool
    void InitDecodingH2HCopies(ChannelId ichannel);
    // AdvanceDecoding on a given batch
    // a batch is defined by the channels vector
    // We can compute N channels at the same time (in the same batch)
    // where N = number of lanes, as defined in the constructor
    // AdvanceDecoding will compute as many frames as possible while running the
    // full batch
    // when at least one channel has no more frames ready to be computed,
    // AdvanceDecoding returns
    // The user then decides what to do, i.e.:
    //
    // 1) either remove the empty channel from the channels list
    // and call again AdvanceDecoding
    // 2) or swap the empty channel with another one that has frames ready
    // and call again AdvanceDecoding
    //
    // Solution 2) should be preferred because we need to run full, big batches to
    // saturate the GPU
    //
    // If max_num_frames is >= 0 it will decode no more than
    // that many frames.
    void AdvanceDecoding(const std::vector<ChannelId> &channels,
                         std::vector<CudaDecodableInterface *> &decodables,
                         int32 max_num_frames = -1);
  
    // Returns the number of frames already decoded in a given channel
    int32 NumFramesDecoded(ChannelId ichannel) const;
    // GetBestPath gets the one-best decoding traceback. If "use_final_probs" is
    // true
    // AND we reached a final state, it limits itself to final states;
    // otherwise it gets the most likely token not taking into account
    // final-probs.
    void GetBestPath(const std::vector<ChannelId> &channels,
                     std::vector<Lattice *> &fst_out_vec,
                     bool use_final_probs = true);
    // It is possible to use a threadsafe version of GetRawLattice, which is
    // ConcurrentGetRawLatticeSingleChannel()
    // Which will do the heavy CPU work associated with GetRawLattice
    // It is necessary to first call PrepareForGetRawLattice *on the main thread*
    // on the channels.
    // The main thread is the one we use to call all other functions, like
    // InitDecoding or AdvanceDecoding
    // We usually call it "cuda control thread", but it is a CPU thread
    // For example:
    // on main cpu thread : Call PrepareForGetRawLattice on channel 8,6,3
    // then:
    // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 3
    // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 8
    // on some cpu thread : Call ConcurrentGetRawLatticeSingleChannel on channel 6
    void PrepareForGetRawLattice(const std::vector<ChannelId> &channels,
                                 bool use_final_probs);
    void ConcurrentGetRawLatticeSingleChannel(ChannelId ichannel,
                                              Lattice *fst_out);
  
    // GetRawLattice gets the lattice decoding traceback (using the lattice-beam
    // in the CudaConfig parameters).
    // If "use_final_probs" is true
    // AND we reached a final state, it limits itself to final states;
    // otherwise it gets the most likely token not taking into account
    // final-probs.
    void GetRawLattice(const std::vector<ChannelId> &channels,
                       std::vector<Lattice *> &fst_out_vec, bool use_final_probs);
    // GetBestCost finds the best cost in the last tokens queue
    // for each channel in channels. If isfinal is true,
    // we also add the final cost to the token costs before
    // finding the minimum cost
    // We list all tokens that have a cost within [best; best+lattice_beam]
    // in list_lattice_tokens.
    // We alsos set has_reached_final[ichannel] to true if token associated to a
    // final state
    // exists in the last token queue of that channel
    void GetBestCost(
        const std::vector<ChannelId> &channels, bool isfinal,
        std::vector<std::pair<int32, CostType>> *argmins,
        std::vector<std::vector<std::pair<int, float>>> *list_lattice_tokens,
        std::vector<bool> *has_reached_final);
    // (optional) Giving the decoder access to the cpu thread pool
    // We will use it to compute specific CPU work, such as InitDecodingH2HCopies
    // For recurrent CPU work, such as ComputeH2HCopies, we will use dedicated CPU
    // threads
    // We will launch nworkers of those threads
    void SetThreadPoolAndStartCPUWorkers(ThreadPool *thread_pool, int32 nworkers);
  
   private:
    // Data allocation. Called in constructor
    void AllocateDeviceData();
    void AllocateHostData();
    void AllocateDeviceKernelParams();
    // Data initialization. Called in constructor
    void InitDeviceData();
    void InitHostData();
    void InitDeviceParams();
    // Computes the initial channel
    // The initial channel is used to initialize a channel
    // when a new utterance starts (we clone it into the given channel)
    void ComputeInitialChannel();
    // Updates *h_kernel_params using channels
    void SetChannelsInKernelParams(const std::vector<ChannelId> &channels);
    void ResetChannelsInKernelParams();
    // Context-switch functions
    // Used to perform the context-switch of load/saving the state of a channels
    // into a lane. When a channel will be executed on a lane, we load that
    // channel into that lane (same idea than when we load a software threads into
    // the registers of a CPU)
    void LoadChannelsStateToLanes(const std::vector<ChannelId> &channels);
    void SaveChannelsStateFromLanes();
    // We compute the decodes by batch. Each decodable in the batch has a
    // different number of frames ready
    // We compute the min number of frames ready (so that the full batch is
    // executing). If max_num_frames
    // is > 0, we apply that ceiling to the NumFramesToDecode.
    int32 NumFramesToDecode(const std::vector<ChannelId> &channels,
                            std::vector<CudaDecodableInterface *> &decodables,
                            int32 max_num_frames);
    // Expand the arcs, emitting stage. Must be called after
    // a preprocess_in_place, which happens in PostProcessingMainQueue.
    // ExpandArcsEmitting is called first when decoding a frame,
    // using the preprocessing that happened at the end of the previous frame,
    // in PostProcessingMainQueue
    void ExpandArcsEmitting();
    // ExpandArcs, non-emitting stage. Must be called after PruneAndPreprocess.
    void ExpandArcsNonEmitting();
    // If we have more than max_active_ tokens in the queue (either after an
    // expand, or at the end of the frame)
    // we will compute a new beam that will only keep a number of tokens as close
    // as possible to max_active_ tokens
    // (that number is >= max_active_) (soft topk)
    // All ApplyMaxActiveAndReduceBeam is find the right beam for that topk and
    // set it.
    // We need to then call PruneAndPreprocess (explicitly pruning tokens with
    // cost > beam)
    // Or PostProcessingMainQueue (ignoring tokens with cost > beam in the next
    // frame)
    void ApplyMaxActiveAndReduceBeam(enum QUEUE_ID queue_id);
    // Called after an ExpandArcs. Prune the aux_q (output of the ExpandArcs),
    // move the survival tokens to the main_q, do the preprocessing at the same
    // time
    // We don't need it after the last ExpandArcsNonEmitting.
    void PruneAndPreprocess();
    // Once the non-emitting is done, the main_q is final for that frame.
    // We now generate all the data associated with that main_q, such as listing
    // the different tokens sharing the same token.next_state
    // we also preprocess for the ExpandArcsEmitting of the next frame
    // Once PostProcessingMainQueue, all working data is back to its original
    // state, to make sure we're ready for the next context switch
    void PostProcessingMainQueue();
    // Moving the relevant data to host, ie the data that will be needed in
    // GetBestPath/GetRawLattice.
    // Happens when PostProcessingMainQueue is done generating that data
    void CopyMainQueueDataToHost();
    // CheckOverflow
    // If a kernel sets the flag h_q_overflow, we send a warning to stderr
    // Overflows are detected and prevented on the device. It only means
    // that we've discarded the tokens that were created after the queue was full
    // That's why we only send a warning. It is not a fatal error
    void CheckOverflow();
    // Evaluates the function func for each lane, returning the max of all return
    // values
    // (func returns int32)
    // Used for instance to ge the max number of arcs for all lanes
    // func is called with h_lanes_counters_[ilane] for each lane.
    // h_lanes_counters_
    // must be ready to be used when calling GetMaxForAllLanes (you might want to
    // call
    // CopyLaneCountersToHost[A|]sync to make sure everything is ready first)
    int32 GetMaxForAllLanes(std::function<int32(const LaneCounters &)> func);
    // Copy the lane counters back to host, async or sync
    // The lanes counters contain all the information such as main_q_end (number
    // of tokens in the main_q)
    // main_q_narcs (number of arcs) during the computation. That's why we
    // frequently copy it back to host
    // to know what to do next
    void CopyLaneCountersToHostAsync();
    void CopyLaneCountersToHostSync();
    // The selected tokens for each frame will be copied back to host. We will
    // store them on host memory, and we wil use them to create the final lattice
    // once we've reached the last frame
    // We will also copy information on those tokens that we've generated on the
    // device, such as which tokens are associated to the same FST state in the
    // same frame, or their extra cost.
    // We cannot call individuals Device2Host copies for each channel, because it
    // would lead to a lot of small copies, reducing performance. Instead we
    // concatenate all channels data into a single
    // continuous array, copy that array to host, then unpack it to the individual
    // channel vectors
    // The first step (pack then copy to host, async) is done in
    // ConcatenateData
    // The second step is done in LaunchD2H and sLaunchH2HCopies
    // A sync on cudaStream st has to happen between the two functions to make
    // sure that the copy is done
    //
    // Each lane contains X elements to be copied, where X = func(ilane)
    // That data is contained in the array (pointer, X), with pointer = src[ilane]
    // It will be concatenated in d_concat on device, then copied async into
    // h_concat
    // That copy is launched on stream st
    // The offset of the data of each lane in the concatenate array is saved in
    // *lanes_offsets_ptr
    // it will be used for unpacking in MoveConcatenatedCopyToVector
    //
    // func is called with h_lanes_counters_[ilane] for each lane.
    // h_lanes_counters_
    // must be ready to be used when calling GetMaxForAllLanes (you might want to
    // call
    // CopyLaneCountersToHost[A|]sync to make sure everything is ready first)
    // Concatenate data on device before calling the D2H copies
    void ConcatenateData();
    // Start the D2H copies used to send data back to host at the end of each
    // frames
    void LaunchD2HCopies();
    // ComputeH2HCopies
    // At the end of each frame, we copy data back to host
    // That data was concatenated into a single continous array
    // We then have to unpack it and move it inside host memory
    // This is done by ComputeH2HCopies
    void ComputeH2HCopies();
    // Takes care of preparing the data for ComputeH2HCopies
    // and check whether we can use the threadpool or we have to do the work on
    // the current thread
    void LaunchH2HCopies();
    // Function called by the CPU worker threads
    // Calls ComputeH2HCopies when triggered
    void ComputeH2HCopiesCPUWorker();
  
    template <typename T>
    void MoveConcatenatedCopyToVector(const LaneId ilane,
                                      const ChannelId ichannel,
                                      const std::vector<int32> &lanes_offsets,
                                      T *h_concat,
                                      std::vector<std::vector<T>> *vecvec);
    void WaitForH2HCopies();
    void WaitForInitDecodingH2HCopies();
    // Computes a set of static asserts on the static values
    // In theory we should do them at compile time
    void CheckStaticAsserts();
    // Can be called in GetRawLattice to do a bunch of deep asserts on the data
    // Slow, so disabled by default
    void DebugValidateLattice();
  
    //
    // Data members
    //
  
    // The CudaFst data structure contains the FST graph
    // in the CSR format, on both the GPU and CPU memory
    const CudaFst fst_;
    // Counters used by a decoder lane
    // Contains all the single values generated during computation,
    // such as the current size of the main_q, the number of arcs currently in
    // that queue
    // We load data from the channel state during context-switch (for instance the
    // size of the last token queue for that channel)
    HostLaneMatrix<LaneCounters> h_lanes_counters_;
    // Counters of channels
    // Contains all the single values saved to remember the state of a channel
    // not used during computation. Those values are loaded/saved into/from a lane
    // during context switching
    ChannelCounters *h_channels_counters_;
    // Contain the various counters used by lanes/channels, such as main_q_end,
    // main_q_narcs. On device memory (equivalent of h_channels_counters on
    // device)
    DeviceChannelMatrix<ChannelCounters> d_channels_counters_;
    DeviceLaneMatrix<LaneCounters> d_lanes_counters_;
    // Number of lanes and channels, as defined in the constructor arguments
    int32 nlanes_, nchannels_;
  
    // We will now define the data used on the GPU
    // The data is mainly linked to two token queues
    // - the main queue
    // - the auxiliary queue
    //
    // The auxiliary queue is used to store the raw output of ExpandArcs.
    // We then prune that aux queue (and apply max-active) and move the survival
    // tokens in the main queue.
    // Tokens stored in the main q can then be used to generate new tokens (using
    // ExpandArcs)
    // We also generate more information about what's in the main_q at the end of
    // a frame (in PostProcessingMainQueue)
    //
    // As a reminder, here's the data structure of a token :
    //
    // struct Token { state, cost, prev_token, arc_idx }
    //
    // Please keep in mind that this structure is also used in the context
    // of lattice decoding. We are not storing a list of forward links like in the
    // CPU decoder. A token stays an instanciation of an single arc.
    //
    // For performance reasons, we split the tokens in three parts :
    // { state } , { cost }, { prev_token, arc_idx }
    // Each part has its associated queue
    // For instance, d_main_q_state[i], d_main_q_cost[i], d_main_q_info[i]
    // all refer to the same token (at index i)
    // The data structure InfoToken contains { prev_token, arc_idx }
    // We also store the acoustic costs independently in d_main_q_acoustic_cost_
    //
    // The data is eiher linked to a channel, or to a lane.
    //
    // Channel data (DeviceChannelMatrix):
    //
    // The data linked with a channel contains the data of frame i we need to
    // remember
    // to compute frame i+1. It is the list of tokens from frame i, with some
    // additional info
    // (ie the prefix sum of the emitting arcs degrees from those tokens).
    // We are only storing d_main_q_state_and_cost_ as channel data because that's
    // all we need in a token to compute
    // frame i+1. We don't need token.arc_idx or token.prev_token.
    // The reason why we also store that prefix sum is because we do the emitting
    // preprocessing
    // at the end of frame i. The reason for that is that we need infos from the
    // hashmap to do that preprocessing.
    // The hashmap is always cleared at the end of a frame. So we need to do the
    // preprocessing at the end of frame i,
    // and then save d_main_q_degrees_prefix_sum_. d_main_q_arc_offsets is
    // generated also during preprocessing.
    //
    // Lane data (DeviceLaneMatrix):
    //
    // The lane data is everything we use during computation, but which we reset
    // at the end of each frame.
    // For instance we use a hashmap at some point during the computation, but at
    // the end of each frame we reset it. That
    // way that hashmap is able to compute whichever channel the next time
    // AdvanceDecoding is called. The reasons why we do that is :
    //
    // - We use context switching. Before and after every frames, we can do a
    // context switching. Which means that a lane cannot save a channel's state
    // in any way once AdvanceDecoding returns. e.g., during a call of
    // AdvanceDecoding, ilane=2 may compute 5 frames from channel=57 (as defined
    // in the std::vector<ChannelId> channels).
    // In the next call, the same ilane=2 may compute 10 frames from channel=231.
    // A lane data has to be reset to its original state at the end of each
    // AdvanceDecoding call.
    // If somehow some data has to be saved, it needs to be declared as channel
    // data.
    //
    // - The reason why we make the distinction between lane and channel data (in
    // theory everything could be consider channel data), is because
    // a lane uses more memory than a channel. In the context of online decoding,
    // we need to create a lot channels, and we need them to be as small as
    // possible in memory.
    // Everything that can be reused between channels is stored as lane data.
  
    //
    // Channel data members:
    //
  
    DeviceChannelMatrix<int2> d_main_q_state_and_cost_;
    // Prefix sum of the arc's degrees in the main_q. Used by ExpandArcs,
    // set in the preprocess stages (either PruneAndPreprocess or
    // preprocess_in_place in PostProcessingMainQueue)
    DeviceChannelMatrix<int32> d_main_q_degrees_prefix_sum_;
    // d_main_q_arc_offsets[i] = fst_.arc_offsets[d_main_q_state[i]]
    // we pay the price for the random memory accesses of fst_.arc_offsets in the
    // preprocess kernel
    // we cache the results in d_main_q_arc_offsets which will be read in a
    // coalesced fashion in expand
    DeviceChannelMatrix<int32> d_main_q_arc_offsets_;
  
    //
    // Lane data members:
    //
  
    // InfoToken
    // Usually contains {prev_token, arc_idx}
    // If more than one token is associated to a fst_state,
    // it will contain where to find the list of those tokens in
    // d_main_q_extra_prev_tokens
    // ie {offset,size} in that list. We differentiate the two situations by
    // calling InfoToken.IsUniqueTokenForStateAndFrame()
    DeviceLaneMatrix<InfoToken> d_main_q_info_;
    // Acoustic cost of a given token
    DeviceLaneMatrix<CostType> d_main_q_acoustic_cost_;
    // At the end of a frame, we use a hashmap to detect the tokens that are
    // associated with the same FST state S
    // We do it that the very end, to only use the hashmap on post-prune, post-max
    // active tokens
    DeviceLaneMatrix<HashmapValueT> d_hashmap_values_;
    // Reminder: in the GPU lattice decoder, a token is always associated
    // to a single arc. Which means that multiple tokens in the same frame
    // can be associated with the same FST state.
    //
    // We are NOT listing those duplicates as ForwardLinks in an unique meta-token
    // like in the CPU lattice decoder
    //
    // When more than one token is associated to a single FST state,
    // we will list those tokens into another list : d_main_q_extra_prev_tokens
    // we will also save data useful in such a case, such as the extra_cost of a
    // token compared to the best for that state
    DeviceLaneMatrix<InfoToken> d_main_q_extra_prev_tokens_;
    DeviceLaneMatrix<float2> d_main_q_extra_and_acoustic_cost_;
    // Histogram. Used to perform the histogram of the token costs
    // in the main_q. Used to perform a soft topk of the main_q (max-active)
    DeviceLaneMatrix<int32> d_histograms_;
    // When filling the hashmap in PostProcessingMainQueue, we create a hashmap
    // value for each FST state
    // presents in the main_q (if at least one token is associated with that
    // state)
    // d_main_q_state_hash_idx_[token_idx] is the index of the state token.state
    // in the hashmap
    // Stored into a FSTStateHashIndex, which is actually a int32.
    // FSTStateHashIndex should only
    // be accessed through [Get|Set]FSTStateHashIndex, because it uses the bit
    // sign to also remember if that token is the representative of that state.
    // If only one token is associated with S, its representative will be itself
    DeviceLaneMatrix<FSTStateHashIndex> d_main_q_state_hash_idx_;
    // local_idx of the extra cost list for a state
    // For a given state S, first token associated with S will have local_idx=0
    // the second one local_idx=1, etc. The order of the local_idxs is random
    DeviceLaneMatrix<int32> d_main_q_n_extra_prev_tokens_local_idx_;
    // Where to write the extra_prev_tokens in the d_main_q_extra_prev_tokens_
    // queue
    DeviceLaneMatrix<int32> d_main_q_extra_prev_tokens_prefix_sum_;
    // Used when computing the prefix_sums in preprocess_in_place. Stores
    // the local_sums per CTA
    DeviceLaneMatrix<int2> d_main_q_block_sums_prefix_sum_;
    // Defining the aux_q. Filled by ExpandArcs.
    // The tokens are moved to the main_q by PruneAndPreprocess
    DeviceLaneMatrix<int2> d_aux_q_state_and_cost_;
    DeviceLaneMatrix<InfoToken> d_aux_q_info_;
    // Dedicated space for the concat of extra_cost. We should reuse memory
    DeviceLaneMatrix<float2> d_extra_and_acoustic_cost_concat_matrix_;
    DeviceLaneMatrix<InfoToken> d_extra_prev_tokens_concat_matrix_;
    DeviceLaneMatrix<CostType> d_acoustic_cost_concat_matrix_;
    DeviceLaneMatrix<InfoToken> d_infotoken_concat_matrix_;
    // We will list in d_list_final_tokens_in_main_q all tokens within [min_cost;
    // min_cost+lattice_beam]
    // It is used when calling GetBestCost
    // We only use an interface here because we will actually reuse data from
    // d_aux_q_state_and_cost
    // We are done using the aux_q when GetBestCost is called, so we can reuse
    // that memory
    HostLaneMatrix<int2> h_list_final_tokens_in_main_q_;
    // Parameters used by the kernels
    // DeviceParams contains all the parameters that won't change
    // i.e. memory address of the main_q for instance
    // KernelParams contains information that can change.
    // For instance which channel is executing on which lane
    DeviceParams *h_device_params_;
    KernelParams *h_kernel_params_;
    std::vector<ChannelId> channel_to_compute_;
    int32 nlanes_used_;  // number of lanes used in h_kernel_params_
    // Initial lane
    // When starting a new utterance,
    // init_channel_id is used to initialize a channel
    int32 init_channel_id_;
    // CUDA streams used by the decoder
    cudaStream_t compute_st_, copy_st_;
    // Parameters extracted from CudaDecoderConfig
    // Those are defined in CudaDecoderConfig
    CostType default_beam_;
    CostType lattice_beam_;
    int32 ntokens_pre_allocated_;
    int32 max_active_;  // Target value from the parameters
    int32 aux_q_capacity_;
    int32 main_q_capacity_;
    // Hashmap capacity. Multiple of max_tokens_per_frame
    int32 hashmap_capacity_;
    // Static segment of the adaptive beam. Cf InitDeviceParams
    int32 adaptive_beam_static_segment_;
    // The first index of all the following vectors (or vector<vector>)
    // is the ChannelId. e.g., to get the number of frames decoded in channel 2,
    // look into num_frames_decoded_[2].
  
    // Keep track of the number of frames decoded in the current file.
    std::vector<int32> num_frames_decoded_;
    // Offsets of each frame in h_all_tokens_info_
    std::vector<std::vector<int32>> frame_offsets_;
    // Data storage. We store on host what we will need in
    // GetRawLattice/GetBestPath
    std::vector<std::vector<InfoToken>> h_all_tokens_info_;
    std::vector<std::vector<CostType>> h_all_tokens_acoustic_cost_;
    std::vector<std::vector<InfoToken>> h_all_tokens_extra_prev_tokens_;
    std::vector<std::vector<float2>>
        h_all_tokens_extra_prev_tokens_extra_and_acoustic_cost_;
    std::vector<std::mutex> channel_lock_;  // at some point we should switch to a
                                            // shared_lock (to be able to compute
                                            // partial lattices while still
                                            // streaming new data for this
                                            // channel)
    bool worker_threads_running_;
    // For each channel, set by PrepareForGetRawLattice
    // argmin cost, list of the tokens within [best_cost;best_cost+lattice_beam]
    // and if we've reached a final token. Set by PrepareForGetRawLattice.
    std::vector<std::pair<int32, CostType>> h_all_argmin_cost_;
    std::vector<std::vector<std::pair<int, float>>> h_all_final_tokens_list_;
    std::vector<bool> h_all_has_reached_final_;
  
    // Pinned memory arrays. Used for the DeviceToHost copies
    float2 *h_extra_and_acoustic_cost_concat_, *d_extra_and_acoustic_cost_concat_;
    InfoToken *h_infotoken_concat_, *d_infotoken_concat_;
    CostType *h_acoustic_cost_concat_, *d_acoustic_cost_concat_;
    InfoToken *h_extra_prev_tokens_concat_, *d_extra_prev_tokens_concat_;
    // second memory space used for double buffering
    float2 *h_extra_and_acoustic_cost_concat_tmp_;
    InfoToken *h_infotoken_concat_tmp_;
    CostType *h_acoustic_cost_concat_tmp_;
    InfoToken *h_extra_prev_tokens_concat_tmp_;
    // Offsets used in MoveConcatenatedCopyToVector
    std::vector<int32> h_main_q_end_lane_offsets_;
    std::vector<int32> h_emitting_main_q_end_lane_offsets_;
    std::vector<int32> h_n_extra_prev_tokens_lane_offsets_;
    // Used when calling GetBestCost
    std::vector<std::pair<int32, CostType>> argmins_;
    std::vector<bool> has_reached_final_;
    std::vector<std::vector<std::pair<int32, CostType>>>
        list_finals_token_idx_and_cost_;
    bool compute_max_active_;
    cudaEvent_t nnet3_done_evt_;
    cudaEvent_t d2h_copy_acoustic_evt_;
    cudaEvent_t d2h_copy_infotoken_evt_;
    cudaEvent_t d2h_copy_extra_prev_tokens_evt_;
    cudaEvent_t concatenated_data_ready_evt_;
    cudaEvent_t lane_offsets_ready_evt_;
    // GetRawLattice helper
    // Data used when building the lattice in GetRawLattice
  
    // few typedef to make GetRawLattice easier to understand
    // Returns a unique id for each (iframe, fst_state) pair
    // We need to be able to quickly identity a (iframe, fst_state) ID
    //
    // A lattice state is defined by the pair (iframe, fst_state)
    // A token is associated to a lattice state (iframe, token.next_state)
    // Multiple token in the same frame can be associated to the same lattice
    // state
    // (they all go to the same token.next_state)
    // We need to quickly identify what is the lattice state of a token.
    // We are able to do that through GetLatticeStateInternalId(token),
    // which returns the internal unique ID for each lattice state for a token
    //
    // When we build the output lattice, we a get new lattice state
    // output_lattice_state = fst_out->AddState()
    // We call this one OutputLatticeState
    // The conversion between the two is done through maps
    // [curr|prev]_f_raw_lattice_state_
    typedef int32 LatticeStateInternalId;
    typedef StateId OutputLatticeState;
    typedef int32 TokenId;
    LatticeStateInternalId GetLatticeStateInternalId(int32 total_ntokens,
                                                     TokenId token_idx,
                                                     InfoToken token);
    // Keeping track of a variety of info about states in the lattice
    // - token_extra_cost. A path going from the current lattice_state to the
    // end has an extra cost
    // compared to the best path (which has an extra cost of 0).
    // token_extra_cost is the minimum of the extra_cost of all paths going from
    // the current lattice_state
    // to the final frame.
    // - fst_lattice_state is the StateId of the lattice_state in fst_out (in
    // the output lattice). lattice_state is an internal state used in
    // GetRawLattice.
    // - is_state_closed is true if the token_extra_cost has been read by
    // another token. It means that the
    // token_extra_cost value has been used, and if we modify token_extra_cost
    // again, we may need to recompute the current frame (so that everyone uses
    // the latest
    // token_extra_cost value)
    struct RawLatticeState {
      CostType token_extra_cost;
      OutputLatticeState fst_lattice_state;
      bool is_state_closed;
    };
    // extra_cost_min_delta_ used in the must_replay_frame situation. Please read
    // comments
    // associated with must_replay_frame in GetRawLattice to understand what it
    // does
    CostType extra_cost_min_delta_;
    ThreadPool *thread_pool_;
    std::vector<std::thread> cpu_dedicated_threads_;
    int32 n_threads_used_;
    std::vector<ChannelId> lanes2channels_todo_;
    std::atomic<int> n_acoustic_h2h_copies_todo_;
    std::atomic<int> n_extra_prev_tokens_h2h_copies_todo_;
    std::atomic<int> n_d2h_copies_ready_;
    std::atomic<int> n_infotoken_h2h_copies_todo_;
    int32 n_h2h_task_not_done_;
    int32 n_init_decoding_h2h_task_not_done_;
    std::atomic<int> n_h2h_main_task_todo_;
    std::mutex n_h2h_task_not_done_mutex_;
    std::mutex n_init_decoding_h2h_task_not_done_mutex_;
    std::mutex n_h2h_main_task_todo_mutex_;
    std::condition_variable n_h2h_main_task_todo_cv_;
    std::condition_variable h2h_done_;
    std::condition_variable init_decoding_h2h_done_;
    std::atomic<bool> active_wait_;
    bool h2h_threads_running_;
    // Using the output from GetBestPath, we add the best tokens (as selected in
    // GetBestCost)
    // from the final frame to the output lattice. We also fill the data
    // structures
    // (such as q_curr_frame_todo_, or curr_f_raw_lattice_state_) accordingly
    void AddFinalTokensToLattice(
        ChannelId ichannel,
        std::vector<std::pair<TokenId, InfoToken>> *q_curr_frame_todo,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *curr_f_raw_lattice_state,
        Lattice *fst_out);
    // Check if a token should be added to the lattice. If it should, then
    // keep_arc will be true
    void ConsiderTokenForLattice(
        ChannelId ichannel, int32 iprev, int32 total_ntokens, TokenId token_idx,
        OutputLatticeState fst_lattice_start, InfoToken *tok_beg,
        float2 *arc_extra_cost_beg, CostType token_extra_cost,
        TokenId list_prev_token_idx, int32 list_arc_idx,
        InfoToken *list_prev_token, CostType *this_arc_prev_token_extra_cost,
        CostType *acoustic_cost, OutputLatticeState *lattice_src_state,
        bool *keep_arc, bool *dbg_found_zero);
    // Add the arc to the lattice. Also updates what needs to be updated in the
    // GetRawLattice datastructures.
    void AddArcToLattice(
        int32 list_arc_idx, TokenId list_prev_token_idx,
        InfoToken list_prev_token, int32 curr_frame_offset,
        CostType acoustic_cost, CostType this_arc_prev_token_extra_cost,
        LatticeStateInternalId src_state_internal_id,
        OutputLatticeState fst_lattice_start,
        OutputLatticeState to_fst_lattice_state,
        std::vector<std::pair<TokenId, InfoToken>> *q_curr_frame_todo,
        std::vector<std::pair<TokenId, InfoToken>> *q_prev_frame_todo,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *curr_f_raw_lattice_state,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *prev_f_raw_lattice_state,
        std::unordered_set<int32> *f_arc_idx_added, Lattice *fst_out,
        bool *must_replay_frame);
    // Read a token information
    void GetTokenRawLatticeData(
        TokenId token_idx, InfoToken token, int32 total_ntokens,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *curr_f_raw_lattice_state,
        CostType *token_extra_cost, OutputLatticeState *to_fst_lattice_state);
  
    // A token is an instance of an arc. It goes to a FST state (token.next_state)
    // Multiple token in the same frame can go to the same FST state.
    // GetSameFSTStateTokenList
    // returns that list
    void GetSameFSTStateTokenList(ChannelId ichannel, InfoToken token,
                                  InfoToken **tok_beg,
                                  float2 **arc_extra_cost_beg, int32 *nprevs);
  
    // Swap datastructures at the end of a frame. prev becomes curr (we go
    // backward)
    //
    void SwapPrevAndCurrLatticeMap(
        int32 iframe, bool dbg_found_best_path,
        std::vector<std::pair<TokenId, InfoToken>> *q_curr_frame_todo,
        std::vector<std::pair<TokenId, InfoToken>> *q_prev_frame_todo,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *curr_f_raw_lattice_state,
        std::unordered_map<LatticeStateInternalId, RawLatticeState>
            *prev_f_raw_lattice_state,
        std::unordered_set<int32> *f_arc_idx_added);
    KALDI_DISALLOW_COPY_AND_ASSIGN(CudaDecoder);
  };
  
  }  // end namespace cuda_decoder
  }  // end namespace kaldi
  
  #endif  // KALDI_CUDA_DECODER_CUDA_DECODER_H_