// rnnlmbin/rnnlm-train.cc // Copyright 2015-2017 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "base/kaldi-common.h" #include "util/common-utils.h" #include "rnnlm/rnnlm-training.h" #include "rnnlm/rnnlm-example-utils.h" #include "nnet3/nnet-utils.h" #include "cudamatrix/cu-allocator.h" int main(int argc, char *argv[]) { try { using namespace kaldi; using namespace kaldi::rnnlm; typedef kaldi::int32 int32; typedef kaldi::int64 int64; // rnnlm_rxfilename must be supplied, via --read-rnnlm option. std::string rnnlm_rxfilename; // For now, rnnlm_wxfilename must be supplied (later we could make it possible // to train the embedding matrix without training the RNNLM itself, if there // is a need). std::string rnnlm_wxfilename; // embedding_rxfilename must be supplied, via --read-embedding option. std::string embedding_rxfilename; std::string embedding_wxfilename; std::string word_features_rxfilename; // binary mode for writing output. bool binary = true; RnnlmCoreTrainerOptions core_config; RnnlmEmbeddingTrainerOptions embedding_config; RnnlmObjectiveOptions objective_config; const char *usage = "Train nnet3-based RNNLM language model (reads minibatches prepared\n" "by rnnlm-get-egs). Supports various modes depending which parameters\n" "we are training.\n" "Usage:\n" " rnnlm-train [options] \n" "e.g.:\n" " rnnlm-get-egs ... ark:- | \\\n" " rnnlm-train --read-rnnlm=foo/0.raw --write-rnnlm=foo/1.raw --read-embedding=foo/0.embedding \\\n" " --write-embedding=foo/1.embedding --read-sparse-word-features=foo/word_feats.txt ark:-\n" "See also: rnnlm-get-egs\n"; std::string use_gpu = "yes"; ParseOptions po(usage); po.Register("use-gpu", &use_gpu, "yes|no|optional|wait, only has effect if compiled with CUDA"); po.Register("read-rnnlm", &rnnlm_rxfilename, "Read RNNLM from this location (e.g. 0.raw). Must be supplied."); po.Register("write-rnnlm", &rnnlm_wxfilename, "Write RNNLM to this location (e.g. 1.raw)." "If not supplied, the core RNNLM is not trained " "(but other parts of the model might be."); po.Register("read-embedding", &embedding_rxfilename, "Location to read dense (feature or word) embedding matrix, " "of dimension (num-words or num-features) by (embedding-dim)."); po.Register("write-embedding", &embedding_wxfilename, "Location to write embedding matrix (c.f. --read-embedding). " "If not provided, the embedding will not be trained."); po.Register("read-sparse-word-features", &word_features_rxfilename, "Location to read sparse word-feature matrix, e.g. " "word_feats.txt. Format is lines like: '1 30 1.0 516 1.0':" "starting with word-index, then a list of pairs " "(feature-index, value) only including nonzero features. " "This will usually be determined in an ad-hoc way based on " "letters and other hand-built features; it's not trainable." " If present, the embedding matrix read via --read-embedding " "will be interpreted as a feature-embedding matrix."); po.Register("binary", &binary, "If true, write outputs in binary form."); objective_config.Register(&po); RegisterCuAllocatorOptions(&po); // register the core RNNLM training options options with the prefix "rnnlm", // so they will appear as --rnnlm.max-change and the like. This is done // with a prefix because later we may add a neural net to transform the word // embedding, and it would have options that would have a name conflict with // some of these options. ParseOptions core_opts("rnnlm", &po); core_config.Register(&core_opts); // ... and register the embedding options with the prefix "embedding". ParseOptions embedding_opts("embedding", &po); embedding_config.Register(&embedding_opts); po.Read(argc, argv); if (po.NumArgs() != 1) { po.PrintUsage(); exit(1); } if (rnnlm_rxfilename == "" || rnnlm_wxfilename == "" || embedding_rxfilename == "") { KALDI_WARN << "--read-rnnlm, --write-rnnlm and --read-embedding " "options are required."; po.PrintUsage(); exit(1); } std::string examples_rspecifier = po.GetArg(1); #if HAVE_CUDA==1 CuDevice::Instantiate().SelectGpuId(use_gpu); CuDevice::Instantiate().AllowMultithreading(); #endif kaldi::nnet3::Nnet rnnlm; ReadKaldiObject(rnnlm_rxfilename, &rnnlm); if (!IsSimpleNnet(rnnlm)) KALDI_ERR << "Input RNNLM in " << rnnlm_rxfilename << " is not the type of neural net we were looking for; " "failed IsSimpleNnet()."; CuMatrix embedding_mat; ReadKaldiObject(embedding_rxfilename, &embedding_mat); CuSparseMatrix word_feature_mat; if (word_features_rxfilename != "") { // binary mode is not supported here; it's a text format. Input input(word_features_rxfilename); int32 feature_dim = embedding_mat.NumRows(); SparseMatrix cpu_word_feature_mat; ReadSparseWordFeatures(input.Stream(), feature_dim, &cpu_word_feature_mat); word_feature_mat.Swap(&cpu_word_feature_mat); // copy to GPU, if we have // one. } { bool train_embedding = (embedding_wxfilename != ""); RnnlmTrainer trainer( train_embedding, core_config, embedding_config, objective_config, (word_features_rxfilename != "" ? &word_feature_mat : NULL), &embedding_mat, &rnnlm); SequentialRnnlmExampleReader example_reader(examples_rspecifier); for (; !example_reader.Done(); example_reader.Next()) trainer.Train(&(example_reader.Value())); if (trainer.NumMinibatchesProcessed() == 0) KALDI_ERR << "There was no data to train on."; // The destructor of 'trainer' trains on the last minibatch // and writes out anything we need to write out. } WriteKaldiObject(rnnlm, rnnlm_wxfilename, binary); KALDI_LOG << "Wrote RNNLM to " << PrintableWxfilename(rnnlm_wxfilename); if (embedding_wxfilename != "") { WriteKaldiObject(embedding_mat, embedding_wxfilename, binary); KALDI_LOG << "Wrote embedding matrix to " << PrintableWxfilename(embedding_wxfilename); } #if HAVE_CUDA==1 CuDevice::Instantiate().PrintProfile(); #endif return 0; } catch(const std::exception &e) { std::cerr << e.what() << '\n'; return -1; } }