Blame view

src/lat/compose-lattice-pruned.cc 44 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
  // lat/compose-lattice-pruned.cc
  
  // Copyright 2009-2012  Microsoft Corporation
  //           2012-2013  Johns Hopkins University (Author: Daniel Povey)
  //                2014  Guoguo Chen
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #include "lat/compose-lattice-pruned.h"
  #include "lat/lattice-functions.h"
  
  namespace kaldi {
  
  /**
     PrunedCompactLatticeComposer implements an algorithm for pruned composition.
     It uses a heuristic (like the heuristics used in A*) to estimate the
     cost to the end of a graph, of the best path that we might get if
     we expand a particular transition out of a particular state.  This enables
     us to use a priority queue to expand arcs in the composed result in an
     order roughly from most promising to least promising.
  
     Because some of the quantities used in the heuristic are hard to efficiently
     keep updated as the composed output is incrementally added to, we
     periodically recompute these quantities (c.f. RecomputePruningInfo()).
     In order to prevent this periodic recomputation from dominating the time
     taken to produce the lattice, we recompute these things on a schedule where,
     between each computation, we allow the size of the output to grow by
     a constant factor (default: 1.5).  Since the time taken to do the
     recomputation of quantities used in the heuristic takes time linear in the
     size of the so-far existing composed output, doing so on this type of schedule
     will add no more than a constant factor to the runtime.
  
   */
  class PrunedCompactLatticeComposer {
   public:
    PrunedCompactLatticeComposer(
        const ComposeLatticePrunedOptions &opts,
        const CompactLattice &clat,
        fst::DeterministicOnDemandFst<fst::StdArc> *det_fst,
        CompactLattice* composed_clat);
  
    // Does the composition.  You must call this just once per object.
    void Compose();
  
   private:
  
    // Gets the num-arcs limit for this iteration of the algorithm, which will be
    // opts_.initial_num_arcs if there are currently no arcs; or otherwise
    // opts_.growth_ration * the current number of arcs (subject to the
    // opts_.max_arcs limit if we have already reached a final-state).  This helps
    // ensure that we call RecomputePruningInfo() on an appropriate schedule.
    int32 GetCurrentArcLimit() const;
  
    // This function, called just once at the start, computes all the static
    // information about the input lattice 'clat', in lat_state_info_.  (however,
    // the 'composed_states' members are just set to the empty vector for now.
    void ComputeLatticeStateInfo();
  
    // Called just once at the start, this sets up the first state in the
    // composed output.
    void AddFirstState();
  
    // This function processes the next un-expanded transition (or final-state)
    // out of the composed state numbered 'composed_state_to_expand'.
    void ProcessQueueElement(int32 composed_state_to_expand);
  
    // This is a part of ProcessQueueElements() that has been broken out
    // for clarity. it process the arc_index'th arc out of this source state.
    void ProcessTransition(int32 composed_src_state,
                           int32 arc_index);
  
    // This function recomputes certain members of the ComposedStateInfo relating
    // to the output states: namely, 'forward_cost', 'backward_cost' and
    // 'delta_backward_cost'.  In between calls to this function, we try to
    // keep those quantities as accurate as possible, but they aren't
    // completely accurate (see comments by their declarations for more info).
    void RecomputePruningInfo();
  
    // Sets '*composed_states' to a list of the states that currently
    // exist in the composed output, in topologically sorted order.
    // At exit, *composed_states will be a permutation of numbers
    // [0, 1, ...  clat_out_->NumStates() - 1], beginning with the
    // start-state 0.
    void GetTopsortedStateList(std::vector<int32> *composed_states) const;
  
    // Called from RecomputePruningInfo(), this computes all the 'forward_cost'
    // and 'prev_composed_state' members of the ComposedStateInfo.
    //   @param [in] composed_states  This is expected to be a list,
    //         in topological order, of all currently existing composed states,
    //         as produced by GetTopsortedStateList().
    void ComputeForwardCosts(const std::vector<int32> &composed_states);
  
    // Called from RecomputePruningInfo(), this computes all the 'backward_cost'
    // members of the ComposedStateInfo.  It also sets 'output_best_cost_'.
    // 'composed_states' is expected to be a list, in topological order, of all
    // currently existing composed states, as produced by GetTopsortedStateList().
    void ComputeBackwardCosts(const std::vector<int32> &composed_states);
  
    // Called from RecomputePruningInfo(), this computes all the
    // 'delta_backward_cost' members of the ComposedStateInfo.  'composed_states'
    // is expected to be a list, in topological order, of all currently existing
    // composed states, as produced by GetTopsortedStateList().  It also computes
    // the 'expected_cost_offset' values for all states, and uses them recreate
    // 'composed_state_queue_'.
    void ComputeDeltaBackwardCosts(const std::vector<int32> &composed_states);
  
  
    // This struct contains information about a state of the input lattice.
    struct LatticeStateInfo {
      // 'backward_cost' is the total cost of the best path from this state to
      //  the final state in the source lattice, including the final-prob.
      double backward_cost;
  
      // 'arc_delta_costs' is an array, one for each arc (and the final-prob, if
      // present), showing how much the cost to the final-state for the best path
      // starting in this state and exiting through each arc (or final-prob),
      // differs from 'backward_cost'.  Specifically, it contains pairs
      // (delta_cost, arc_index), where delta_cost >= 0 and arc_index is
      // either the index into this state's array of arcs (for arcs), or -1
      // if this represents the final-prob.
      //
      // 'arc_delta_costs' will be sorted, so that the first element has
      // .first=0.0 and the delta-costs will be increasing order.  This means that
      // we expand them from the start of the array, in order to process the best
      // arcs first.
      // lat_state_info_[i].arc_delta_costs.size() will equal will equal
      // clat_.NumStates(i), plus one if clat_.Final(i) != Zero().
      std::vector<std::pair<BaseFloat, int32> > arc_delta_costs;
  
  
      // 'composed_states' is a list of the state-ids in the composed output
      // that correspond to this state in the lattice, so we expect
      // that composed_state_info_[composed_states[i]].lat_state
      // equals the index of this lattice state.  This is helpful in
      // accessing the states in the output lattice in topological
      // order.
      std::vector<int32> composed_states;
    };
  
    // This struct contains information about a state of the composed
    // output.
    struct ComposedStateInfo {
      // 'lat_state' and 'lm_state' form the pair of states in the two FSTs
      // that this state corresponds to.  The unordered map 'pair_to_state_' maps these
      // state-pairs to the index of the composed state (the state-index in clat_out_).
      int32 lat_state;
      int32 lm_state;
  
      // the number of arcs on the path from the start state to this state, in the
      // composed lattice, by which this state was first reached.
      int32 depth;
  
      // If you have just called RecomputePruningInfo(), then
      // 'forward_cost' will equal the cost of the best path from the start-state
      // to this state, in the composed output.
      //
      // In between calls to RecomputePruningInfo() it may not always be fully up
      // to date; instead it will be an upper bound on what it would be if you had
      // just called RecomputePruningInfo(); it will be the cost of some path but
      // not necessarily the best path.
      double forward_cost;
  
      // 'backward_cost' relates to the cost from this state to the final-state in
      // the composed FST.  (By this we mean, more precisely, the cost of the best
      // path from this state to any final state, including the final-prob in that
      // final state).
      //
      // If we have just called RecomputePruningInfo(), then the following rules
      // specify what the value of 'backward_cost' will be:
      //   - If a final state is reachable from this state, backward_cost
      //     will contain the cost of the best path from this state to the
      //     final state (including the corresponding final-prob).
      //   - Otherwise, it will contain +infinity.
      //
      // If RecomputePruningInfo() has not just been called), it may contain any
      // value that is >= the value the the rules above specify (since, for
      // existing states, we don't modify it between calls to
      // RecomputePruningInfo()).  For states that have been added since
      // RecomputePruningInfo() was last called, it will be infinity.
      double backward_cost;
  
      // 'delta_backward_cost' is a quantity that is used in our heuristic of the
      // cost to an end-state from expanding a previously un-expanded arc.  It is
      // an estimate of the difference between the backward cost in this struct
      // (this->backward_cost) and the backward cost in the input lattice
      // (LatticeStateInfo::backward_cost).  This term reflects the anticipated
      // extra costs from 'det_fst_', which, while fairly close to zero, may be
      // substantial enough to want to correct for.
      //
      // The following is the value that 'delta_backward_cost' will have if
      // RecomputePruningInfo() has just been called:
      //   - If backward_cost is finite (this state in the composed result can
      //    reach the final state via currently expanded states), then
      //    delta_backward_cost is this->backward_cost minus
      //    lat_state_info_[lat_state].backward_cost.  (It will mostly, but
      //    not always, be <= 0, reflecting that the new LM is better than
      //    the old LM).
      //  - On the other hand, if backward_cost is infinite: delta_backward_cost
      //     is set to the delta_backward_cost of the previous state on the best
      //     path from the start state of the composed result to this state (or
      //     zero if this is the start state).
      //
      // If RecomputePruningInfo() has not just been called, then:
      //  - For states created since RecomputePruningInfo() was last called,
      //    delta_backward_cost will be inherited from the source state from
      //    which the new state was expanded.
      //  - For other states, delta_backward_cost will be unchanged since
      //    RecomputePruningInfo() was last called.
      // The above rules may make the delta_backward_cost a less accurate, but
      // still probably reasonable, heuristic.  What it is a heuristic for,
      // is: if we were to successfully reach an end-state of the composed output
      // from this state, what would be the resulting backward_cost
      // minus lat_state_info_[lat_state].backward_cost.
      BaseFloat delta_backward_cost;
  
      // 'prev_composed_state' is the previous state on the best path from
      // the start-state to the current state (or -1 if this is the start state).
      // It is computed in RecomputePruningInfo() when setting up 'forward_cost',
      // and then used to compute delta_backward_cost.  It is not otherwise
      // used.
      int32 prev_composed_state;
  
      // 'sorted_arc_index' is an index into the 'arc_delta_costs' array which is
      // a member of the LatticeStateInfo object corresponding to the lattice
      // state 'lat_state'.  It corresponds to the next arc (or final-prob) out of
      // the input lattice that we have yet to expand in the composition; or -1 if
      // we have expanded all of them.  When we first reach a composed state,
      // 'sorted_arc_index' will be zero; then it will increase one at a time as
      // we expand arcs until either the composition terminates or we have
      // expanded all the arcs and it becomes -1.
      int32 sorted_arc_index;
  
      // 'arc_delta_cost' is a derived quantity that we store here for easier
      // access.  Suppose this_lat_info is lat_state_info_[lat_state]; then
      // if sorted_arc_index >= 0, then:
      //    arc_delta_cost == this_lat_info.arc_delta_costs[sorted_arc_index].first
      // else: arc_delta_cost == +infinity.
      //
      // what 'arc_delta_cost' represents (or is a heuristic for), is the expected
      // cost of a path to the final-state leaving through the arc we're about to
      // expand, minus the expected cost of any path to the final-state starting
      // from this state.
      BaseFloat arc_delta_cost;
  
      // view 'expected_cost_offset' a phantom field of this struct, that has
      // been optimized out.  It's clearer if we act like it's a field, but
      // actually it's not stored.
      //
      // 'expected_cost_offset' is a derived quantity that reflects the expected
      // cost (according to our heuristic) of the best path we might encounter
      // when expanding the next previously unseen arc (or final-prob),
      // corresponding to 'sorted_arc_index'.  (This is the expected cost of a
      // successful path, from the beginning to the end of the lattice, but
      // constrained to be a path that contains the arc we're about to expand).
      //
      // The 'offset' part is about subtracting the best cost of the lattice, so we
      // can cast to float without too much loss of accuracy:
      //   expected_cost_offset = expected_cost - lat_best_cost_.
      //
      // We define expected_cost_offset by defining the 'expected_cost' part;
      // for clarity:
      //   First, let lat_backward_cost equal the backward_cost of the LatticeStateInfo
      //   corresponding to 'lat_state', i.e.
      //   lat_backward_cost = lat_state_info_[lat_state].backward_cost.  Then:
      //  expected_cost = forward_cost + lat_backward_cost +
      //                  delta_backward_cost + arc_delta_cost.
      // expected_cost_offset will always equal the above minus lat_best_cost_.
      //
      // The formula for expected_cost above is a pretty good heuristic for what
      // the cost to the end-state will be.  If the costs in det_fst_ were zero,
      // then the expression (forward_cost + lat_backward_cost + arc_delta_cost)
      // would be exact, and we would expand things in the ideal, best-first
      // order.  "delta_backward_cost" is a reasonable approximation for the extra
      // costs from 'det_fst_'.
      // BaseFloat expected_cost_offset;
    };
  
    // This bool variable is initialized to false, and will be updated to true
    // the first time a Final() function is called on the det_fst_. Then we will
    // immediately call RecomputeRruningInfo() so that the output_best_cost_ is
    // changed from +inf to a finite value, to be used in beam search. This is the
    // only time the RecomputeRruningInfo() function is called manually; otherwise
    // it always follows an automatic schedule based on the num-arcs of the output
    // lattice.
    bool output_reached_final_;
  
    // This variable, which we set initially to -1000, makes sure that in the
    // beginning of the algorithm, we always prioritize exploring the lattice
    // in a depth-first way. Once we find a path reaching a final state, this
    // variable will be reset to 0.
    // The reason we do this is because the beam-search depends on a good estimate
    // of the composed-best-cost, which before we reach a final state, we instead
    // borrow the value from best-cost from the input lattice, which is usually
    // systematically worse than the RNNLM scores, and makes the algorithm spend
    // a lot of time before reaching any final state, especially if the input
    // lattices are large.
    float depth_penalty_;
    const ComposeLatticePrunedOptions &opts_;
    const CompactLattice &clat_in_;
    fst::DeterministicOnDemandFst<fst::StdArc> *det_fst_;
    CompactLattice *clat_out_;
  
    // This counter keeps track of the number of arcs in the output lattice
    // clat_out_.  When it exceeds max_arcs,
    int32 num_arcs_out_;
  
    std::vector<LatticeStateInfo> lat_state_info_;
  
    // 'lat_best_cost' is the cost of the best path in the input lattice,
    // equal to lat_state_info_[0].backward_cost (we check that 0 is the
    // start state in the input lattice).
    double lat_best_cost_;
  
    // 'output_best_cost_' is the cost of the best successful path in the output
    // lattice 'clat_out_'; or +infinity if 'clat_out_' does not yet have any
    // successful paths.  It is updated only when RecomputePruningInfo() is
    // called.
    double output_best_cost_;
  
  
    // current_cutoff_ is a value used in deciding which composed states
    // need to be included in the queue.  Each time RecomputePruningInfo()
    // called, current_cutoff_ is set to
    //    (output_best_cost_ - lat_best_cost_ + opts_.lattice_compose_beam).
    // It will be +infinity if the output lattice doesn't yet have any
    // successful paths.  It decreases with time.  You can compare the
    // phantom 'expected_cost_offset' members of ComposedStateInfo with this
    // value; if they are more than this value, then there is no need
    // to enter the corresponding state into the queue.
    BaseFloat current_cutoff_;
  
    typedef std::priority_queue<std::pair<BaseFloat, int32>,
                        std::vector<std::pair<BaseFloat, int32> >,
                        std::greater<std::pair<BaseFloat, int32> > > QueueType;
  
    // composed_state_queue_ is a priority queue of the composed states
    // that we are intending to expand.  It contains pairs
    //   (expected_cost_offset, composed_state_index),
    // where expected_cost_offset == the phantom variable
    //       composed_state_info_[composed_state_index].expected_cost_offset.
    // We process the states from lowest cost first.
    // Every time RecomputePruningInfo() is called, this is cleared and repopulated
    // (since the states' expected_cost_offset values may change), and in between
    // calls to RecomputePruningInfo(), we do insert elements for newly created
    // states.
    QueueType composed_state_queue_;
  
  
    std::vector<ComposedStateInfo> composed_state_info_;
  
    // This maps a pair (lat_state, lm_state) to the index of the
    // state in the composed FST.  That would correspond to a state-id in
    // clat_out_, and also to an index into 'composed_state_info_'.
    unordered_map<std::pair<int32,int32>,
                  int32, PairHasher<int32> > pair_to_state_;
  
    // This contains the set of state-indexes of the input lattice that already
    // have states in the composed output (i.e. is in accessed_lat_states_ if and
    // only if !lat_state_info_[i].composed_states.empty().  The point is to be
    // able to enumerate, in order or in reverse order, just those states in the
    // lattice that appear in the composed output (it's an efficiency thing that
    // will matter more for early iterations of the composition, when we need
    // to access the output lattice in topological order).
    std::set<int32> accessed_lat_states_;
  };
  
  
  void PrunedCompactLatticeComposer::GetTopsortedStateList(
      std::vector<int32> *composed_states) const {
    composed_states->clear();
    composed_states->reserve(clat_out_->NumStates());
    std::set<int32>::const_iterator iter = accessed_lat_states_.begin(),
        end = accessed_lat_states_.end();
    for (; iter != end; ++iter) {
      int32 lat_state = *iter;
      const LatticeStateInfo &input_lat_info = lat_state_info_[lat_state];
      composed_states->insert(composed_states->end(),
                              input_lat_info.composed_states.begin(),
                              input_lat_info.composed_states.end());
    }
    KALDI_ASSERT((*composed_states)[0] == 0 &&
                 static_cast<int32>(composed_states->size()) ==
                 clat_out_->NumStates());
  }
  
  int32 PrunedCompactLatticeComposer::GetCurrentArcLimit() const {
    int32 current_num_arcs = num_arcs_out_;
    if (current_num_arcs == 0) {
      return opts_.initial_num_arcs;
    } else {
      KALDI_ASSERT(opts_.growth_ratio > 1.0);
      int32 ans = static_cast<int32>(current_num_arcs *
                                     opts_.growth_ratio);
      if (ans == current_num_arcs)  // make sure the target increases.
        ans = current_num_arcs + 1;
      // if we have already reached a final state, then
      // apply the max_arcs limit.
      if (output_best_cost_ - output_best_cost_ == 0.0 &&
          ans > opts_.max_arcs)
        ans = opts_.max_arcs;
      return ans;
    }
  
  }
  
  
  void PrunedCompactLatticeComposer::RecomputePruningInfo() {
    std::vector<int32> all_composed_states;
    GetTopsortedStateList(&all_composed_states);
    ComputeForwardCosts(all_composed_states);
    ComputeBackwardCosts(all_composed_states);
    ComputeDeltaBackwardCosts(all_composed_states);
  }
  
  void PrunedCompactLatticeComposer::ComputeForwardCosts(
      const std::vector<int32> &composed_states) {
    KALDI_ASSERT(composed_states[0] == 0);
  
    // Note: when we initialized composed_state_info_[0]
    // we set forward_cost = 0.0, prev_composed_state = -1.
  
    std::vector<ComposedStateInfo>::iterator
        state_iter = composed_state_info_.begin(),
        state_end = composed_state_info_.end();
  
    state_iter->depth = 0;  // start state has depth 0
    ++state_iter;  // Skip over the start state.
    // Set all other forward_cost fields to infinity and prev_composed_state to
    // -1.
    for (; state_iter != state_end; ++state_iter) {
      state_iter->forward_cost = std::numeric_limits<double>::infinity();
      state_iter->prev_composed_state = -1;
    }
  
    std::vector<int32>::const_iterator state_index_iter = composed_states.begin(),
        state_index_end = composed_states.end();
    for (; state_index_iter != state_index_end; ++state_index_iter) {
      int32 composed_state_index = *state_index_iter;
      const ComposedStateInfo &info = composed_state_info_[
          composed_state_index];
      double forward_cost = info.forward_cost;
      // The next line is a check for infinity.  If infinities have appeared, it
      // either means there is a bug in the algorithm or there were infinities or
      // NaN's in the lattice.
      KALDI_ASSERT(forward_cost - forward_cost == 0.0);
      fst::ArcIterator<CompactLattice> aiter(*clat_out_,
                                             composed_state_index);
      for (; !aiter.Done(); aiter.Next()) {
        const CompactLatticeArc &arc = aiter.Value();
        double arc_cost = ConvertToCost(arc.weight),
            next_forward_cost = forward_cost + arc_cost;
        ComposedStateInfo &next_info = composed_state_info_[arc.nextstate];
        if (next_info.forward_cost > next_forward_cost) {
          next_info.forward_cost = next_forward_cost;
          next_info.prev_composed_state = composed_state_index;
          next_info.depth = composed_state_info_[composed_state_index].depth + 1;
        }
      }
    }
  }
  
  void PrunedCompactLatticeComposer::ComputeBackwardCosts(
      const std::vector<int32> &composed_states) {
    // Access the composed states in reverse topological order from latest to
    // earliest.
    std::vector<int32>::const_reverse_iterator iter = composed_states.rbegin(),
        end = composed_states.rend();
    for (; iter != end; ++iter) {
      int32 composed_state_index = *iter;
      ComposedStateInfo &info = composed_state_info_[composed_state_index];
      double backward_cost =
          ConvertToCost(clat_out_->Final(composed_state_index));
      fst::ArcIterator<CompactLattice> aiter(*clat_out_,
                                             composed_state_index);
      for (; !aiter.Done(); aiter.Next()) {
        const CompactLatticeArc &arc = aiter.Value();
        double arc_cost = ConvertToCost(arc.weight),
           next_backward_cost = composed_state_info_[arc.nextstate].backward_cost,
           this_backward_cost = arc_cost + next_backward_cost;
        if (this_backward_cost < backward_cost)
          backward_cost = this_backward_cost;
      }
      // It's OK if at this point, backward_cost is still +infinity.  This means
      // that this state cannot reach the end yet, which means we have not yet
      // expanded any path from this state all the way to a final-state of the
      // output.
      info.backward_cost = backward_cost;
    }
    output_best_cost_ = composed_state_info_[0].backward_cost;
    // See the declaration of current_cutoff_ for more information.  Note: on
    // early iterations, before any path reaches a final state of the composed
    // lattice, current_cutoff_ may be +infinity, and this is OK.
    current_cutoff_ =
        output_best_cost_ - lat_best_cost_ + opts_.lattice_compose_beam;
  }
  
  void PrunedCompactLatticeComposer::ComputeDeltaBackwardCosts(
      const std::vector<int32> &composed_states) {
  
    int32 num_states = clat_out_->NumStates();
    for (int32 composed_state_index = 0; composed_state_index < num_states;
         ++composed_state_index) {
      ComposedStateInfo &info = composed_state_info_[composed_state_index];
      int32 lat_state = info.lat_state;
      // Note: delta_backward_cost will be +infinity at this stage if the
      // backward_cost was +infinity.  This is OK; we'll set them all to
      // finite values later in this function.
      info.delta_backward_cost =
          info.backward_cost - lat_state_info_[lat_state].backward_cost + info.depth * depth_penalty_;
    }
  
    // 'queue_elements' is a list of items (expected_cost_offset,
    // composed_state_index) that we are going to add to composed_state_queue_,
    // after clearing it.  It's more efficient to accumulate them as a vector
    // and add them all at once, than adding them one by one (search online for
    // "heapify" if this seems confusing).
    std::vector<std::pair<BaseFloat, int32> > queue_elements;
    queue_elements.reserve(num_states);
  
    double lat_best_cost = lat_best_cost_;
    BaseFloat current_cutoff = current_cutoff_;
    std::vector<int32>::const_iterator iter = composed_states.begin(),
        end = composed_states.end();
    for (; iter != end; ++iter) {
      int32 composed_state_index = *iter;
      ComposedStateInfo &info = composed_state_info_[composed_state_index];
      if (info.delta_backward_cost - info.delta_backward_cost != 0) {
        // if info.delta_backward_cost is +infinity...
        int32 prev_composed_state = info.prev_composed_state;
        if (prev_composed_state < 0) {
          KALDI_ASSERT(composed_state_index == 0);
          info.delta_backward_cost = 0.0;
        } else {
          const ComposedStateInfo &prev_info =
              composed_state_info_[prev_composed_state];
          // Check that prev_info.delta_backward_cost is finite.
          KALDI_ASSERT(prev_info.delta_backward_cost -
                       prev_info.delta_backward_cost == 0.0);
          info.delta_backward_cost = prev_info.delta_backward_cost + depth_penalty_;
        }
      }
      double lat_backward_cost = lat_state_info_[info.lat_state].backward_cost;
      // See the formula by where expected_cost_offset is declared in the
      // struct for explanation.
      BaseFloat expected_cost_offset =
          info.forward_cost + lat_backward_cost + info.delta_backward_cost +
          info.arc_delta_cost - lat_best_cost;
      // If info.expected_cost_offset were real, we'd set it here:
      //info.expected_cost_offset = expected_cost_offset;
  
      // At this point expected_cost_offset may be infinite, if arc_delta_cost was
      // infinite (reflecting that we processed all the arcs, and the final-state
      // if applicable, of the lattice state corresponding to this composed state.
      if (expected_cost_offset < current_cutoff) {
        queue_elements.push_back(std::pair<BaseFloat, int32>(
            expected_cost_offset, composed_state_index));
      }
    }
  
    // Reinitialize composed_state_queue_ from 'queue_elements'.
    QueueType temp_queue(queue_elements.begin(), queue_elements.end());
    composed_state_queue_.swap(temp_queue);
  }
  
  void PrunedCompactLatticeComposer::ComputeLatticeStateInfo() {
    KALDI_ASSERT(clat_in_.Properties(fst::kTopSorted, true) ==
                 fst::kTopSorted && clat_in_.NumStates() > 0 &&
                 clat_in_.Start()  == 0);
    int32 num_lat_states = clat_in_.NumStates();
    lat_state_info_.resize(num_lat_states);
  
    for (int32 s = num_lat_states - 1; s >= 0; s--) {
      LatticeStateInfo &info = lat_state_info_[s];
      std::vector<std::pair<double, int32> > arc_costs;
      double backward_cost = ConvertToCost(clat_in_.Final(s));
      if (backward_cost != std::numeric_limits<double>::infinity())
        arc_costs.push_back(std::pair<BaseFloat,int32>(backward_cost, -1));
      fst::ArcIterator<CompactLattice> aiter(clat_in_, s);
      int32 arc_index = 0;
      for (; !aiter.Done(); aiter.Next(), ++arc_index)  {
        const CompactLatticeArc &arc = aiter.Value();
        KALDI_ASSERT(arc.nextstate > s);
        backward_cost = lat_state_info_[arc.nextstate].backward_cost +
            ConvertToCost(arc.weight);
        KALDI_ASSERT(backward_cost - backward_cost == 0.0 &&
                     "Possibly not all states of input lattice are co-accessible?");
        arc_costs.push_back(std::pair<BaseFloat,int32>(backward_cost, arc_index));
      }
      std::sort(arc_costs.begin(), arc_costs.end());
      KALDI_ASSERT(!arc_costs.empty() &&
                   "Possibly not all states of input lattice are co-accessible?");
      backward_cost = arc_costs[0].first;
      info.backward_cost = backward_cost;  // this is the state's backward_cost,
                                           // reflecting the best path to the end.
      info.arc_delta_costs.resize(arc_costs.size());
      std::vector<std::pair<double, int32> >::const_iterator
          src_iter = arc_costs.begin(), src_end = arc_costs.end();
      std::vector<std::pair<BaseFloat, int32> >::iterator
          dest_iter = info.arc_delta_costs.begin();
      for (; src_iter != src_end; ++src_iter, ++dest_iter) {
        dest_iter->first = BaseFloat(src_iter->first - backward_cost);
        dest_iter->second = src_iter->second;
      }
    }
    lat_best_cost_ = lat_state_info_[0].backward_cost;
  }
  
  PrunedCompactLatticeComposer::PrunedCompactLatticeComposer(
        const ComposeLatticePrunedOptions &opts,
        const CompactLattice &clat_in,
        fst::DeterministicOnDemandFst<fst::StdArc> *det_fst,
        CompactLattice* composed_clat): output_reached_final_(false),
      opts_(opts), clat_in_(clat_in), det_fst_(det_fst),
      clat_out_(composed_clat),
      num_arcs_out_(0),
      output_best_cost_(std::numeric_limits<double>::infinity()),
      current_cutoff_(std::numeric_limits<double>::infinity()) {
    clat_out_->DeleteStates();
    depth_penalty_ = -1000;
  }
  
  
  void PrunedCompactLatticeComposer::AddFirstState() {
    int32 state_id = clat_out_->AddState();
    clat_out_->SetStart(state_id);
    KALDI_ASSERT(state_id == 0);
    composed_state_info_.resize(1);
    ComposedStateInfo &composed_state = composed_state_info_[0];
    composed_state.lat_state = 0;
    composed_state.lm_state = det_fst_->Start();
    composed_state.depth = 0;
    composed_state.forward_cost = 0.0;
    composed_state.backward_cost = std::numeric_limits<double>::infinity();
    composed_state.delta_backward_cost = 0.0;
    composed_state.prev_composed_state = -1;
    composed_state.sorted_arc_index = 0;
    composed_state.arc_delta_cost = 0.0; // the first arc_delta_cost is always 0.0
                                         // due to sorting; no need to look it up.
    lat_state_info_[0].composed_states.push_back(state_id);
    accessed_lat_states_.insert(state_id);
    pair_to_state_[std::pair<int32, int32>(0, det_fst_->Start())] = state_id;
  
    BaseFloat expected_cost_offset = 0.0;  // the formula simplifies to zero
                                           // in this case.
    composed_state_queue_.push(
        std::pair<BaseFloat, int32>(expected_cost_offset,
                                    state_id));  // actually (0.0, 0).
  
  }
  
  
  void PrunedCompactLatticeComposer::ProcessQueueElement(
      int32 src_composed_state) {
    KALDI_ASSERT(static_cast<size_t>(src_composed_state) <
                 composed_state_info_.size());
  
    ComposedStateInfo &src_composed_state_info = composed_state_info_[
        src_composed_state];
    int32 lat_state = src_composed_state_info.lat_state;
    const LatticeStateInfo &lat_state_info =
        lat_state_info_[lat_state];
  
    int32 sorted_arc_index = src_composed_state_info.sorted_arc_index,
        num_sorted_arcs = lat_state_info.arc_delta_costs.size();
    // note: num_sorted_arcs will be the number of arcs from this
    // lattice state; plus one if there is a final-prob.
    KALDI_ASSERT(sorted_arc_index >= 0);
  
    { // this block update the state's 'sorted_arc_index', 'arc_delta_cost' and
      // 'expected_cost_offset' to reflect the fact that (by the time we exit from
      // this function) we will have processed this arc (or the final-prob);
      // it also re-inserts this state into the queue, if appropriate.
      BaseFloat expected_cost_offset;
      if (sorted_arc_index + 1 == num_sorted_arcs) {
        src_composed_state_info.sorted_arc_index = -1;
        src_composed_state_info.arc_delta_cost =
            std::numeric_limits<BaseFloat>::infinity();
        expected_cost_offset =
            std::numeric_limits<BaseFloat>::infinity();
      } else {
        src_composed_state_info.sorted_arc_index = sorted_arc_index + 1;
        src_composed_state_info.arc_delta_cost =
            lat_state_info.arc_delta_costs[sorted_arc_index+1].first;
        expected_cost_offset =
            (src_composed_state_info.forward_cost +
             lat_state_info.backward_cost +
             src_composed_state_info.delta_backward_cost +
             src_composed_state_info.arc_delta_cost - lat_best_cost_);
      }
      // We do '<' here rather than '<=', so that if current_cutoff_ is infinity
      // and expected_cost_offset is infinity (because we've exhausted all the
      // transitions from this state, and sorted_arc_index is now -1), we don't
      // add this element to the queue.
      if (expected_cost_offset < current_cutoff_) {
        // this state has another exit arc (or final prob) that is good
        // enough to re-enter into the queue.  Note: if we are processing
        // an arc out of this state and the destination state is new,
        // we may also add something new to the queue at that time.
  
        // the following call should be equivalent to
        // composed_state_queue_.push(std::pair<BaseFloat,int32>(...)) with
        // the same pair of args.
        composed_state_queue_.emplace(
            expected_cost_offset, src_composed_state);
      }
    }
  
    int32 arc_index = lat_state_info.arc_delta_costs[sorted_arc_index].second;
    if (arc_index < 0) {  // This (arc_index == -1) means it is not really an arc
                          // index; it's a final-prob.
      int32 lm_state = src_composed_state_info.lm_state;
      BaseFloat lm_final_cost = det_fst_->Final(lm_state).Value();
      if (lm_final_cost != std::numeric_limits<BaseFloat>::infinity()) {
        // If there is a final-prob on this LM state (note: there always will be
        // for conventional language models), then add the final-prob of this
        // state...
        CompactLattice::Weight final_weight = clat_in_.Final(lat_state);
        // assume 'final_weight' is not Zero(); otherwise the final-prob should
        // not have been present in 'arc_delta_costs'.
        Lattice::Weight final_lat_weight = final_weight.Weight();
        final_lat_weight.SetValue1(final_lat_weight.Value1() +
                                   lm_final_cost);
        final_weight.SetWeight(final_lat_weight);
        clat_out_->SetFinal(src_composed_state, final_weight);
        double final_cost = ConvertToCost(final_lat_weight);
        if (final_cost < src_composed_state_info.backward_cost)
          src_composed_state_info.backward_cost = final_cost;
        if (!output_reached_final_) {
          output_reached_final_ = true;
          depth_penalty_ = 0.0;
          RecomputePruningInfo();
        }
      }
    } else {
      // It really was an arc.  This code is very complicated, so we make it its
      // own function.
      ProcessTransition(src_composed_state, arc_index);
    }
  }
  
  void PrunedCompactLatticeComposer::ProcessTransition(int32 src_composed_state,
                                                       int32 arc_index) {
    // Make src_composed_state a const pointer not a reference, as we may have to
    // modify the pointer if composed_state_info_ is resized.
    const ComposedStateInfo *src_info = &(composed_state_info_[
        src_composed_state]);
    int32 src_lat_state = src_info->lat_state;
    // Get the arc we are going to expand.
    fst::ArcIterator<CompactLattice> aiter(clat_in_, src_lat_state);
    aiter.Seek(arc_index);
    const CompactLatticeArc &lat_arc = aiter.Value();
    // Note: this code is for CompactLatticeArc, in which the ilabel and olabel
    // are the same, but we're writing it in such a way that it will naturally
    // generalize to LatticeArc, so there are separate variables for the ilabel
    // and the olabel.
    int32 dest_lat_state = lat_arc.nextstate,
        ilabel = lat_arc.ilabel,
        olabel = lat_arc.olabel;
    // Note: we expect that ilabel == olabel, since this is a CompactLattice, but this
    // may not be so if we extend this to work with Lattice.
    fst::StdArc lm_arc;
  
    // the input lattice might have epsilons
    if (olabel == 0) {
      lm_arc.ilabel = 0;
      lm_arc.olabel = 0;
      lm_arc.nextstate = src_info->lm_state;
      lm_arc.weight = fst::StdArc::Weight(0.0);
    } else if (!det_fst_->GetArc(src_info->lm_state, olabel, &lm_arc)) {
      // for normal language models we don't expect this to happen, but the
      // appropriate behavior is to do nothing; the composed arc does not exist,
      // so there is no arc to add and no new state to create.
      return;
    }
    int32 dest_lm_state = lm_arc.nextstate;
    // The following assertion is necessary because CompactLattice cannot support
    // different ilabel vs. olabel; and also it's an expectation about
    // language-models.
    KALDI_ASSERT(lm_arc.ilabel == lm_arc.olabel);
  
    LatticeStateInfo &dest_lat_state_info =
        lat_state_info_[dest_lat_state];
  
    int32 dest_composed_state;
    ComposedStateInfo *dest_info;
  
    { // The next block works out 'dest_composed_state' and
      // 'dest_info', and if the destination state did not already
      // exist, creates a new composed state.
      typedef std::unordered_map<std::pair<int32,int32>, int32,
          PairHasher<int32> > MapType;
      int32 new_composed_state = clat_out_->NumStates();
      std::pair<const std::pair<int32,int32>, int32> value(
          std::pair<int32,int32>(dest_lat_state, dest_lm_state), new_composed_state);
      std::pair<MapType::iterator, bool> ret =
          pair_to_state_.insert(value);
      if (ret.second) {
        // Successfully inserted: this dest-state did not already exist.  Most of
        // the rest of this block deals with the consequences of adding a new
        // state.
        int32 ans = clat_out_->AddState();
        KALDI_ASSERT(ans == new_composed_state);
        dest_composed_state = new_composed_state;
        composed_state_info_.resize(dest_composed_state + 1);
        dest_info = &(composed_state_info_[dest_composed_state]);
        // Re-assign src_composed_state as the vector might have been reallocated.
        src_info = &(composed_state_info_[src_composed_state]);
        if (dest_lat_state_info.composed_states.empty())
          accessed_lat_states_.insert(dest_lat_state);
        dest_lat_state_info.composed_states.push_back(new_composed_state);
        dest_info->lat_state = dest_lat_state;
        dest_info->lm_state = dest_lm_state;
        dest_info->depth = src_info->depth + 1;
        dest_info->forward_cost =
            src_info->forward_cost +
            ConvertToCost(lat_arc.weight) + lm_arc.weight.Value();
        dest_info->backward_cost =
            std::numeric_limits<double>::infinity();
        dest_info->delta_backward_cost =
            src_info->delta_backward_cost + dest_info->depth * depth_penalty_;
        // The 'prev_composed_state' field will not be read again until after it's
        // overwritten; we set it as below only for debugging purposes (the
        // negation is also for debugging purposes).
        dest_info->prev_composed_state = -src_composed_state;
        dest_info->sorted_arc_index = 0;
        dest_info->arc_delta_cost = 0.0;
        // Note: in the expression below, which can be understood with reference
        // to the comment by the declaration of the phantom variable
        // 'expected_cost_offset', 'arc_delta_cost' is known to equal 0.0 so it
        // has been removed.
        BaseFloat expected_cost_offset =
            (dest_info->forward_cost +
             dest_lat_state_info.backward_cost +
             dest_info->delta_backward_cost -
             lat_best_cost_);
        if (expected_cost_offset < current_cutoff_) {
          // the following call should be equivalent to
          // composed_state_queue_.push(std::pair<BaseFloat,int32>(...)) with
          // the same pair of args.
          composed_state_queue_.emplace(expected_cost_offset,
                                        dest_composed_state);
        }
      } else { // the destination composed state already existed.
        dest_composed_state = ret.first->second;
        dest_info = &(composed_state_info_[dest_composed_state]);
      }
    }
    // Add the arc from the src to dest state in the composed output.
    CompactLatticeArc new_arc;
    new_arc.nextstate = dest_composed_state;
    // Actually the ilabel and olabel are the same, but writing it this way will
    // generalize better to type Lattice if we need that later.
    new_arc.ilabel = ilabel;
    new_arc.olabel = olabel;
    new_arc.weight = lat_arc.weight;
    // 'weight' is the weight part, as opposed to the string part.
    LatticeArc::Weight weight = new_arc.weight.Weight();
    // include the LM-arc's weight in the weight of the new arc.
    weight.SetValue1(fst::Times(weight.Value1(), lm_arc.weight).Value());
    new_arc.weight.SetWeight(weight);
    clat_out_->AddArc(src_composed_state, new_arc);
    num_arcs_out_++;
  }
  
  static int32 TotalNumArcs(const CompactLattice &clat) {
    int32 num_states = clat.NumStates(),
        num_arcs = 0;
    for (int32 s = 0; s < num_states; s++)
      num_arcs += clat.NumArcs(s);
    return num_arcs;
  }
  
  
  void PrunedCompactLatticeComposer::Compose() {
    if (clat_in_.NumStates() == 0) {
      KALDI_WARN << "Input lattice to composition is empty.";
      return;
    }
    ComputeLatticeStateInfo();
    AddFirstState();
    // while (we have not reached final state  ||
    //        num-arcs produced < target num-arcs) { ...
    while (output_best_cost_ == std::numeric_limits<double>::infinity() ||
           num_arcs_out_ < opts_.max_arcs) {
      RecomputePruningInfo();
      int32 this_iter_arc_limit = GetCurrentArcLimit();
      while (num_arcs_out_ < this_iter_arc_limit &&
             !composed_state_queue_.empty()) {
        int32 src_composed_state = composed_state_queue_.top().second;
        composed_state_queue_.pop();
        ProcessQueueElement(src_composed_state);
      }
      if (composed_state_queue_.empty())
        break;
    }
  
    fst::Connect(clat_out_);
    TopSortCompactLatticeIfNeeded(clat_out_);
  
    if (GetVerboseLevel() >= 2) {
      int32 num_arcs_in = TotalNumArcs(clat_in_),
          orig_num_arcs_out = num_arcs_out_,
          num_arcs_out = TotalNumArcs(*clat_out_),
          num_states_in = clat_in_.NumStates(),
          orig_num_states_out = composed_state_info_.size(),
          num_states_out = clat_out_->NumStates();
      std::ostringstream os;
      os << "Input lattice had " << num_arcs_in << '/' << num_states_in
         << " arcs/states; output lattice has " << num_arcs_out << '/'
         << num_states_out;
      if (num_arcs_out != orig_num_arcs_out) {
        os << " (before pruning: " << orig_num_arcs_out << '/'
           << orig_num_states_out << ")";
      }
      if (!composed_state_queue_.empty()) {
        // Below, composed_state_queue_.top().first + lat_best_cost is an
        // expected-cost of the best path from the composed output that we *did
        // not* expand.  This, minus the best cost in the output compact lattice,
        // can be interpreted as the beam that we effecctively pruned the output
        // lattice to.
        BaseFloat effective_beam =
            composed_state_queue_.top().first + lat_best_cost_ - output_best_cost_;
        os << ". Effective beam was " << effective_beam;
      }
      KALDI_VLOG(2) << os.str();
    }
  
    if (clat_out_->NumStates() == 0) {
      KALDI_WARN << "Composed lattice has no states: something went wrong.";
    }
  }
  
  void ComposeCompactLatticePruned(
      const ComposeLatticePrunedOptions &opts,
      const CompactLattice &clat,
      fst::DeterministicOnDemandFst<fst::StdArc> *det_fst,
      CompactLattice* composed_clat) {
    PrunedCompactLatticeComposer composer(opts, clat, det_fst, composed_clat);
    composer.Compose();
  }
  
  } // namespace kaldi