am-diag-gmm-test.cc
4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// gmm/am-diag-gmm-test.cc
// Copyright 2009-2011 Saarland University
// Author: Arnab Ghoshal
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "gmm/model-test-common.h"
#include "gmm/am-diag-gmm.h"
#include "util/kaldi-io.h"
using kaldi::AmDiagGmm;
using kaldi::int32;
using kaldi::BaseFloat;
namespace ut = kaldi::unittest;
// Tests the Read() and Write() methods, in both binary and ASCII mode, as well
// as Check(), CopyFromSgmm(), and methods in likelihood computations.
void TestAmDiagGmmIO(const AmDiagGmm &am_gmm) {
int32 dim = am_gmm.Dim();
kaldi::Vector<BaseFloat> feat(dim);
for (int32 d = 0; d < dim; d++) {
feat(d) = kaldi::RandGauss();
}
BaseFloat loglike = 0.0;
for (int32 i = 0; i < am_gmm.NumPdfs(); i++)
loglike += am_gmm.LogLikelihood(i, feat);
// First, non-binary write
am_gmm.Write(kaldi::Output("tmpf", false).Stream(), false);
bool binary_in;
AmDiagGmm *am_gmm1 = new AmDiagGmm();
// Non-binary read
kaldi::Input ki1("tmpf", &binary_in);
am_gmm1->Read(ki1.Stream(), binary_in);
BaseFloat loglike1 = 0.0;
for (int32 i = 0; i < am_gmm1->NumPdfs(); i++)
loglike1 += am_gmm1->LogLikelihood(i, feat);
kaldi::AssertEqual(loglike, loglike1, 1e-4);
// Next, binary write
am_gmm1->Write(kaldi::Output("tmpfb", true).Stream(), true);
delete am_gmm1;
AmDiagGmm *am_gmm2 = new AmDiagGmm();
// Binary read
kaldi::Input ki2("tmpfb", &binary_in);
am_gmm2->Read(ki2.Stream(), binary_in);
BaseFloat loglike2 = 0.0;
for (int32 i = 0; i < am_gmm2->NumPdfs(); i++)
loglike2 += am_gmm2->LogLikelihood(i, feat);
kaldi::AssertEqual(loglike, loglike2, 1e-4);
delete am_gmm2;
unlink("tmpf");
unlink("tmpfb");
}
void TestSplitStates(const AmDiagGmm &am_gmm) {
int32 target_comp = 2 * am_gmm.NumGauss();
kaldi::Vector<BaseFloat> occs(am_gmm.NumPdfs());
for (int32 i = 0; i < occs.Dim(); i++)
occs(i) = std::fabs(kaldi::RandGauss()) * (kaldi::RandUniform()+1) * 4;
AmDiagGmm *am_gmm1 = new AmDiagGmm();
am_gmm1->CopyFromAmDiagGmm(am_gmm);
am_gmm1->SplitByCount(occs, target_comp, 0.01, 0.2, 0.0);
int32 dim = am_gmm.Dim();
kaldi::Vector<BaseFloat> feat(dim);
for (int32 d = 0; d < dim; d++) {
feat(d) = kaldi::RandGauss();
}
BaseFloat loglike = am_gmm.LogLikelihood(0, feat);
BaseFloat loglike1 = am_gmm1->LogLikelihood(0, feat);
kaldi::AssertEqual(loglike, loglike1, 1e-2);
delete am_gmm1;
}
void TestClustering(const AmDiagGmm &am_gmm) {
int32 target_comp = am_gmm.NumGauss() / 5,
interm_comp = am_gmm.NumGauss() / 2;
kaldi::Vector<BaseFloat> occs(am_gmm.NumPdfs());
for (int32 i = 0; i < occs.Dim(); i++)
occs(i) = std::fabs(kaldi::RandGauss()) * (kaldi::RandUniform()+1) * 4;
kaldi::UbmClusteringOptions ubm_opts(target_comp, 0.2, interm_comp, 0.01, 30);
kaldi::DiagGmm ubm;
ClusterGaussiansToUbm(am_gmm, occs, ubm_opts, &ubm);
}
void UnitTestAmDiagGmm() {
int32 dim = 1 + kaldi::RandInt(0, 9), // random dimension of the gmm
num_pdfs = 5 + kaldi::RandInt(0, 9); // random number of states
AmDiagGmm am_gmm;
for (int32 i = 0; i < num_pdfs; i++) {
int32 num_comp = 1 + kaldi::RandInt(0, 9); // random number of mixtures
kaldi::DiagGmm gmm;
ut::InitRandDiagGmm(dim, num_comp, &gmm);
am_gmm.AddPdf(gmm);
}
TestAmDiagGmmIO(am_gmm);
TestSplitStates(am_gmm);
TestClustering(am_gmm);
}
int main() {
for (int i = 0; i < 5; i++)
UnitTestAmDiagGmm();
std::cout << "Test OK.\n";
return 0;
}