Blame view

src/transform/regression-tree.h 3.75 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  // transform/regression-tree.h
  
  // Copyright 2009-2011  Saarland University
  // Author:  Arnab Ghoshal
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_TRANSFORM_REGRESSION_TREE_H_
  #define KALDI_TRANSFORM_REGRESSION_TREE_H_
  
  #include <utility>
  #include <vector>
  
  #include "base/kaldi-common.h"
  #include "tree/cluster-utils.h"
  #include "gmm/am-diag-gmm.h"
  #include "transform/transform-common.h"
  
  namespace kaldi {
  
  /** \class RegressionTree
   *  A regression tree is a clustering of Gaussian densities in an acoustic
   *  model, such that the group of Gaussians at each node of the tree are
   *  transformed by the same transform. Each node is thus called a regression
   *  class.
   */
  class RegressionTree {
   public:
    RegressionTree() {}
  
    /// Top-down clustering of the Gaussians in a model based on their means.
    /// If sil_indices is nonempty, will put silence in a special class
    /// using a top-level split.
    void BuildTree(const Vector<BaseFloat> &state_occs,
                   const std::vector<int32> &sil_indices,
                   const AmDiagGmm &am,
                   int32 max_clusters);
  
    /// Parses the regression tree and finds the nodes whose occupancies (read
    /// from stats_in) are greater than min_count. The regclass_out vector has
    /// size equal to number of baseclasses, and contains the regression class
    /// index for each baseclass. The stats_out vector has size equal to number
    /// of regression classes. Return value is true if at least one regression
    /// class passed the count cutoff, false otherwise.
    bool GatherStats(const std::vector<AffineXformStats*> &stats_in,
                     double min_count,
                     std::vector<int32> *regclasses_out,
                     std::vector<AffineXformStats*> *stats_out) const;
  
    void Write(std::ostream &out, bool binary) const;
    void Read(std::istream &in, bool binary, const AmDiagGmm &am);
  
    /// Accessors (const)
    int32 NumBaseclasses() const { return num_baseclasses_; }
    const std::vector< std::pair<int32, int32> >& GetBaseclass(int32 bclass)
        const { return baseclasses_[bclass]; }
    int32 Gauss2BaseclassId(size_t pdf_id, size_t gauss_id) const {
      return gauss2bclass_[pdf_id][gauss_id];
    }
  
   private:
    int32 num_nodes_;  ///< Total (non-leaf+leaf) nodes
  
    /// For each node, index of its parent: size = num_nodes_
    /// If 0 <= i < num_baseclasses_, then i is a leaf of the tree (a base class);
    /// else a non-leaf node.  parents_[i] > i, except for the top node
    /// (last-numbered one), for which parents_[i] == i.
    std::vector<int32> parents_;
    int32 num_baseclasses_;  ///< Number of leaf nodes
    /// Each baseclass (leaf of regression tree) is a vector of Gaussian indices.
    /// Each Gaussian in the model is indexed by (pdf, gaussian) indices pair.
    std::vector< std::vector< std::pair<int32, int32> > > baseclasses_;
    /// Mapping from (pdf, gaussian) indices to baseclasses
    std::vector< std::vector<int32> > gauss2bclass_;
  
    void MakeGauss2Bclass(const AmDiagGmm &am);
  
    // Cannot have copy constructor and assigment operator
    KALDI_DISALLOW_COPY_AND_ASSIGN(RegressionTree);
  
  };
  
  }  // namespace kaldi
  
  #endif  // KALDI_TRANSFORM_REGRESSION_TREE_H_