nnet-parse.h
4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
// nnet3/nnet-parse.h
// Copyright 2015 Johns Hopkins University (author: Daniel Povey)
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#ifndef KALDI_NNET3_NNET_PARSE_H_
#define KALDI_NNET3_NNET_PARSE_H_
#include "util/text-utils.h"
#include "matrix/kaldi-vector.h"
namespace kaldi {
namespace nnet3 {
/**
This function tokenizes input when parsing Descriptor configuration
values. A token in this context is not the same as a generic Kaldi token,
e.g. as defined in IsToken() in util/text_utils.h, which just means a non-empty
whitespace-free string. Here a token is more like a programming-language token,
and currently the following are allowed as tokens:
"("
")"
","
- A nonempty string beginning with A-Za-z_, and containing only -_A-Za-z0-9.
- An integer, optionally beginning with - or + and then a nonempty sequence of 0-9.
This function should return false and print an informative error with local
context if it can't tokenize the input.
*/
bool DescriptorTokenize(const std::string &input,
std::vector<std::string> *tokens);
/*
Returns true if name 'name' matches pattern 'pattern'. The pattern
format is: everything is literal, except '*' matches zero or more
characters. (Like filename globbing in UNIX).
*/
bool NameMatchesPattern(const char *name,
const char *pattern);
/**
Return a string used in error messages. Here, "is" will be from an
istringstream derived from a single line or part of a line.
If "is" is at EOF or in error state, this should just say "end of line",
else if the contents of "is" before EOF is <20 characters it should return
it all, else it should return the first 20 characters followed by "...".
*/
std::string ErrorContext(std::istream &is);
std::string ErrorContext(const std::string &str);
/** Returns a string that summarizes a vector fairly succintly, for
printing stats in info lines. For example:
"[percentiles(0,1,2,5 10,20,50,80,90 95,98,99,100)=(0.001,0.003,0.003,0.004 \
0.005,0.01,0.07,0.11,0.14 0.18,0.24,0.29,0.39), mean=0.0745, stddev=0.0611]"
*/
std::string SummarizeVector(const VectorBase<float> &vec);
std::string SummarizeVector(const VectorBase<double> &vec);
std::string SummarizeVector(const CuVectorBase<BaseFloat> &vec);
/** Print to 'os' some information about the mean and standard deviation of
some parameters, used in Info() functions in nnet-simple-component.cc.
For example:
PrintParameterStats(os, "bias", bias_params_, true);
would print to 'os' something like the string
", bias-{mean,stddev}=-0.013,0.196". If 'include_mean = false',
it will print something like
", bias-rms=0.2416", and this represents and uncentered standard deviation.
*/
void PrintParameterStats(std::ostringstream &os,
const std::string &name,
const CuVectorBase<BaseFloat> ¶ms,
bool include_mean = false);
/** Print to 'os' some information about the mean and standard deviation of
some parameters, used in Info() functions in nnet-simple-component.cc.
For example:
PrintParameterStats(os, "linear-params", linear_params_;
would print to 'os' something like the string
", linear-params-rms=0.239".
If you set 'include_mean' to true, it will print something like
", linear-params-{mean-stddev}=0.103,0.183".
If you set 'include_row_norms' to true, it will print something
like
", linear-params-row-norms=[percentiles(0,1........, stddev=0.0508]"
If you set 'include_column_norms' to true, it will print something
like
", linear-params-col-norms=[percentiles(0,1........, stddev=0.0508]"
If you set 'include_singular_values' to true, it will print something
like
", linear-params-singular-values=[percentiles(0,1........, stddev=0.0508]"
*/
void PrintParameterStats(std::ostringstream &os,
const std::string &name,
const CuMatrix<BaseFloat> ¶ms,
bool include_mean = false,
bool include_row_norms = false,
bool include_column_norms = false,
bool include_singular_values = false);
} // namespace nnet3
} // namespace kaldi
#endif