Blame view
src/nnet3bin/nnet3-copy.cc
4.34 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
// nnet3bin/nnet3-copy.cc // Copyright 2012 Johns Hopkins University (author: Daniel Povey) // 2015 Xingyu Na // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include <typeinfo> #include "base/kaldi-common.h" #include "util/common-utils.h" #include "hmm/transition-model.h" #include "nnet3/am-nnet-simple.h" #include "nnet3/nnet-utils.h" int main(int argc, char *argv[]) { try { using namespace kaldi; using namespace kaldi::nnet3; typedef kaldi::int32 int32; const char *usage = "Copy 'raw' nnet3 neural network to standard output " "Also supports setting all the learning rates to a value " "(the --learning-rate option) " " " "Usage: nnet3-copy [options] <nnet-in> <nnet-out> " "e.g.: " " nnet3-copy --binary=false 0.raw text.raw "; bool binary_write = true; BaseFloat learning_rate = -1; std::string nnet_config, edits_config, edits_str; BaseFloat scale = 1.0; bool prepare_for_test = false; ParseOptions po(usage); po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("learning-rate", &learning_rate, "If supplied, all the learning rates of updatable components" "are set to this value."); po.Register("nnet-config", &nnet_config, "Name of nnet3 config file that can be used to add or replace " "components or nodes of the neural network (the same as you " "would give to nnet3-init)."); po.Register("edits-config", &edits_config, "Name of edits-config file that can be used to modify the network " "(applied after nnet-config). See comments for ReadEditConfig()" "in nnet3/nnet-utils.h to see currently supported commands."); po.Register("edits", &edits_str, "Can be used as an inline alternative to edits-config; semicolons " "will be converted to newlines before parsing. E.g. " "'--edits=remove-orphans'."); po.Register("scale", &scale, "The parameter matrices are scaled" " by the specified value."); po.Register("prepare-for-test", &prepare_for_test, "If true, prepares the model for test time (may reduce model size " "slightly. Involves setting test mode in dropout and batch-norm " "components, and calling CollapseModel() which may remove some " "components."); po.Read(argc, argv); if (po.NumArgs() != 2) { po.PrintUsage(); exit(1); } std::string raw_nnet_rxfilename = po.GetArg(1), raw_nnet_wxfilename = po.GetArg(2); Nnet nnet; ReadKaldiObject(raw_nnet_rxfilename, &nnet); if (!nnet_config.empty()) { Input ki(nnet_config); nnet.ReadConfig(ki.Stream()); } if (learning_rate >= 0) SetLearningRate(learning_rate, &nnet); if (scale != 1.0) ScaleNnet(scale, &nnet); if (!edits_config.empty()) { Input ki(edits_config); ReadEditConfig(ki.Stream(), &nnet); } if (!edits_str.empty()) { for (size_t i = 0; i < edits_str.size(); i++) if (edits_str[i] == ';') edits_str[i] = ' '; std::istringstream is(edits_str); ReadEditConfig(is, &nnet); } if (prepare_for_test) { SetBatchnormTestMode(true, &nnet); SetDropoutTestMode(true, &nnet); CollapseModel(CollapseModelConfig(), &nnet); } WriteKaldiObject(nnet, raw_nnet_wxfilename, binary_write); KALDI_LOG << "Copied raw neural net from " << raw_nnet_rxfilename << " to " << raw_nnet_wxfilename; return 0; } catch(const std::exception &e) { std::cerr << e.what() << ' '; return -1; } } |