grammar-fst.cc
41.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
// decoder/grammar-fst.cc
// Copyright 2018 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "decoder/grammar-fst.h"
#include "fstext/grammar-context-fst.h"
namespace fst {
GrammarFst::GrammarFst(
int32 nonterm_phones_offset,
std::shared_ptr<const ConstFst<StdArc> > top_fst,
const std::vector<std::pair<Label, std::shared_ptr<const ConstFst<StdArc> > > > &ifsts):
nonterm_phones_offset_(nonterm_phones_offset),
top_fst_(top_fst),
ifsts_(ifsts) {
Init();
}
void GrammarFst::Init() {
KALDI_ASSERT(nonterm_phones_offset_ > 1);
InitNonterminalMap();
entry_arcs_.resize(ifsts_.size());
if (!ifsts_.empty()) {
// We call this mostly so that if something is wrong with the input FSTs, the
// problem will be detected sooner rather than later.
// There would be no problem if we were to call InitEntryArcs(i)
// for all 0 <= i < ifsts_size(), but we choose to call it
// lazily on demand, to save startup time if the number of nonterminals
// is large.
InitEntryArcs(0);
}
InitInstances();
}
GrammarFst::~GrammarFst() {
Destroy();
}
void GrammarFst::Destroy() {
for (size_t i = 0; i < instances_.size(); i++) {
FstInstance &instance = instances_[i];
std::unordered_map<BaseStateId, ExpandedState*>::const_iterator
iter = instance.expanded_states.begin(),
end = instance.expanded_states.end();
for (; iter != end; ++iter) {
ExpandedState *e = iter->second;
delete e;
}
}
top_fst_ = NULL;
ifsts_.clear();
nonterminal_map_.clear();
entry_arcs_.clear();
instances_.clear();
}
void GrammarFst::DecodeSymbol(Label label,
int32 *nonterminal_symbol,
int32 *left_context_phone) {
// encoding_multiple will normally equal 1000 (but may be a multiple of 1000
// if there are a lot of phones); kNontermBigNumber is 10000000.
int32 big_number = static_cast<int32>(kNontermBigNumber),
nonterm_phones_offset = nonterm_phones_offset_,
encoding_multiple = GetEncodingMultiple(nonterm_phones_offset);
// The following assertion should be optimized out as the condition is
// statically known.
KALDI_ASSERT(big_number % static_cast<int32>(kNontermMediumNumber) == 0);
*nonterminal_symbol = (label - big_number) / encoding_multiple;
*left_context_phone = label % encoding_multiple;
if (*nonterminal_symbol <= nonterm_phones_offset ||
*left_context_phone == 0 || *left_context_phone >
nonterm_phones_offset + static_cast<int32>(kNontermBos))
KALDI_ERR << "Decoding invalid label " << label
<< ": code error or invalid --nonterm-phones-offset?";
}
void GrammarFst::InitNonterminalMap() {
nonterminal_map_.clear();
for (size_t i = 0; i < ifsts_.size(); i++) {
int32 nonterminal = ifsts_[i].first;
if (nonterminal_map_.count(nonterminal))
KALDI_ERR << "Nonterminal symbol " << nonterminal
<< " is paired with two FSTs.";
if (nonterminal < GetPhoneSymbolFor(kNontermUserDefined))
KALDI_ERR << "Nonterminal symbol " << nonterminal
<< " in input pairs, was expected to be >= "
<< GetPhoneSymbolFor(kNontermUserDefined);
nonterminal_map_[nonterminal] = static_cast<int32>(i);
}
}
void GrammarFst::InitEntryArcs(int32 i) {
KALDI_ASSERT(static_cast<size_t>(i) < ifsts_.size());
const ConstFst<StdArc> &fst = *(ifsts_[i].second);
InitEntryOrReentryArcs(fst, fst.Start(),
GetPhoneSymbolFor(kNontermBegin),
&(entry_arcs_[i]));
}
void GrammarFst::InitInstances() {
KALDI_ASSERT(instances_.empty());
instances_.resize(1);
instances_[0].ifst_index = -1;
instances_[0].fst = top_fst_.get();
instances_[0].parent_instance = -1;
instances_[0].parent_state = -1;
}
void GrammarFst::InitEntryOrReentryArcs(
const ConstFst<StdArc> &fst,
int32 entry_state,
int32 expected_nonterminal_symbol,
std::unordered_map<int32, int32> *phone_to_arc) {
phone_to_arc->clear();
ArcIterator<ConstFst<StdArc> > aiter(fst, entry_state);
int32 arc_index = 0;
for (; !aiter.Done(); aiter.Next(), ++arc_index) {
const StdArc &arc = aiter.Value();
int32 nonterminal, left_context_phone;
if (arc.ilabel <= (int32)kNontermBigNumber) {
if (entry_state == fst.Start()) {
KALDI_ERR << "There is something wrong with the graph; did you forget to "
"add #nonterm_begin and #nonterm_end to the non-top-level FSTs "
"before compiling?";
} else {
KALDI_ERR << "There is something wrong with the graph; re-entry state is "
"not as anticipated.";
}
}
DecodeSymbol(arc.ilabel, &nonterminal, &left_context_phone);
if (nonterminal != expected_nonterminal_symbol) {
KALDI_ERR << "Expected arcs from this state to have nonterminal-symbol "
<< expected_nonterminal_symbol << ", but got "
<< nonterminal;
}
std::pair<int32, int32> p(left_context_phone, arc_index);
if (!phone_to_arc->insert(p).second) {
// If it was not successfully inserted in the phone_to_arc map, it means
// there were two arcs with the same left-context phone, which does not
// make sense; that's an error, likely a code error (or an error when the
// input FSTs were generated).
KALDI_ERR << "Two arcs had the same left-context phone.";
}
}
}
GrammarFst::ExpandedState *GrammarFst::ExpandState(
int32 instance_id, BaseStateId state_id) {
int32 big_number = kNontermBigNumber;
const ConstFst<StdArc> &fst = *(instances_[instance_id].fst);
ArcIterator<ConstFst<StdArc> > aiter(fst, state_id);
KALDI_ASSERT(!aiter.Done() && aiter.Value().ilabel > big_number &&
"Something is not right; did you call PrepareForGrammarFst()?");
const StdArc &arc = aiter.Value();
int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_),
nonterminal = (arc.ilabel - big_number) / encoding_multiple;
if (nonterminal == GetPhoneSymbolFor(kNontermBegin) ||
nonterminal == GetPhoneSymbolFor(kNontermReenter)) {
KALDI_ERR << "Encountered unexpected type of nonterminal while "
"expanding state.";
} else if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) {
return ExpandStateEnd(instance_id, state_id);
} else if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) {
return ExpandStateUserDefined(instance_id, state_id);
} else {
KALDI_ERR << "Encountered unexpected type of nonterminal "
<< nonterminal << " while expanding state.";
}
return NULL; // Suppress compiler warning
}
// static inline
void GrammarFst::CombineArcs(const StdArc &leaving_arc,
const StdArc &arriving_arc,
float cost_correction,
StdArc *arc) {
// The following assertion shouldn't fail; we ensured this in
// PrepareForGrammarFst(), search for 'olabel_problem'.
KALDI_ASSERT(leaving_arc.olabel == 0);
// 'leaving_arc' leaves one fst, and 'arriving_arcs', conceptually arrives in
// another. This code merges the information of the two arcs to make a
// cross-FST arc. The ilabel information is discarded as it was only intended
// for the consumption of the GrammarFST code.
arc->ilabel = 0;
arc->olabel = arriving_arc.olabel;
// conceptually, arc->weight =
// Times(Times(leaving_arc.weight, arriving_arc.weight), Weight(cost_correction)).
// The below might be a bit faster, I hope-- avoiding checking.
arc->weight = Weight(cost_correction + leaving_arc.weight.Value() +
arriving_arc.weight.Value());
arc->nextstate = arriving_arc.nextstate;
}
GrammarFst::ExpandedState *GrammarFst::ExpandStateEnd(
int32 instance_id, BaseStateId state_id) {
if (instance_id == 0)
KALDI_ERR << "Did not expect #nonterm_end symbol in FST-instance 0.";
const FstInstance &instance = instances_[instance_id];
int32 parent_instance_id = instance.parent_instance;
const ConstFst<StdArc> &fst = *(instance.fst);
const FstInstance &parent_instance = instances_[parent_instance_id];
const ConstFst<StdArc> &parent_fst = *(parent_instance.fst);
ExpandedState *ans = new ExpandedState;
ans->dest_fst_instance = parent_instance_id;
// parent_aiter is the arc-iterator in the state we return to. We'll Seek()
// to a different position 'parent_aiter' for each arc leaving this state.
// (actually we expect just one arc to leave this state).
ArcIterator<ConstFst<StdArc> > parent_aiter(parent_fst,
instance.parent_state);
// for explanation of cost_correction, see documentation for CombineArcs().
float num_reentry_arcs = instances_[instance_id].parent_reentry_arcs.size(),
cost_correction = -log(num_reentry_arcs);
ArcIterator<ConstFst<StdArc> > aiter(fst, state_id);
for (; !aiter.Done(); aiter.Next()) {
const StdArc &leaving_arc = aiter.Value();
int32 this_nonterminal, left_context_phone;
DecodeSymbol(leaving_arc.ilabel, &this_nonterminal,
&left_context_phone);
KALDI_ASSERT(this_nonterminal == GetPhoneSymbolFor(kNontermEnd) &&
">1 nonterminals from a state; did you use "
"PrepareForGrammarFst()?");
std::unordered_map<int32, int32>::const_iterator reentry_iter =
instances_[instance_id].parent_reentry_arcs.find(left_context_phone),
reentry_end = instances_[instance_id].parent_reentry_arcs.end();
if (reentry_iter == reentry_end) {
KALDI_ERR << "FST with index " << instance.ifst_index
<< " ends with left-context-phone " << left_context_phone
<< " but parent FST does not support that left-context "
"at the return point.";
}
size_t parent_arc_index = static_cast<size_t>(reentry_iter->second);
parent_aiter.Seek(parent_arc_index);
const StdArc &arriving_arc = parent_aiter.Value();
// 'arc' will combine the information on 'leaving_arc' and 'arriving_arc',
// except that the ilabel will be set to zero.
if (leaving_arc.olabel != 0) {
// If the following fails it would maybe indicate you hadn't called
// PrepareForGrammarFst(), or there was an error in that, because
// we made sure the leaving arc does not have an olabel. Search
// in that code for 'olabel_problem' for more details.
KALDI_ERR << "Leaving arc has zero olabel.";
}
StdArc arc;
CombineArcs(leaving_arc, arriving_arc, cost_correction, &arc);
ans->arcs.push_back(arc);
}
return ans;
}
int32 GrammarFst::GetChildInstanceId(int32 instance_id, int32 nonterminal,
int32 state) {
int64 encoded_pair = (static_cast<int64>(nonterminal) << 32) + state;
// 'new_instance_id' is the instance-id we'd assign if we had to create a new one.
// We try to add it at once, to avoid having to do an extra map lookup in case
// it wasn't there and we did need to add it.
int32 child_instance_id = instances_.size();
{
std::pair<int64, int32> p(encoded_pair, child_instance_id);
std::pair<std::unordered_map<int64, int32>::const_iterator, bool> ans =
instances_[instance_id].child_instances.insert(p);
if (!ans.second) {
// The pair was not inserted, which means the key 'encoded_pair' did exist in the
// map. Return the value in the map.
child_instance_id = ans.first->second;
return child_instance_id;
}
}
// If we reached this point, we did successfully insert 'child_instance_id' into
// the map, because the key didn't exist. That means we have to actually create
// the instance.
instances_.resize(child_instance_id + 1);
const FstInstance &parent_instance = instances_[instance_id];
FstInstance &child_instance = instances_[child_instance_id];
// Work out the ifst_index for this nonterminal.
std::unordered_map<int32, int32>::const_iterator iter =
nonterminal_map_.find(nonterminal);
if (iter == nonterminal_map_.end()) {
KALDI_ERR << "Nonterminal " << nonterminal << " was requested, but "
"there is no FST for it.";
}
int32 ifst_index = iter->second;
child_instance.ifst_index = ifst_index;
child_instance.fst = ifsts_[ifst_index].second.get();
child_instance.parent_instance = instance_id;
child_instance.parent_state = state;
InitEntryOrReentryArcs(*(parent_instance.fst), state,
GetPhoneSymbolFor(kNontermReenter),
&(child_instance.parent_reentry_arcs));
return child_instance_id;
}
GrammarFst::ExpandedState *GrammarFst::ExpandStateUserDefined(
int32 instance_id, BaseStateId state_id) {
const ConstFst<StdArc> &fst = *(instances_[instance_id].fst);
ArcIterator<ConstFst<StdArc> > aiter(fst, state_id);
ExpandedState *ans = new ExpandedState;
int32 dest_fst_instance = -1; // We'll set it in the loop.
// and->dest_fst_instance will be set to this.
for (; !aiter.Done(); aiter.Next()) {
const StdArc &leaving_arc = aiter.Value();
int32 nonterminal, left_context_phone;
DecodeSymbol(leaving_arc.ilabel, &nonterminal,
&left_context_phone);
int32 child_instance_id = GetChildInstanceId(instance_id,
nonterminal,
leaving_arc.nextstate);
if (dest_fst_instance < 0) {
dest_fst_instance = child_instance_id;
} else if (dest_fst_instance != child_instance_id) {
KALDI_ERR << "Same state leaves to different FST instances "
"(Did you use PrepareForGrammarFst()?)";
}
const FstInstance &child_instance = instances_[child_instance_id];
const ConstFst<StdArc> &child_fst = *(child_instance.fst);
int32 child_ifst_index = child_instance.ifst_index;
std::unordered_map<int32, int32> &entry_arcs = entry_arcs_[child_ifst_index];
if (entry_arcs.empty())
InitEntryArcs(child_ifst_index);
// for explanation of cost_correction, see documentation for CombineArcs().
float num_entry_arcs = entry_arcs.size(),
cost_correction = -log(num_entry_arcs);
// Get the arc-index for the arc leaving the start-state of child FST that
// corresponds to this phonetic context.
std::unordered_map<int32, int32>::const_iterator entry_iter =
entry_arcs.find(left_context_phone);
if (entry_iter == entry_arcs.end()) {
KALDI_ERR << "FST for nonterminal " << nonterminal
<< " does not have an entry point for left-context-phone "
<< left_context_phone;
}
int32 arc_index = entry_iter->second;
ArcIterator<ConstFst<StdArc> > child_aiter(child_fst, child_fst.Start());
child_aiter.Seek(arc_index);
const StdArc &arriving_arc = child_aiter.Value();
StdArc arc;
CombineArcs(leaving_arc, arriving_arc, cost_correction, &arc);
ans->arcs.push_back(arc);
}
ans->dest_fst_instance = dest_fst_instance;
return ans;
}
void GrammarFst::Write(std::ostream &os, bool binary) const {
using namespace kaldi;
if (!binary)
KALDI_ERR << "GrammarFst::Write only supports binary mode.";
int32 format = 1,
num_ifsts = ifsts_.size();
WriteToken(os, binary, "<GrammarFst>");
WriteBasicType(os, binary, format);
WriteBasicType(os, binary, num_ifsts);
WriteBasicType(os, binary, nonterm_phones_offset_);
std::string stream_name("unknown");
FstWriteOptions wopts(stream_name);
top_fst_->Write(os, wopts);
for (int32 i = 0; i < num_ifsts; i++) {
int32 nonterminal = ifsts_[i].first;
WriteBasicType(os, binary, nonterminal);
ifsts_[i].second->Write(os, wopts);
}
WriteToken(os, binary, "</GrammarFst>");
}
static ConstFst<StdArc> *ReadConstFstFromStream(std::istream &is) {
fst::FstHeader hdr;
std::string stream_name("unknown");
if (!hdr.Read(is, stream_name))
KALDI_ERR << "Reading FST: error reading FST header";
FstReadOptions ropts("<unspecified>", &hdr);
ConstFst<StdArc> *ans = ConstFst<StdArc>::Read(is, ropts);
if (!ans)
KALDI_ERR << "Could not read ConstFst from stream.";
return ans;
}
void GrammarFst::Read(std::istream &is, bool binary) {
using namespace kaldi;
if (!binary)
KALDI_ERR << "GrammarFst::Read only supports binary mode.";
if (top_fst_ != NULL)
Destroy();
int32 format = 1, num_ifsts;
ExpectToken(is, binary, "<GrammarFst>");
ReadBasicType(is, binary, &format);
if (format != 1)
KALDI_ERR << "This version of the code cannot read this GrammarFst, "
"update your code.";
ReadBasicType(is, binary, &num_ifsts);
ReadBasicType(is, binary, &nonterm_phones_offset_);
top_fst_ = std::shared_ptr<const ConstFst<StdArc> >(ReadConstFstFromStream(is));
for (int32 i = 0; i < num_ifsts; i++) {
int32 nonterminal;
ReadBasicType(is, binary, &nonterminal);
std::shared_ptr<const ConstFst<StdArc> >
this_fst(ReadConstFstFromStream(is));
ifsts_.push_back(std::pair<int32, std::shared_ptr<const ConstFst<StdArc> > >(
nonterminal, this_fst));
}
Init();
}
/**
This utility function input-determinizes a specified state s of the FST
'fst'. (This input-determinizes while treating epsilon as a real symbol,
although for the application we expect to use it, there won't be epsilons).
What this function does is: for any symbol i that appears as the ilabel of
more than one arc leaving state s of FST 'fst', it creates an additional
state, it creates a new state t with epsilon-input transitions leaving it for
each of those multiple arcs leaving state s; it deletes the original arcs
leaving state s; and it creates a single arc leaving state s to the newly
created state with the ilabel i on it. It sets the weights as necessary to
preserve equivalence and also to ensure that if, prior to this modification,
the FST was stochastic when cast to the log semiring (see
IsStochasticInLog()), it still will be. I.e. when interpreted as
negative logprobs, the weight from state s to t would be the sum of
the weights on the original arcs leaving state s.
This is used as a very cheap solution when preparing FSTs for the grammar
decoder, to ensure that there is only one entry-state to the sub-FST for each
phonetic left-context; this keeps the grammar-FST code (i.e. the code that
stitches them together) simple. Of course it will tend to introduce
unnecessary epsilons, and if we were careful we might be able to remove
some of those, but this wouldn't have a substantial impact on overall
decoder performance so we don't bother.
*/
static void InputDeterminizeSingleState(StdArc::StateId s,
VectorFst<StdArc> *fst) {
bool was_input_deterministic = true;
typedef StdArc Arc;
typedef Arc::StateId StateId;
typedef Arc::Label Label;
typedef Arc::Weight Weight;
struct InfoForIlabel {
std::vector<size_t> arc_indexes; // indexes of all arcs with this ilabel
float tot_cost; // total cost of all arcs leaving state s for this
// ilabel, summed as if they were negative log-probs.
StateId new_state; // state-id of new state, if any, that we have created
// to remove duplicate symbols with this ilabel.
InfoForIlabel(): new_state(-1) { }
};
std::unordered_map<Label, InfoForIlabel> label_map;
size_t arc_index = 0;
for (ArcIterator<VectorFst<Arc> > aiter(*fst, s);
!aiter.Done(); aiter.Next(), ++arc_index) {
const Arc &arc = aiter.Value();
InfoForIlabel &info = label_map[arc.ilabel];
if (info.arc_indexes.empty()) {
info.tot_cost = arc.weight.Value();
} else {
info.tot_cost = -kaldi::LogAdd(-info.tot_cost, -arc.weight.Value());
was_input_deterministic = false;
}
info.arc_indexes.push_back(arc_index);
}
if (was_input_deterministic)
return; // Nothing to do.
// 'new_arcs' will contain the modified list of arcs
// leaving state s
std::vector<Arc> new_arcs;
new_arcs.reserve(arc_index);
arc_index = 0;
for (ArcIterator<VectorFst<Arc> > aiter(*fst, s);
!aiter.Done(); aiter.Next(), ++arc_index) {
const Arc &arc = aiter.Value();
Label ilabel = arc.ilabel;
InfoForIlabel &info = label_map[ilabel];
if (info.arc_indexes.size() == 1) {
new_arcs.push_back(arc); // no changes needed
} else {
if (info.new_state < 0) {
info.new_state = fst->AddState();
// add arc from state 's' to newly created state.
new_arcs.push_back(Arc(ilabel, 0, Weight(info.tot_cost),
info.new_state));
}
// add arc from new state to original destination of this arc.
fst->AddArc(info.new_state, Arc(0, arc.olabel,
Weight(arc.weight.Value() - info.tot_cost),
arc.nextstate));
}
}
fst->DeleteArcs(s);
for (size_t i = 0; i < new_arcs.size(); i++)
fst->AddArc(s, new_arcs[i]);
}
// This class contains the implementation of the function
// PrepareForGrammarFst(), which is declared in grammar-fst.h.
class GrammarFstPreparer {
public:
using FST = VectorFst<StdArc>;
using Arc = StdArc;
using StateId = Arc::StateId;
using Label = Arc::Label;
using Weight = Arc::Weight;
GrammarFstPreparer(int32 nonterm_phones_offset,
VectorFst<StdArc> *fst):
nonterm_phones_offset_(nonterm_phones_offset),
fst_(fst), orig_num_states_(fst->NumStates()),
simple_final_state_(kNoStateId) { }
void Prepare() {
if (fst_->Start() == kNoStateId) {
KALDI_ERR << "FST has no states.";
}
for (StateId s = 0; s < fst_->NumStates(); s++) {
if (IsSpecialState(s)) {
if (NeedEpsilons(s)) {
InsertEpsilonsForState(s);
// This state won't be treated as a 'special' state any more;
// all 'special' arcs (arcs with ilabels >= kNontermBigNumber)
// have been moved and now leave from newly created states that
// this state transitions to via epsilons arcs.
} else {
// OK, state s is a special state.
FixArcsToFinalStates(s);
MaybeAddFinalProbToState(s);
// The following ensures that the start-state of sub-FSTs only has
// a single arc per left-context phone (the graph-building recipe can
// end up creating more than one if there were disambiguation symbols,
// e.g. for langauge model backoff).
if (s == fst_->Start() && IsEntryState(s))
InputDeterminizeSingleState(s, fst_);
}
}
}
StateId num_new_states = fst_->NumStates() - orig_num_states_;
KALDI_LOG << "Added " << num_new_states << " new states while "
"preparing for grammar FST.";
}
private:
// Returns true if state 's' has at least one arc coming out of it with a
// special nonterminal-related ilabel on it (i.e. an ilabel >=
// kNontermBigNumber), and false otherwise.
bool IsSpecialState(StateId s) const;
// This function verifies that state s does not currently have any
// final-prob (crashes if that fails); then, if the arcs leaving s have
// nonterminal symbols kNontermEnd or user-defined nonterminals (>=
// kNontermUserDefined), it adds a final-prob with cost given by
// KALDI_GRAMMAR_FST_SPECIAL_WEIGHT to the state.
//
// State s is required to be a 'special state', i.e. have special symbols on
// arcs leaving it, and the function assumes (since it will already
// have been checked) that the arcs leaving s, if there are more than
// one, all correspond to the same nonterminal symbol.
void MaybeAddFinalProbToState(StateId s);
// This function does some checking for 'special states', that they have
// certain expected properties, and also detects certain problematic
// conditions that we need to fix. It returns true if we need to
// modify this state (by adding input-epsilon arcs), and false otherwise.
bool NeedEpsilons(StateId s) const;
// Returns true if state s (which is expected to be the start state, although we
// don't check this) has arcs with nonterminal symbols #nonterm_begin.
bool IsEntryState(StateId s) const;
// Fixes any final-prob-related problems with this state. The problem we aim
// to fix is that there may be arcs with nonterminal symbol #nonterm_end which
// transition from this state to a state with non-unit final prob. This
// function assimilates that final-prob into the arc leaving from this state,
// by making the arc transition to a new state with unit final-prob, and
// incorporating the original final-prob into the arc's weight.
//
// The purpose of this is to keep the GrammarFst code simple.
//
// It would have been more efficient to do this in CheckProperties(), but
// doing it this way is clearer; and the extra time taken here will be tiny.
void FixArcsToFinalStates(StateId s);
// This struct represents a category of arcs that are allowed to leave from
// the same 'special state'. If a special state has arcs leaving it that
// are in more than one category, it will need to be split up into
// multiple states connected by epsilons.
//
// The 'nonterminal' and 'nextstate' have to do with ensuring that all
// arcs leaving a particular FST state transition to the same FST instance
// (which, in turn, helps to keep the ArcIterator code efficient).
//
// The 'olabel' has to do with ensuring that arcs with user-defined
// nonterminals or kNontermEnd have no olabels on them. This is a requirement
// of the CombineArcs() function of GrammarFst, because it needs to combine
// two olabels into one so we need to know that at least one of the olabels is
// always epsilon.
struct ArcCategory {
int32 nonterminal; // The nonterminal symbol #nontermXXX encoded into the ilabel,
// or 0 if the ilabel was <kNontermBigNumber.
StateId nextstate; // If 'nonterminal' is a user-defined nonterminal like
// #nonterm:foo,
// then the destination state of the arc, else kNoStateId (-1).
Label olabel; // If 'nonterminal' is #nonterm_end or is a user-defined
// nonterminal (e.g. #nonterm:foo), then the olabel on the
// arc; else, 0.
bool operator < (const ArcCategory &other) const {
if (nonterminal < other.nonterminal) return true;
else if (nonterminal > other.nonterminal) return false;
if (nextstate < other.nextstate) return true;
else if (nextstate > other.nextstate) return false;
return olabel < other.olabel;
}
};
// This function, which is used in CheckProperties() and
// InsertEpsilonsForState(), works out the categrory of the arc; see
// documentation of struct ArcCategory for more details.
void GetCategoryOfArc(const Arc &arc,
ArcCategory *arc_category) const;
// This will be called for 'special states' that need to be split up.
// Non-special arcs leaving this state will stay here. For each
// category of special arcs (see ArcCategory for details), a new
// state will be created and those arcs will leave from that state
// instead; for each such state, an input-epsilon arc will leave this state
// for that state. For more details, see the code.
void InsertEpsilonsForState(StateId s);
inline int32 GetPhoneSymbolFor(enum NonterminalValues n) const {
return nonterm_phones_offset_ + static_cast<int32>(n);
}
int32 nonterm_phones_offset_;
VectorFst<StdArc> *fst_;
StateId orig_num_states_;
// If needed we may add a 'simple final state' to fst_, which has unit
// final-prob. This is used when we ensure that states with kNontermExit on
// them transition to a state with unit final-prob, so we don't need to
// look at the final-prob when expanding states.
StateId simple_final_state_;
};
bool GrammarFstPreparer::IsSpecialState(StateId s) const {
if (fst_->Final(s).Value() == KALDI_GRAMMAR_FST_SPECIAL_WEIGHT) {
// TODO: find a way to detect if it was a coincidence, or not make it an
// error, because in principle a user-defined grammar could contain this
// special cost.
KALDI_WARN << "It looks like you are calling PrepareForGrammarFst twice.";
}
for (ArcIterator<FST> aiter(*fst_, s ); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel >= kNontermBigNumber) // 1 million
return true;
}
return false;
}
bool GrammarFstPreparer::IsEntryState(StateId s) const {
int32 big_number = kNontermBigNumber,
encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_);
for (ArcIterator<FST> aiter(*fst_, s ); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
int32 nonterminal = (arc.ilabel - big_number) /
encoding_multiple;
// we check that at least one has label with nonterminal equal to #nonterm_begin...
// in fact they will all have this value if at least one does, and this was checked
// in NeedEpsilons().
if (nonterminal == GetPhoneSymbolFor(kNontermBegin))
return true;
}
return false;
}
bool GrammarFstPreparer::NeedEpsilons(StateId s) const {
// See the documentation for GetCategoryOfArc() for explanation of what these are.
std::set<ArcCategory> categories;
if (fst_->Final(s) != Weight::Zero()) {
// A state having a final-prob is considered the same as it having
// a non-nonterminal arc out of it.. this would be like a transition
// within the same FST.
ArcCategory category;
category.nonterminal = 0;
category.nextstate = kNoStateId;
category.olabel = 0;
categories.insert(category);
}
int32 big_number = kNontermBigNumber,
encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_);
for (ArcIterator<FST> aiter(*fst_, s ); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
ArcCategory category;
GetCategoryOfArc(arc, &category);
categories.insert(category);
// the rest of this block is just checking.
int32 nonterminal = category.nonterminal;
if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) {
// Check that the destination state of this arc has arcs with
// kNontermReenter on them. We'll separately check that such states
// don't have other types of arcs leaving them (search for
// kNontermReenter below), so it's sufficient to check the first arc.
ArcIterator<FST> next_aiter(*fst_, arc.nextstate);
if (next_aiter.Done())
KALDI_ERR << "Destination state of a user-defined nonterminal "
"has no arcs leaving it.";
const Arc &next_arc = next_aiter.Value();
int32 next_nonterminal = (next_arc.ilabel - big_number) /
encoding_multiple;
if (next_nonterminal != GetPhoneSymbolFor(kNontermReenter)) {
KALDI_ERR << "Expected arcs with user-defined nonterminals to be "
"followed by arcs with kNontermReenter.";
}
}
if (nonterminal == GetPhoneSymbolFor(kNontermBegin) &&
s != fst_->Start()) {
KALDI_ERR << "#nonterm_begin symbol is present but this is not the "
"first state. Did you do fstdeterminizestar while compiling?";
}
if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) {
if (fst_->NumArcs(arc.nextstate) != 0 ||
fst_->Final(arc.nextstate) == Weight::Zero()) {
KALDI_ERR << "Arc with kNontermEnd is not the final arc.";
}
}
}
if (categories.size() > 1) {
// This state has arcs leading to multiple FST instances.
// Do some checking to see that there is nothing really unexpected in
// there.
for (std::set<ArcCategory>::const_iterator
iter = categories.begin();
iter != categories.end(); ++iter) {
int32 nonterminal = iter->nonterminal;
if (nonterminal == nonterm_phones_offset_ + kNontermBegin ||
nonterminal == nonterm_phones_offset_ + kNontermReenter)
// we don't expect any state which has symbols like (kNontermBegin:p1)
// on arcs coming out of it, to also have other types of symbol. The
// same goes for kNontermReenter.
KALDI_ERR << "We do not expect states with arcs of type "
"kNontermBegin/kNontermReenter coming out of them, to also have "
"other types of arc.";
}
}
// the first half of the || below relates to olabels on arcs with either
// user-defined nonterminals or #nonterm_end (which would become 'leaving_arc'
// in the CombineArcs() function of GrammarFst). That function does not allow
// nonzero olabels on 'leaving_arc', which would be a problem if the
// 'arriving' arc had nonzero olabels, so we solve this by introducing
// input-epsilon arcs and putting the olabels on them instead.
bool need_epsilons = (categories.size() == 1 &&
categories.begin()->olabel != 0) ||
categories.size() > 1;
return need_epsilons;
}
void GrammarFstPreparer::FixArcsToFinalStates(StateId s) {
int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_),
big_number = kNontermBigNumber;
for (MutableArcIterator<FST> aiter(fst_, s ); !aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
if (arc.ilabel < big_number)
continue;
int32 nonterminal = (arc.ilabel - big_number) / encoding_multiple;
if (nonterminal == GetPhoneSymbolFor(kNontermEnd)) {
KALDI_ASSERT(fst_->NumArcs(arc.nextstate) == 0 &&
fst_->Final(arc.nextstate) != Weight::Zero());
if (fst_->Final(arc.nextstate) == Weight::One())
continue; // There is no problem to fix.
if (simple_final_state_ == kNoStateId) {
simple_final_state_ = fst_->AddState();
fst_->SetFinal(simple_final_state_, Weight::One());
}
arc.weight = Times(arc.weight, fst_->Final(arc.nextstate));
arc.nextstate = simple_final_state_;
aiter.SetValue(arc);
}
}
}
void GrammarFstPreparer::MaybeAddFinalProbToState(StateId s) {
if (fst_->Final(s) != Weight::Zero()) {
// Something went wrong and it will require some debugging. In Prepare(),
// if we detected that the special state had a nonzero final-prob, we
// would have inserted epsilons to remove it, so there may be a bug in
// this class's code.
KALDI_ERR << "State already final-prob.";
}
ArcIterator<FST> aiter(*fst_, s );
KALDI_ASSERT(!aiter.Done());
const Arc &arc = aiter.Value();
int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_),
big_number = kNontermBigNumber,
nonterminal = (arc.ilabel - big_number) / encoding_multiple;
KALDI_ASSERT(nonterminal >= GetPhoneSymbolFor(kNontermBegin));
if (nonterminal == GetPhoneSymbolFor(kNontermEnd) ||
nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) {
fst_->SetFinal(s, Weight(KALDI_GRAMMAR_FST_SPECIAL_WEIGHT));
}
}
void GrammarFstPreparer::GetCategoryOfArc(
const Arc &arc, ArcCategory *arc_category) const {
int32 encoding_multiple = GetEncodingMultiple(nonterm_phones_offset_),
big_number = kNontermBigNumber;
int32 ilabel = arc.ilabel;
if (ilabel < big_number) {
arc_category->nonterminal = 0;
arc_category->nextstate = kNoStateId;
arc_category->olabel = 0;
} else {
int32 nonterminal = (ilabel - big_number) / encoding_multiple;
arc_category->nonterminal = nonterminal;
if (nonterminal <= nonterm_phones_offset_) {
KALDI_ERR << "Problem decoding nonterminal symbol "
"(wrong --nonterm-phones-offset option?), ilabel="
<< ilabel;
}
if (nonterminal >= GetPhoneSymbolFor(kNontermUserDefined)) {
// This is a user-defined symbol.
arc_category->nextstate = arc.nextstate;
arc_category->olabel = arc.olabel;
} else {
arc_category->nextstate = kNoStateId;
if (nonterminal == GetPhoneSymbolFor(kNontermEnd))
arc_category->olabel = arc.olabel;
else
arc_category->olabel = 0;
}
}
}
void GrammarFstPreparer::InsertEpsilonsForState(StateId s) {
// Maps from category of arc, to a pair:
// the StateId is the state corresponding to that category.
// the float is the cost on the arc leading to that state;
// we compute the value that corresponds to the sum of the probabilities
// of the leaving arcs, bearing in mind that p = exp(-cost).
// We don't insert the arc-category whose 'nonterminal' is 0 here (i.e. the
// category for normal arcs); those arcs stay at this state.
std::map<ArcCategory, std::pair<StateId, float> > category_to_state;
// This loop sets up 'category_to_state'.
for (fst::ArcIterator<FST> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
ArcCategory category;
GetCategoryOfArc(arc, &category);
int32 nonterminal = category.nonterminal;
if (nonterminal == 0)
continue;
if (nonterminal == GetPhoneSymbolFor(kNontermBegin) ||
nonterminal == GetPhoneSymbolFor(kNontermReenter)) {
KALDI_ERR << "Something went wrong; did not expect to insert epsilons "
"for this type of state.";
}
auto iter = category_to_state.find(category);
if (iter == category_to_state.end()) {
StateId new_state = fst_->AddState();
float cost = arc.weight.Value();
category_to_state[category] = std::pair<StateId, float>(new_state, cost);
} else {
std::pair<StateId, float> &p = iter->second;
p.second = -kaldi::LogAdd(-p.second, -arc.weight.Value());
}
}
KALDI_ASSERT(!category_to_state.empty()); // would be a code error.
// 'arcs_from_this_state' is a place to put arcs that will put on this state
// after we delete all its existing arcs.
std::vector<Arc> arcs_from_this_state;
arcs_from_this_state.reserve(fst_->NumArcs(s) + category_to_state.size());
// add arcs corresponding to transitions to the newly created states, to
// 'arcs_from_this_state'
for (std::map<ArcCategory, std::pair<StateId, float> >::const_iterator
iter = category_to_state.begin(); iter != category_to_state.end();
++iter) {
const ArcCategory &category = iter->first;
StateId new_state = iter->second.first;
float cost = iter->second.second;
Arc arc;
arc.ilabel = 0;
arc.olabel = category.olabel;
arc.weight = Weight(cost);
arc.nextstate = new_state;
arcs_from_this_state.push_back(arc);
}
// Now add to 'arcs_from_this_state', and to the newly created states,
// arcs corresponding to each of the arcs that were originally leaving
// this state.
for (fst::ArcIterator<FST> aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
ArcCategory category;
GetCategoryOfArc(arc, &category);
int32 nonterminal = category.nonterminal;
if (nonterminal == 0) { // this arc remains unchanged; we'll put it back later.
arcs_from_this_state.push_back(arc);
continue;
}
auto iter = category_to_state.find(category);
KALDI_ASSERT(iter != category_to_state.end());
Arc new_arc;
new_arc.ilabel = arc.ilabel;
if (arc.olabel == category.olabel) {
new_arc.olabel = 0; // the olabel went on the epsilon-input arc.
} else {
KALDI_ASSERT(category.olabel == 0);
new_arc.olabel = arc.olabel;
}
StateId new_state = iter->second.first;
float epsilon_arc_cost = iter->second.second;
new_arc.weight = Weight(arc.weight.Value() - epsilon_arc_cost);
new_arc.nextstate = arc.nextstate;
fst_->AddArc(new_state, new_arc);
}
fst_->DeleteArcs(s);
for (size_t i = 0; i < arcs_from_this_state.size(); i++) {
fst_->AddArc(s, arcs_from_this_state[i]);
}
// leave the final-prob on this state as it was before.
}
void PrepareForGrammarFst(int32 nonterm_phones_offset,
VectorFst<StdArc> *fst) {
GrammarFstPreparer p(nonterm_phones_offset, fst);
p.Prepare();
}
void CopyToVectorFst(GrammarFst *grammar_fst,
VectorFst<StdArc> *vector_fst) {
typedef GrammarFstArc::StateId GrammarStateId; // int64
typedef StdArc::StateId StdStateId; // int
typedef StdArc::Label Label;
typedef StdArc::Weight Weight;
std::vector<std::pair<GrammarStateId, StdStateId> > queue;
std::unordered_map<GrammarStateId, StdStateId> state_map;
vector_fst->DeleteStates();
state_map[grammar_fst->Start()] = vector_fst->AddState(); // state 0.
vector_fst->SetStart(0);
queue.push_back(
std::pair<GrammarStateId, StdStateId>(grammar_fst->Start(), 0));
while (!queue.empty()) {
std::pair<GrammarStateId, StdStateId> p = queue.back();
queue.pop_back();
GrammarStateId grammar_state = p.first;
StdStateId std_state = p.second;
vector_fst->SetFinal(std_state, grammar_fst->Final(grammar_state));
ArcIterator<GrammarFst> aiter(*grammar_fst, grammar_state);
for (; !aiter.Done(); aiter.Next()) {
const GrammarFstArc &grammar_arc = aiter.Value();
StdArc std_arc;
std_arc.ilabel = grammar_arc.ilabel;
std_arc.olabel = grammar_arc.olabel;
std_arc.weight = grammar_arc.weight;
GrammarStateId next_grammar_state = grammar_arc.nextstate;
StdStateId next_std_state;
std::unordered_map<GrammarStateId, StdStateId>::const_iterator
state_iter = state_map.find(next_grammar_state);
if (state_iter == state_map.end()) {
next_std_state = vector_fst->AddState();
state_map[next_grammar_state] = next_std_state;
queue.push_back(std::pair<GrammarStateId, StdStateId>(
next_grammar_state, next_std_state));
} else {
next_std_state = state_iter->second;
}
std_arc.nextstate = next_std_state;
vector_fst->AddArc(std_state, std_arc);
}
}
}
} // end namespace fst