Blame view

src/base/kaldi-math.h 9.99 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
  // base/kaldi-math.h
  
  // Copyright 2009-2011  Ondrej Glembek;  Microsoft Corporation;  Yanmin Qian;
  //                      Jan Silovsky;  Saarland University
  //
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #ifndef KALDI_BASE_KALDI_MATH_H_
  #define KALDI_BASE_KALDI_MATH_H_ 1
  
  #ifdef _MSC_VER
  #include <float.h>
  #endif
  
  #include <cmath>
  #include <limits>
  #include <vector>
  
  #include "base/kaldi-types.h"
  #include "base/kaldi-common.h"
  
  
  #ifndef DBL_EPSILON
  #define DBL_EPSILON 2.2204460492503131e-16
  #endif
  #ifndef FLT_EPSILON
  #define FLT_EPSILON 1.19209290e-7f
  #endif
  
  #ifndef M_PI
  #define M_PI 3.1415926535897932384626433832795
  #endif
  
  #ifndef M_SQRT2
  #define M_SQRT2 1.4142135623730950488016887
  #endif
  
  #ifndef M_2PI
  #define M_2PI 6.283185307179586476925286766559005
  #endif
  
  #ifndef M_SQRT1_2
  #define M_SQRT1_2 0.7071067811865475244008443621048490
  #endif
  
  #ifndef M_LOG_2PI
  #define M_LOG_2PI 1.8378770664093454835606594728112
  #endif
  
  #ifndef M_LN2
  #define M_LN2 0.693147180559945309417232121458
  #endif
  
  #ifndef M_LN10
  #define M_LN10 2.302585092994045684017991454684
  #endif
  
  
  #define KALDI_ISNAN std::isnan
  #define KALDI_ISINF std::isinf
  #define KALDI_ISFINITE(x) std::isfinite(x)
  
  #if !defined(KALDI_SQR)
  # define KALDI_SQR(x) ((x) * (x))
  #endif
  
  namespace kaldi {
  
  #if !defined(_MSC_VER) || (_MSC_VER >= 1900)
  inline double Exp(double x) { return exp(x); }
  #ifndef KALDI_NO_EXPF
  inline float Exp(float x) { return expf(x); }
  #else
  inline float Exp(float x) { return exp(static_cast<double>(x)); }
  #endif  // KALDI_NO_EXPF
  #else
  inline double Exp(double x) { return exp(x); }
  #if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
  // Microsoft CL v18.0 buggy 64-bit implementation of
  // expf() incorrectly returns -inf for exp(-inf).
  inline float Exp(float x) { return exp(static_cast<double>(x)); }
  #else
  inline float Exp(float x) { return expf(x); }
  #endif  // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
  #endif  // !defined(_MSC_VER) || (_MSC_VER >= 1900)
  
  inline double Log(double x) { return log(x); }
  inline float Log(float x) { return logf(x); }
  
  #if !defined(_MSC_VER) || (_MSC_VER >= 1700)
  inline double Log1p(double x) {  return log1p(x); }
  inline float Log1p(float x) {  return log1pf(x); }
  #else
  inline double Log1p(double x) {
    const double cutoff = 1.0e-08;
    if (x < cutoff)
      return x - 0.5 * x * x;
    else
      return Log(1.0 + x);
  }
  
  inline float Log1p(float x) {
    const float cutoff = 1.0e-07;
    if (x < cutoff)
      return x - 0.5 * x * x;
    else
      return Log(1.0 + x);
  }
  #endif
  
  static const double kMinLogDiffDouble = Log(DBL_EPSILON);  // negative!
  static const float kMinLogDiffFloat = Log(FLT_EPSILON);  // negative!
  
  // -infinity
  const float kLogZeroFloat = -std::numeric_limits<float>::infinity();
  const double kLogZeroDouble = -std::numeric_limits<double>::infinity();
  const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity();
  
  // Returns a random integer between 0 and RAND_MAX, inclusive
  int Rand(struct RandomState* state = NULL);
  
  // State for thread-safe random number generator
  struct RandomState {
    RandomState();
    unsigned seed;
  };
  
  // Returns a random integer between first and last inclusive.
  int32 RandInt(int32 first, int32 last, struct RandomState* state = NULL);
  
  // Returns true with probability "prob",
  bool WithProb(BaseFloat prob, struct RandomState* state = NULL);
  // with 0 <= prob <= 1 [we check this].
  // Internally calls Rand().  This function is carefully implemented so
  // that it should work even if prob is very small.
  
  /// Returns a random number strictly between 0 and 1.
  inline float RandUniform(struct RandomState* state = NULL) {
    return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX+2.0));
  }
  
  inline float RandGauss(struct RandomState* state = NULL) {
    return static_cast<float>(sqrtf (-2 * Log(RandUniform(state)))
                              * cosf(2*M_PI*RandUniform(state)));
  }
  
  // Returns poisson-distributed random number.  Uses Knuth's algorithm.
  // Take care: this takes time proportional
  // to lambda.  Faster algorithms exist but are more complex.
  int32 RandPoisson(float lambda, struct RandomState* state = NULL);
  
  // Returns a pair of gaussian random numbers. Uses Box-Muller transform
  void RandGauss2(float *a, float *b, RandomState *state = NULL);
  void RandGauss2(double *a, double *b, RandomState *state = NULL);
  
  // Also see Vector<float,double>::RandCategorical().
  
  // This is a randomized pruning mechanism that preserves expectations,
  // that we typically use to prune posteriors.
  template<class Float>
  inline Float RandPrune(Float post, BaseFloat prune_thresh,
                         struct RandomState* state = NULL) {
    KALDI_ASSERT(prune_thresh >= 0.0);
    if (post == 0.0 || std::abs(post) >= prune_thresh)
      return post;
    return (post >= 0 ? 1.0 : -1.0) *
        (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0);
  }
  
  
  inline double LogAdd(double x, double y) {
    double diff;
  
    if (x < y) {
      diff = x - y;
      x = y;
    } else {
      diff = y - x;
    }
    // diff is negative.  x is now the larger one.
  
    if (diff >= kMinLogDiffDouble) {
      double res;
      res = x + Log1p(Exp(diff));
      return res;
    } else {
      return x;  // return the larger one.
    }
  }
  
  
  inline float LogAdd(float x, float y) {
    float diff;
  
    if (x < y) {
      diff = x - y;
      x = y;
    } else {
      diff = y - x;
    }
    // diff is negative.  x is now the larger one.
  
    if (diff >= kMinLogDiffFloat) {
      float res;
      res = x + Log1p(Exp(diff));
      return res;
    } else {
      return x;  // return the larger one.
    }
  }
  
  
  // returns exp(x) - exp(y).
  inline double LogSub(double x, double y) {
    if (y >= x) {  // Throws exception if y>=x.
      if (y == x)
        return kLogZeroDouble;
      else
        KALDI_ERR << "Cannot subtract a larger from a smaller number.";
    }
  
    double diff = y - x;  // Will be negative.
    double res = x + Log(1.0 - Exp(diff));
  
    // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
    if (KALDI_ISNAN(res))
      return kLogZeroDouble;
    return res;
  }
  
  
  // returns exp(x) - exp(y).
  inline float LogSub(float x, float y) {
    if (y >= x) {  // Throws exception if y>=x.
      if (y == x)
        return kLogZeroDouble;
      else
        KALDI_ERR << "Cannot subtract a larger from a smaller number.";
    }
  
    float diff = y - x;  // Will be negative.
    float res = x + Log(1.0f - Exp(diff));
  
    // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
    if (KALDI_ISNAN(res))
      return kLogZeroFloat;
    return res;
  }
  
  /// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
  static inline bool ApproxEqual(float a, float b,
                                 float relative_tolerance = 0.001) {
    // a==b handles infinities.
    if (a == b) return true;
    float diff = std::abs(a-b);
    if (diff == std::numeric_limits<float>::infinity()
        || diff != diff) return false;  // diff is +inf or nan.
    return (diff <= relative_tolerance*(std::abs(a)+std::abs(b)));
  }
  
  /// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
  static inline void AssertEqual(float a, float b,
                                 float relative_tolerance = 0.001) {
    // a==b handles infinities.
    KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance));
  }
  
  
  // RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0.
  int32 RoundUpToNearestPowerOfTwo(int32 n);
  
  /// Returns a / b, rounding towards negative infinity in all cases.
  static inline int32 DivideRoundingDown(int32 a, int32 b) {
    KALDI_ASSERT(b != 0);
    if (a * b >= 0)
      return a / b;
    else if (a < 0)
      return (a - b + 1) / b;
    else
      return (a - b - 1) / b;
  }
  
  template<class I> I  Gcd(I m, I n) {
    if (m == 0 || n == 0) {
      if (m == 0 && n == 0) {  // gcd not defined, as all integers are divisors.
        KALDI_ERR << "Undefined GCD since m = 0, n = 0.";
      }
      return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m));
      // return absolute value of whichever is nonzero
    }
    // could use compile-time assertion
    // but involves messing with complex template stuff.
    KALDI_ASSERT(std::numeric_limits<I>::is_integer);
    while (1) {
      m %= n;
      if (m == 0) return (n > 0 ? n : -n);
      n %= m;
      if (n == 0) return (m > 0 ? m : -m);
    }
  }
  
  /// Returns the least common multiple of two integers.  Will
  /// crash unless the inputs are positive.
  template<class I> I  Lcm(I m, I n) {
    KALDI_ASSERT(m > 0 && n > 0);
    I gcd = Gcd(m, n);
    return gcd * (m/gcd) * (n/gcd);
  }
  
  
  template<class I> void Factorize(I m, std::vector<I> *factors) {
    // Splits a number into its prime factors, in sorted order from
    // least to greatest,  with duplication.  A very inefficient
    // algorithm, which is mainly intended for use in the
    // mixed-radix FFT computation (where we assume most factors
    // are small).
    KALDI_ASSERT(factors != NULL);
    KALDI_ASSERT(m >= 1);  // Doesn't work for zero or negative numbers.
    factors->clear();
    I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 };
  
    // First try small factors.
    for (I i = 0; i < 10; i++) {
      if (m == 1) return;  // We're done.
      while (m % small_factors[i] == 0) {
        m /= small_factors[i];
        factors->push_back(small_factors[i]);
      }
    }
    // Next try all odd numbers starting from 31.
    for (I j = 31;; j += 2) {
      if (m == 1) return;
      while (m % j == 0) {
        m /= j;
        factors->push_back(j);
      }
    }
  }
  
  inline double Hypot(double x, double y) {  return hypot(x, y); }
  inline float Hypot(float x, float y) {  return hypotf(x, y); }
  
  
  
  
  }  // namespace kaldi
  
  
  #endif  // KALDI_BASE_KALDI_MATH_H_