sampler.h
13.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
// rnnlm/sampler.h
// Copyright 2017 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_RNNLM_SAMPLER_H_
#define KALDI_RNNLM_SAMPLER_H_
#include <vector>
#include "base/kaldi-common.h"
namespace kaldi {
namespace rnnlm {
/**
Sample without replacement from a distribution, with provided 1st order inclusion
probabilities. For example, if probs[i] is 1.0, i will definitely be included
in the list 'sample', and probs[i] is 0.0, i will definitely not be included.
@params [in] probs The input vector of inclusion probabilities, with
0.0 <= probs[i] <= 1.0, and the sum of 'probs' should
be close to an integer. (specifically: within 1.0e-03 of
a whole number; this should be easy to ensure in double
precision). Let 'k' be this sum, rounded to
the nearest integer.
@params [out] sample The vector 'sample' will be set to an unsorted list
of 'k' distinct samples with first order inclusion
probabilities given by 'probs'.
*/
void SampleWithoutReplacement(const std::vector<double> &probs,
std::vector<int32> *sample);
/**
This function samples from part of a distribution expressed as a cdf
(cumulative density function). It is a utility function used in class
Sampler, but we make it a namespace function so that we can test it.
It does the sampling in O(log(cdf_end - cdf_start)) time, using
binary search.
@param [in] cdf_start The start of a range of a cdf.
@param [in] cdf_end The end of a range of a cdf. We require that
cdf_end > cdf_start and *cdf_end > *cdf_start,
and there must be nondecreasing doubles in between.
As an example, suppose cdf_end - cdf_start == 3,
and the full range including the ends
is [ 0.50, 0.55, 0.65, 0.70 ].
This is interpreted as representing the probabilities
of 3 elements, which are difference between the
elements, meaning probabilities [ 0.05, 0.1, 0.05 ].
What this function does is to renormalize those values
to sum to 1.0 and then saple from the resulting distribution,
i.e. the distribution [ 0.25, 0.5, 0.25 ] in this
example, so we'd return 'cdf_start' with proability 0.25,
'cdf_start + 1' with probability 0.5, and
'cdf_start + 2' with probability 0.25.
@return Returns a pointer cdf_start <= p < cdf_end, with probability
proportional to p[1] - p[0].
*/
const double* SampleFromCdf(const double *cdf_start,
const double *cdf_end);
/**
This class allows us to sample a set of words from a distribution over
words, where the distribution (which ultimately comes from an ARPA-style
language model) is given as a combination of a unigram distriubution
with a sparse component represented as a list of (word-index, probability)
pairs.
*/
class Sampler {
public:
// Initialize the class. The unigram probabilities (which you can think
// of as the probability for each word if we don't know the history) are given.
// each element of unigram_probs should be >= 0, and they should sum to
// a value close to 1.
// This class does not retain a reference to 'unigram_probs' after
// the constructor exits.
explicit Sampler(const std::vector<BaseFloat> &unigram_probs);
/// Sample words from the supplied distribution, appropriately scaled.
/// Let the unnormalized distribution be as follows:
/// p(i) = unigram_weight * u(i) + h(i)
/// where u(i) is the 'unigram_probs' vector this class was constructed
/// with, and h(i) is the probability that word i is given (if any) in
/// the sparse vector that 'higher_order_probs' represents.
/// Notice that we are adding to the unigram distribution, we are not
/// backing off to it. Doing it this way makes a lot of things simpler.
///
/// We define the first-order inclusion probabilities:
/// q(i) = min(alpha p(i), 1.0)
/// where alpha is chosen so that the sum of q(i) equals 'num_words_to_sample'.
/// Then we generate a sample whose first-order inclusion probabilities
/// are q(i). We do all this without explicitly iterating over the unigram
/// distribution, so this is fairly fast.
///
/// @param [in] num_words_to_sample The number of words that we are
/// directed sample; must be > 0 and less than
/// the number of nonzero elements of the 'unigram_probs'
/// that this class was constructed with.
/// @param [in] unigram_weight Must be > 0.0. Search above for p(i) to
/// see what effect it has.
/// @param [in] higher_order_probs A vector of pairs (i, p) where
/// 0 <= i < unigram_probs.size() (referring to the
/// unigram_probs vector used in the constructor),
/// and p > 0.0. This vector must be sorted and
/// unique w.r.t. i. Note: the probabilities
/// here will be added to the unigram probabilities
/// of the words concerned.
/// @param [out] sample The sampled list of words, represented as pairs
/// (i, p), where 0 <= i < unigram_probs.size() is
/// the word index and 0 < p <= 1 is the probabilitity
/// with which that word was included in the set.
/// The list will not be sorted, but it will be unique
/// on the int. Its size will equal num_words_to_sample.
void SampleWords(int32 num_words_to_sample,
BaseFloat unigram_weight,
const std::vector<std::pair<int32, BaseFloat> > &higher_order_probs,
std::vector<std::pair<int32, BaseFloat> > *sample) const;
/// This is an alternative version of SampleWords() which allows you to
/// specify a list of words that must be sampled (i.e. after scaling, they
/// must have probability 1.0.). It does this by adding them to the
/// distribution with sufficiently large probability and then calling the
/// other version of SampleWords().
///
/// The vector 'words_we_must_sample' must be sorted and unique, and all
/// elements i must satisfy 0 <= i < unigram_probs.size(), where unigram_probs
/// was the vector supplied to the constructor.
/// See the comments for the other version of SampleWords() to understand the
/// interface, which is otherwise the same.
void SampleWords(int32 num_words_to_sample,
BaseFloat unigram_weight,
const std::vector<std::pair<int32, BaseFloat> > &higher_order_probs,
const std::vector<int32> &words_we_must_sample,
std::vector<std::pair<int32, BaseFloat> > *sample) const;
private:
// This structure represents a contiguous range of symbols; 'start' and 'end'
// are pointers into the contents of 'unigram_cdf_'. Let
// start_i = start - &(unigram_cdf_[0])
// end_i = end - &(unigram_cdf_[0])
// and we require that end_i > start_i.
// Then this struct Interval represents the set of words from
// start_i to end_i - 1, with a total probability mass given
// by 'prob'. If end > start + 1, the probability mass is
// apportioned among the words proportional to the 'unigram_probs'
// vector passed to the constructor, with cdf given by unigram_cdf_.
//
// Note: although we call the total mass 'prob', and we require that
// prob > 0.0, there is not necessarily a constraint that prob < 1.0
// in early stages of processing. Ultimately these will be processed
// into inclusion probabilities, which will sum to num_words_to_sample.
// Search for q(i) above for more explanation.
struct Interval {
double prob;
const double *start;
const double *end;
// this operator < allows us to include Interval in a priority_queue
// from which we can select the largest one (used when computing alpha).
bool operator < (const Interval &other) const {
return prob < other.prob;
}
Interval(double p, const double *s, const double *e):
prob(p), start(s), end(e) { }
};
/// Given a distribution over words which we will write as p(i)
/// (this is the initial state of 'intervals'), produces
/// q(i) = min(alpha p(i), 1.0)
/// where alpha is chosen so that sum_i q(i) == num_words_to_sample.
/// This involves choosing alpha in an algorithm involving a queue
/// of Intervals. We may have to split some Intervals.
/// After this function is called, 'intervals' will contain q(i),
/// so each Interval will have prob <= 1.0 and the sum of the
/// probs will equal num_words_to_sample.
/// 'total_p' is the sum of (*intervals)[i].prob; it is provided
/// to this function so that it doesn't have to compute it itself.
static void NormalizeIntervals(int32 num_words_to_sample,
double total_p,
std::vector<Interval> *intervals);
/// Sample from the distribution q(i) represented by 'intervals'.
/// @param [in] intervals The vector of Intervals to sample from.
/// The number of things to sample are given by the sum of
/// intervals[i].prob, and will equal the original
/// 'num_words_to_sample' that was passed into the function
/// SampleWords(). Below, we refer to this number as
/// 'num_words_to_sample', although it is not passed in
/// explicitly.
/// @param [out] sample The sampled words and the probabilities 0 < p <= 1
/// with which they were included in the sample, are written
/// to here. The size of this vector will equal
/// 'num_words_to_sample' at exit. This vector will not
/// be sorted.
void SampleFromIntervals(const std::vector<Interval> &intervals,
std::vector<std::pair<int32, BaseFloat> > *sample) const;
// This helper function, used inside SampleWords(), combines the unigram and
// higher-order portions of the distribution into a single unified format
// based on Intervals. The intervals that it outputs represent the p(i) in
// the comment above, i.e. this is before we compute alpha and normalize it so
// that it sums to num_words_to_sample.
// Returns the sum of the probabilities of the intervals.
double GetInitialIntervals(BaseFloat unigram_weight,
const std::vector<std::pair<int32, BaseFloat> > &higher_order_probs,
std::vector<Interval> *intervals) const;
// the cdf (cumulative density function) of the unigram distribution, indexed
// from 0 to N where N is unigram_probs.size() given in the constructor (so
// the dimension of unigram_cdf_ is N+1). We make sure the unigram
// distribution is normalized, so unigram_cdf_[0] == 0.0 and
// unigram_cdf_.back() == 1.0
std::vector<double> unigram_cdf_;
};
// A 'Distribution' represents an, unnormalized distribution
// over a discrete space. E.g., [ (5, 0.5), (10, 0.4) ]
// represents: choose 5 with probability 0.5, choose 10 with probability
// 0.4. (we said that it's unnormalized, so these things won't
// necessarily sum to one).
// A valid Distribution must be sorted and unique on the .first
// element, and all the .second elements must be > 0.
typedef std::vector<std::pair<int32, BaseFloat> > Distribution;
// Check that a Distribution is sorted and unique on its
// .first values, and that all of its .second values are > 0.
void CheckDistribution(const Distribution &d);
// Weights a Distribution by multiplying all the .second elements by
// 'weight'. 'weight' must be > 0.0.
void WeightDistribution(BaseFloat weight,
Distribution *d);
// Returns the sum of the .second elements of a Distribution.
BaseFloat TotalOfDistribution(const Distribution &d);
// Merges two distributions, summing the probabilities of any elements that
// occur in both.
void MergeDistributions(const Distribution &d1,
const Distribution &d2,
Distribution *d);
} // namespace rnnlm
} // namespace kaldi
#endif // KALDI_RNNLM_SAMPLER_H_