// nnet3/nnet-parse.cc // Copyright 2015 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include #include #include #include "nnet3/nnet-parse.h" #include "cudamatrix/cu-vector.h" #include "cudamatrix/cu-matrix.h" namespace kaldi { namespace nnet3 { bool DescriptorTokenize(const std::string &input, std::vector *tokens) { KALDI_ASSERT(tokens != NULL); size_t start = input.find_first_not_of(" \t"), size = input.size(); tokens->clear(); while (start < size) { KALDI_ASSERT(!isspace(input[start])); if (input[start] == '(' || input[start] == ')' || input[start] == ',') { tokens->push_back(std::string(input, start, 1)); start = input.find_first_not_of(" \t", start + 1); } else { size_t found = input.find_first_of(" \t(),", start); KALDI_ASSERT(found != start); if (found == std::string::npos) { std::string str(input, start, input.size() - start); BaseFloat tmp; if (!IsValidName(str) && !ConvertStringToReal(str, &tmp)) { KALDI_WARN << "Could not tokenize line " << ErrorContext(std::string(input, start)); return false; } tokens->push_back(str); break; } else { if (input[found] == '(' || input[found] == ')' || input[found] == ',') { std::string str(input, start, found - start); BaseFloat tmp; if (!IsValidName(str) && !ConvertStringToReal(str, &tmp)) { KALDI_WARN << "Could not tokenize line " << ErrorContext(std::string(input, start)); return false; } tokens->push_back(str); start = found; } else { std::string str(input, start, found - start); BaseFloat tmp; if (!IsValidName(str) && !ConvertStringToReal(str, &tmp)) { KALDI_WARN << "Could not tokenize line " << ErrorContext(std::string(input, start)); return false; } tokens->push_back(str); start = input.find_first_not_of(" \t", found); } } } } return true; } std::string ErrorContext(std::istream &is) { if (!is.good()) return "end of line"; char buf[21]; is.read(buf, 21); if (is) { return (std::string(buf, 20) + "..."); } return std::string(buf, is.gcount()); } std::string ErrorContext(const std::string &str) { if (str.size() == 0) return "end of line"; if (str.size() <= 20) return str; return std::string(str, 0, 20) + "..."; } static void PrintFloatSuccinctly(std::ostream &os, BaseFloat f) { if (fabs(f) < 10000.0 && fabs(f) >= 10.0) { os << std::fixed << std::setprecision(0) << f; } else if (fabs(f) >= 0.995) { os << std::fixed << std::setprecision(1) << f; } else if (fabs(f) >= 0.01) { os << std::fixed << std::setprecision(2) << f; } else { os << std::setprecision(1) << f; } os.unsetf(std::ios_base::floatfield); os << std::setprecision(6); // Restore the default. } // Returns a string that summarizes a vector fairly succintly, for // printing stats in info lines. std::string SummarizeVector(const VectorBase &vec) { std::ostringstream os; if (vec.Dim() < 10) { os << "[ "; for (int32 i = 0; i < vec.Dim(); i++) { PrintFloatSuccinctly(os, vec(i)); os << ' '; } os << "]"; } else { // print out mean and standard deviation, and some selected values. BaseFloat mean = vec.Sum() / vec.Dim(), stddev = sqrt(VecVec(vec, vec) / vec.Dim() - mean * mean); std::string percentiles_str = "0,1,2,5 10,20,50,80,90 95,98,99,100"; std::vector percentiles; bool ans = SplitStringToIntegers(percentiles_str, ", ", false, &percentiles); KALDI_ASSERT(ans); os << "[percentiles(" << percentiles_str << ")=("; Vector vec_sorted(vec); std::sort(vec_sorted.Data(), vec_sorted.Data() + vec_sorted.Dim()); int32 n = vec.Dim() - 1; for (size_t i = 0; i < percentiles.size(); i++) { int32 percentile = percentiles[i]; BaseFloat value = vec_sorted((n * percentile) / 100); PrintFloatSuccinctly(os, value); if (i + 1 < percentiles.size()) os << (i == 3 || i == 8 ? ' ' : ','); } os << std::setprecision(3); os << "), mean=" << mean << ", stddev=" << stddev << "]"; } return os.str(); } std::string SummarizeVector(const VectorBase &vec) { Vector vec_copy(vec); return SummarizeVector(vec_copy); } std::string SummarizeVector(const CuVectorBase &cu_vec) { Vector vec(cu_vec); return SummarizeVector(vec); } void PrintParameterStats(std::ostringstream &os, const std::string &name, const CuVectorBase ¶ms, bool include_mean) { os << std::setprecision(4); os << ", " << name << '-'; if (include_mean) { BaseFloat mean = params.Sum() / params.Dim(), stddev = std::sqrt(VecVec(params, params) / params.Dim() - mean * mean); os << "{mean,stddev}=" << mean << ',' << stddev; } else { BaseFloat rms = std::sqrt(VecVec(params, params) / params.Dim()); os << "rms=" << rms; } os << std::setprecision(6); // restore the default precision. } void PrintParameterStats(std::ostringstream &os, const std::string &name, const CuMatrix ¶ms, bool include_mean, bool include_row_norms, bool include_column_norms, bool include_singular_values) { os << std::setprecision(4); os << ", " << name << '-'; int32 dim = params.NumRows() * params.NumCols(); if (include_mean) { BaseFloat mean = params.Sum() / dim, stddev = std::sqrt(TraceMatMat(params, params, kTrans) / dim - mean * mean); os << "{mean,stddev}=" << mean << ',' << stddev; } else { BaseFloat rms = std::sqrt(TraceMatMat(params, params, kTrans) / dim); os << "rms=" << rms; } os << std::setprecision(6); // restore the default precision. if (include_row_norms) { CuVector row_norms(params.NumRows()); row_norms.AddDiagMat2(1.0, params, kNoTrans, 0.0); row_norms.ApplyPow(0.5); Vector row_norms_cpu; row_norms.Swap(&row_norms_cpu); os << ", " << name << "-row-norms=" << SummarizeVector(row_norms_cpu); } if (include_column_norms) { CuVector col_norms(params.NumCols()); col_norms.AddDiagMat2(1.0, params, kTrans, 0.0); col_norms.ApplyPow(0.5); Vector col_norms_cpu; col_norms.Swap(&col_norms_cpu); os << ", " << name << "-col-norms=" << SummarizeVector(col_norms_cpu); } if (include_singular_values) { Matrix params_cpu(params); Vector s(std::min(params.NumRows(), params.NumCols())); params_cpu.Svd(&s); std::string singular_values_str = SummarizeVector(s); os << ", " << name << "-singular-values=" << singular_values_str; std::ostringstream name_os; } } void ParseConfigLines(const std::vector &lines, std::vector *config_lines) { config_lines->resize(lines.size()); for (size_t i = 0; i < lines.size(); i++) { bool ret = (*config_lines)[i].ParseLine(lines[i]); if (!ret) { KALDI_ERR << "Error parsing config line: " << lines[i]; } } } bool NameMatchesPattern(const char *name, const char *pattern) { if (*pattern == '*') { return NameMatchesPattern(name, pattern + 1) || (*name != '\0' && NameMatchesPattern(name + 1, pattern)); } else if (*name == *pattern) { return (*name == '\0' || NameMatchesPattern(name + 1, pattern + 1)); } else { return false; } } } // namespace nnet3 } // namespace kaldi