Blame view

src/gmm/model-common.h 3.56 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
  // gmm/model-common.h
  
  // Copyright 2009-2012  Saarland University;  Microsoft Corporation;
  //                      Johns Hopkins University (author: Daniel Povey)
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_GMM_MODEL_COMMON_H_
  #define KALDI_GMM_MODEL_COMMON_H_
  #include "matrix/matrix-lib.h"
  
  namespace kaldi {
  
  enum GmmUpdateFlags {
    kGmmMeans       = 0x001,  // m
    kGmmVariances   = 0x002,  // v
    kGmmWeights     = 0x004,  // w
    kGmmTransitions = 0x008,  // t ... not really part of GMM.
    kGmmAll       = 0x00F  // a
  };
  typedef uint16 GmmFlagsType;  ///< Bitwise OR of the above flags.
  /// Convert string which is some subset of "mSwa" to
  /// flags.
  GmmFlagsType StringToGmmFlags(std::string str);
  
  /// Convert GMM flags to string
  std::string GmmFlagsToString(GmmFlagsType gmm_flags);
  
  // Make sure that the flags make sense, i.e. if there is variance
  // accumulation that there is also mean accumulation
  GmmFlagsType AugmentGmmFlags(GmmFlagsType flags);
  
  enum SgmmUpdateFlags {  /// The letters correspond to the variable names.
    kSgmmPhoneVectors                = 0x001,  /// v
    kSgmmPhoneProjections            = 0x002,  /// M
    kSgmmPhoneWeightProjections      = 0x004,  /// w
    kSgmmCovarianceMatrix            = 0x008,  /// S
    kSgmmSubstateWeights             = 0x010,  /// c
    kSgmmSpeakerProjections          = 0x020,  /// N
    kSgmmTransitions                 = 0x040,  /// t .. not really part of SGMM.
    kSgmmSpeakerWeightProjections    = 0x080,  /// u [ for SSGMM ]
    kSgmmAll                         = 0x0FF   /// a (won't normally use this).  
  };
  
  typedef uint16 SgmmUpdateFlagsType;  ///< Bitwise OR of the above flags.
  SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str);
  
  enum SgmmWriteFlags {
    kSgmmGlobalParams    = 0x001,  /// g
    kSgmmStateParams     = 0x002,  /// s
    kSgmmNormalizers     = 0x004,  /// n
    kSgmmBackgroundGmms  = 0x008,  /// u
    kSgmmWriteAll        = 0x00F  /// a
  };
  
  typedef uint16 SgmmWriteFlagsType;  ///< Bitwise OR of the above flags.
  
  SgmmWriteFlagsType StringToSgmmWriteFlags(std::string str);
  
  /// Get Gaussian-mixture or substate-mixture splitting targets,
  /// according to a power rule (e.g. typically power = 0.2).
  /// Returns targets for number of mixture components (Gaussians,
  /// or sub-states), allocating the Gaussians or whatever according
  /// to a power of occupancy in order to acheive the total supplied
  /// "target".  During splitting we ensure that
  /// each Gaussian [or sub-state] would get a count of at least
  /// "min-count", assuming counts were evenly distributed between
  /// Gaussians in a state.
  /// The vector "targets" will be resized to the appropriate dimension;
  /// its value at input is ignored.
  void GetSplitTargets(const Vector<BaseFloat> &state_occs,
                       int32 target_components,
                       BaseFloat power,
                       BaseFloat min_count,
                       std::vector<int32> *targets);
  
  }  // End namespace kaldi
  
  
  #endif  // KALDI_GMM_MODEL_COMMON_H_