Blame view
src/tree/clusterable-classes.h
5.58 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
// tree/clusterable-classes.h // Copyright 2009-2011 Microsoft Corporation; Saarland University // 2014 Daniel Povey // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #ifndef KALDI_TREE_CLUSTERABLE_CLASSES_H_ #define KALDI_TREE_CLUSTERABLE_CLASSES_H_ 1 #include <string> #include "itf/clusterable-itf.h" #include "matrix/matrix-lib.h" namespace kaldi { /// \addtogroup clustering_group /// @{ /// ScalarClusterable clusters scalars with x^2 loss. class ScalarClusterable: public Clusterable { public: ScalarClusterable(): x_(0), x2_(0), count_(0) {} explicit ScalarClusterable(BaseFloat x): x_(x), x2_(x*x), count_(1) {} virtual std::string Type() const { return "scalar"; } virtual BaseFloat Objf() const; virtual void SetZero() { count_ = x_ = x2_ = 0.0; } virtual void Add(const Clusterable &other_in); virtual void Sub(const Clusterable &other_in); virtual Clusterable* Copy() const; virtual BaseFloat Normalizer() const { return static_cast<BaseFloat>(count_); } // Function to write data to stream. Will organize input later [more complex] virtual void Write(std::ostream &os, bool binary) const; virtual Clusterable* ReadNew(std::istream &is, bool binary) const; std::string Info(); // For debugging. BaseFloat Mean() { return (count_ != 0 ? x_/count_ : 0.0); } private: BaseFloat x_; BaseFloat x2_; BaseFloat count_; void Read(std::istream &is, bool binary); }; /// GaussClusterable wraps Gaussian statistics in a form accessible /// to generic clustering algorithms. class GaussClusterable: public Clusterable { public: GaussClusterable(): count_(0.0), var_floor_(0.0) {} GaussClusterable(int32 dim, BaseFloat var_floor): count_(0.0), stats_(2, dim), var_floor_(var_floor) {} GaussClusterable(const Vector<BaseFloat> &x_stats, const Vector<BaseFloat> &x2_stats, BaseFloat var_floor, BaseFloat count); virtual std::string Type() const { return "gauss"; } void AddStats(const VectorBase<BaseFloat> &vec, BaseFloat weight = 1.0); virtual BaseFloat Objf() const; virtual void SetZero(); virtual void Add(const Clusterable &other_in); virtual void Sub(const Clusterable &other_in); virtual BaseFloat Normalizer() const { return count_; } virtual Clusterable *Copy() const; virtual void Scale(BaseFloat f); virtual void Write(std::ostream &os, bool binary) const; virtual Clusterable *ReadNew(std::istream &is, bool binary) const; virtual ~GaussClusterable() {} BaseFloat count() const { return count_; } // The next two functions are not const-correct, because of SubVector. SubVector<double> x_stats() const { return stats_.Row(0); } SubVector<double> x2_stats() const { return stats_.Row(1); } private: double count_; Matrix<double> stats_; // two rows: sum, then sum-squared. double var_floor_; // should be common for all objects created. void Read(std::istream &is, bool binary); }; /// @} end of "addtogroup clustering_group" inline void GaussClusterable::SetZero() { count_ = 0; stats_.SetZero(); } inline GaussClusterable::GaussClusterable(const Vector<BaseFloat> &x_stats, const Vector<BaseFloat> &x2_stats, BaseFloat var_floor, BaseFloat count): count_(count), stats_(2, x_stats.Dim()), var_floor_(var_floor) { stats_.Row(0).CopyFromVec(x_stats); stats_.Row(1).CopyFromVec(x2_stats); } /// VectorClusterable wraps vectors in a form accessible to generic clustering /// algorithms. Each vector is associated with a weight; these could be 1.0. /// The objective function (to be maximized) is the negated sum of squared /// distances from the cluster center to each vector, times that vector's /// weight. class VectorClusterable: public Clusterable { public: VectorClusterable(): weight_(0.0), sumsq_(0.0) {} VectorClusterable(const Vector<BaseFloat> &vector, BaseFloat weight); virtual std::string Type() const { return "vector"; } // Objf is negated weighted sum of squared distances. virtual BaseFloat Objf() const; virtual void SetZero() { weight_ = 0.0; sumsq_ = 0.0; stats_.Set(0.0); } virtual void Add(const Clusterable &other_in); virtual void Sub(const Clusterable &other_in); virtual BaseFloat Normalizer() const { return weight_; } virtual Clusterable *Copy() const; virtual void Scale(BaseFloat f); virtual void Write(std::ostream &os, bool binary) const; virtual Clusterable *ReadNew(std::istream &is, bool binary) const; virtual ~VectorClusterable() {} private: double weight_; // sum of weights of the source vectors. Never negative. Vector<double> stats_; // Equals the weighted sum of the source vectors. double sumsq_; // Equals the sum over all sources, of weight_ * vec.vec, // where vec = stats_ / weight_. Used in computing // the objective function. void Read(std::istream &is, bool binary); }; } // end namespace kaldi. #endif // KALDI_TREE_CLUSTERABLE_CLASSES_H_ |