tree-accu.cc
6.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// hmm/tree-accu.cc
// Copyright 2009-2011 Microsoft Corporation
// 2013 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "util/kaldi-io.h"
#include "hmm/tree-accu.h"
#include "hmm/hmm-utils.h"
namespace kaldi {
static int32 MapPhone(const std::vector<int32> &phone_map,
int32 phone) {
if (phone == 0 || phone_map.empty()) return phone;
else if (phone < 0 || phone >= phone_map.size()) {
KALDI_ERR << "Out-of-range phone " << phone << " bad --phone-map option?";
}
return phone_map[phone];
}
void AccumulateTreeStats(const TransitionModel &trans_model,
const AccumulateTreeStatsInfo &info,
const std::vector<int32> &alignment,
const Matrix<BaseFloat> &features,
std::map<EventType, GaussClusterable*> *stats) {
std::vector<std::vector<int32> > split_alignment;
bool ans = SplitToPhones(trans_model, alignment, &split_alignment);
if (!ans) {
KALDI_WARN << "AccumulateTreeStats: alignment appears to be bad, not using it";
return;
}
int32 cur_pos = 0;
int32 dim = features.NumCols();
KALDI_ASSERT(features.NumRows() == static_cast<int32>(alignment.size()));
for (int32 i = -info.context_width; i < static_cast<int32>(split_alignment.size()); i++) {
// consider window starting at i, only if i+info.central_position is within
// list of phones.
if (i + info.central_position >= 0 &&
i + info.central_position < static_cast<int32>(split_alignment.size())) {
int32 central_phone =
MapPhone(info.phone_map,
trans_model.TransitionIdToPhone(
split_alignment[i+info.central_position][0]));
bool is_ctx_dep = !std::binary_search(info.ci_phones.begin(),
info.ci_phones.end(),
central_phone);
EventType evec;
for (int32 j = 0; j < info.context_width; j++) {
int32 phone;
if (i + j >= 0 && i + j < static_cast<int32>(split_alignment.size()))
phone =
MapPhone(info.phone_map,
trans_model.TransitionIdToPhone(split_alignment[i+j][0]));
else
phone = 0; // ContextDependency class uses 0 to mean "out of window";
// we also set the phone arbitrarily to 0
// Don't add stuff to the event that we don't "allow" to be asked, due
// to the central phone being context-independent: check "is_ctx_dep".
// Why not just set the value to zero in this
// case? It's for safety. By omitting the key from the event, we
// ensure that there is no way a question can ever be asked that might
// give an inconsistent answer in tree-training versus graph-building.
// [setting it to zero would have the same effect given the "normal"
// recipe but might be less robust to changes in tree-building recipe].
if (is_ctx_dep || j == info.central_position)
evec.push_back(std::make_pair(static_cast<EventKeyType>(j), static_cast<EventValueType>(phone)));
}
for (int32 j = 0; j < static_cast<int32>(split_alignment[i+info.central_position].size());j++) {
// for central phone of this window...
EventType evec_more(evec);
int32 pdf_class = trans_model.TransitionIdToPdfClass(
split_alignment[i+info.central_position][j]);
// pdf_class will normally by 0, 1 or 2 for 3-state HMM.
std::pair<EventKeyType, EventValueType> pr(kPdfClass, pdf_class);
evec_more.push_back(pr);
std::sort(evec_more.begin(), evec_more.end()); // these must be sorted!
if (stats->count(evec_more) == 0)
(*stats)[evec_more] = new GaussClusterable(dim, info.var_floor);
BaseFloat weight = 1.0;
(*stats)[evec_more]->AddStats(features.Row(cur_pos), weight);
cur_pos++;
}
}
}
KALDI_ASSERT(cur_pos == static_cast<int32>(alignment.size()));
}
void ReadPhoneMap(std::string phone_map_rxfilename,
std::vector<int32> *phone_map) {
phone_map->clear();
// phone map file has format e.g.:
// 1 1
// 2 1
// 3 2
// 4 2
std::vector<std::vector<int32> > vec; // vector of vectors, each with two elements
// (if file has right format). first is old phone, second is new phone
if (!ReadIntegerVectorVectorSimple(phone_map_rxfilename, &vec))
KALDI_ERR << "Error reading phone map from " <<
PrintableRxfilename(phone_map_rxfilename);
for (size_t i = 0; i < vec.size(); i++) {
if (vec[i].size() != 2 || vec[i][0]<=0 || vec[i][1]<=0 ||
(vec[i][0]<static_cast<int32>(phone_map->size()) &&
(*phone_map)[vec[i][0]] != -1))
KALDI_ERR << "Error reading phone map from "
<< PrintableRxfilename(phone_map_rxfilename)
<< " (bad line " << i << ")";
if (vec[i][0]>=static_cast<int32>(phone_map->size()))
phone_map->resize(vec[i][0]+1, -1);
KALDI_ASSERT((*phone_map)[vec[i][0]] == -1);
(*phone_map)[vec[i][0]] = vec[i][1];
}
if (phone_map->empty()) {
KALDI_ERR << "Read empty phone map from "
<< PrintableRxfilename(phone_map_rxfilename);
}
}
AccumulateTreeStatsInfo::AccumulateTreeStatsInfo(
const AccumulateTreeStatsOptions &opts):
var_floor(opts.var_floor),
context_width(opts.context_width),
central_position(opts.central_position) {
if (central_position < 0 || context_width <= central_position)
KALDI_ERR << "Invalid options: --central-position=" << central_position
<< ", --context-width=" << context_width;
if (!opts.phone_map_rxfilename.empty())
ReadPhoneMap(opts.phone_map_rxfilename, &phone_map);
if (!opts.ci_phones_str.empty()) {
SplitStringToIntegers(opts.ci_phones_str, ":", false, &ci_phones);
std::sort(ci_phones.begin(), ci_phones.end());
if (!IsSortedAndUniq(ci_phones) || ci_phones.empty() || ci_phones[0] == 0)
KALDI_ERR << "Invalid --ci-phones option: " << opts.ci_phones_str;
}
}
} // end namespace kaldi