Blame view

src/transform/lvtln.h 3.4 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
  // transform/lvtln.h
  
  // Copyright 2009-2011 Microsoft Corporation
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  
  #ifndef KALDI_TRANSFORM_LVTLN_H_
  #define KALDI_TRANSFORM_LVTLN_H_
  
  #include <vector>
  
  #include "base/kaldi-common.h"
  #include "util/common-utils.h"
  #include "transform/transform-common.h"
  #include "transform/fmllr-diag-gmm.h"
  
  
  namespace kaldi {
  
  /*
    Class for applying linear approximations to VTLN transforms;
    see \ref transform_lvtln.
  */
  
  
  class LinearVtln {
   public:
    LinearVtln() { } // This initializer will probably be used prior to calling
    // Read().
  
    LinearVtln(int32 dim, int32 num_classes, int32 default_class);
    // This initializer sets up the
    // model; the transforms will initially all be the same.
  
    // SetTransform is used when we initialize it as "normal" VTLN.
    // It's not necessary to ever call this function.  "transform" is "A",
    // the square part of the transform matrix.
    void SetTransform(int32 i, const MatrixBase<BaseFloat> &transform);
  
    void SetWarp(int32 i, BaseFloat warp);
  
    BaseFloat GetWarp(int32 i) const;
  
    // GetTransform gets the transform for class i.  The caller must
    // make sure the output matrix is sized Dim() by Dim().
    void GetTransform(int32 i, MatrixBase<BaseFloat> *transform) const;
  
  
    /// Compute the transform for the speaker.
    void ComputeTransform(const FmllrDiagGmmAccs &accs,
                          std::string norm_type,  // type of regular fMLLR computation: "none", "offset", "diag"
                          BaseFloat logdet_scale,  // scale on logdet (1.0 is "correct" but less may work better)
                          MatrixBase<BaseFloat> *Ws,  // output fMLLR transform, should be size dim x dim+1
                          int32 *class_idx,  // the transform that was chosen...
                          BaseFloat *logdet_out,
                          BaseFloat *objf_impr = NULL,  // versus no transform
                          BaseFloat *count = NULL);
  
    void Read(std::istream &is, bool binary);
  
    void Write(std::ostream &os, bool binary) const;
  
    int32 Dim() const { KALDI_ASSERT(!A_.empty()); return A_[0].NumRows(); }
    int32 NumClasses() const { return A_.size(); }
    // This computes the offset term for this class given these
    // stats.
    void GetOffset(const FmllrDiagGmmAccs &speaker_stats,
                   int32 class_idx,
                   VectorBase<BaseFloat> *offset) const;
  
    friend class LinearVtlnStats;
   protected:
    int32 default_class_;  // transform we return if we have no data.
    std::vector<Matrix<BaseFloat> > A_;  // Square parts of the FMLLR matrices.
    std::vector<BaseFloat> logdets_;
    std::vector<BaseFloat> warps_; // This variable can be used to store the
                                   // warp factors that each transform correspond to.
    
  
  };
  
  
  
  }  // namespace kaldi
  
  #endif  // KALDI_TRANSFORM_LVTLN_H_