// nnetbin/nnet-train-perutt.cc // Copyright 2011-2014 Brno University of Technology (Author: Karel Vesely) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet/nnet-trnopts.h" #include "nnet/nnet-nnet.h" #include "nnet/nnet-loss.h" #include "nnet/nnet-randomizer.h" #include "base/kaldi-common.h" #include "util/common-utils.h" #include "base/timer.h" #include "cudamatrix/cu-device.h" int main(int argc, char *argv[]) { using namespace kaldi; using namespace kaldi::nnet1; typedef kaldi::int32 int32; try { const char *usage = "Perform one iteration of NN training by SGD with per-utterance updates.\n" "The training targets are represented as pdf-posteriors, usually prepared " "by ali-to-post.\n" "Usage: nnet-train-perutt [options] " " []\n" "e.g.: nnet-train-perutt scp:feature.scp ark:posterior.ark nnet.init nnet.iter1\n"; ParseOptions po(usage); NnetTrainOptions trn_opts; trn_opts.Register(&po); LossOptions loss_opts; loss_opts.Register(&po); bool binary = true; po.Register("binary", &binary, "Write output in binary mode"); bool crossvalidate = false; po.Register("cross-validate", &crossvalidate, "Perform cross-validation (don't backpropagate)"); std::string feature_transform; po.Register("feature-transform", &feature_transform, "Feature transform in Nnet format"); std::string objective_function = "xent"; po.Register("objective-function", &objective_function, "Objective function : xent|mse"); int32 length_tolerance = 5; po.Register("length-tolerance", &length_tolerance, "Allowed length difference of features/targets (frames)"); std::string frame_weights; po.Register("frame-weights", &frame_weights, "Per-frame weights to scale gradients (frame selection/weighting)."); kaldi::int32 max_frames = 6000; // Allow segments maximum of one minute by default po.Register("max-frames",&max_frames, "Maximum number of frames a segment can have to be processed"); std::string use_gpu="yes"; po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA"); //// Add dummy option for compatibility with default scheduler, bool randomize = false; po.Register("randomize", &randomize, "Dummy, for compatibility with 'steps/nnet/train_scheduler.sh'"); //// po.Read(argc, argv); if (po.NumArgs() != 3 + (crossvalidate ? 0 : 1)) { po.PrintUsage(); exit(1); } std::string feature_rspecifier = po.GetArg(1), targets_rspecifier = po.GetArg(2), model_filename = po.GetArg(3); std::string target_model_filename; if (!crossvalidate) { target_model_filename = po.GetArg(4); } using namespace kaldi; using namespace kaldi::nnet1; typedef kaldi::int32 int32; #if HAVE_CUDA == 1 CuDevice::Instantiate().SelectGpuId(use_gpu); #endif Nnet nnet_transf; if (feature_transform != "") { nnet_transf.Read(feature_transform); } Nnet nnet; nnet.Read(model_filename); nnet.SetTrainOptions(trn_opts); if (crossvalidate) { nnet_transf.SetDropoutRate(0.0); nnet.SetDropoutRate(0.0); } kaldi::int64 total_frames = 0; SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier); RandomAccessPosteriorReader targets_reader(targets_rspecifier); RandomAccessBaseFloatVectorReader weights_reader; if (frame_weights != "") { weights_reader.Open(frame_weights); } Xent xent(loss_opts); Mse mse(loss_opts); MultiTaskLoss multitask(loss_opts); if (0 == objective_function.compare(0, 9, "multitask")) { // objective_function contains something like : // 'multitask,xent,2456,1.0,mse,440,0.001' // // the meaning is following: // 'multitask,,,,...,,,' multitask.InitFromString(objective_function); } CuMatrix feats, feats_transf, nnet_out, obj_diff; Timer time; KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED"; int32 num_done = 0, num_no_tgt_mat = 0, num_other_error = 0; // main loop, for ( ; !feature_reader.Done(); feature_reader.Next()) { std::string utt = feature_reader.Key(); KALDI_VLOG(3) << "Reading " << utt; // check that we have targets if (!targets_reader.HasKey(utt)) { KALDI_WARN << utt << ", missing targets"; num_no_tgt_mat++; continue; } // check we have per-frame weights if (frame_weights != "" && !weights_reader.HasKey(utt)) { KALDI_WARN << utt << ", missing per-frame weights"; num_other_error++; feature_reader.Next(); continue; } // get feature / target pair Matrix mat = feature_reader.Value(); Posterior nnet_tgt = targets_reader.Value(utt); // skip the sentence if it is too long, if (mat.NumRows() > max_frames) { KALDI_WARN << "Skipping " << utt << " that has " << mat.NumRows() << " frames," << " it is longer than '--max-frames'" << max_frames; num_other_error++; continue; } // get per-frame weights Vector frm_weights; if (frame_weights != "") { frm_weights = weights_reader.Value(utt); } else { // all per-frame weights are 1.0 frm_weights.Resize(mat.NumRows()); frm_weights.Set(1.0); } // correct small length mismatch ... or drop sentence { // add lengths to vector std::vector length; length.push_back(mat.NumRows()); length.push_back(nnet_tgt.size()); length.push_back(frm_weights.Dim()); // find min, max int32 min = *std::min_element(length.begin(), length.end()); int32 max = *std::max_element(length.begin(), length.end()); // fix or drop ? if (max - min < length_tolerance) { if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData); if (nnet_tgt.size() != min) nnet_tgt.resize(min); if (frm_weights.Dim() != min) frm_weights.Resize(min, kCopyData); } else { KALDI_WARN << utt << ", length mismatch of targets " << nnet_tgt.size() << " and features " << mat.NumRows(); num_other_error++; continue; } } // apply optional feature transform nnet_transf.Feedforward(CuMatrix(mat), &feats_transf); // forward pass nnet.Propagate(feats_transf, &nnet_out); // evaluate objective function we've chosen, if (objective_function == "xent") { // gradients are re-scaled by weights inside Eval, xent.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff); } else if (objective_function == "mse") { // gradients are re-scaled by weights inside Eval, mse.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff); } else if (0 == objective_function.compare(0, 9, "multitask")) { // gradients re-scaled by weights in Eval, multitask.Eval(frm_weights, nnet_out, nnet_tgt, &obj_diff); } else { KALDI_ERR << "Unknown objective function code : " << objective_function; } if (!crossvalidate) { // backpropagate and update, nnet.Backpropagate(obj_diff, NULL); } // 1st minibatch : show what happens in network, if (total_frames == 0) { KALDI_LOG << "### After " << total_frames << " frames,"; KALDI_LOG << nnet.InfoPropagate(); if (!crossvalidate) { KALDI_LOG << nnet.InfoBackPropagate(); KALDI_LOG << nnet.InfoGradient(); } } // VERBOSE LOG // monitor the NN training (--verbose=2), if (GetVerboseLevel() >= 2) { static int32 counter = 0; counter += mat.NumRows(); // print every 25k frames, if (counter >= 25000) { KALDI_VLOG(2) << "### After " << total_frames << " frames,"; KALDI_VLOG(2) << nnet.InfoPropagate(); if (!crossvalidate) { KALDI_VLOG(2) << nnet.InfoBackPropagate(); KALDI_VLOG(2) << nnet.InfoGradient(); } counter = 0; } } num_done++; total_frames += frm_weights.Sum(); } // main loop, // after last minibatch : show what happens in network, KALDI_LOG << "### After " << total_frames << " frames,"; KALDI_LOG << nnet.InfoPropagate(); if (!crossvalidate) { KALDI_LOG << nnet.InfoBackPropagate(); KALDI_LOG << nnet.InfoGradient(); } if (!crossvalidate) { nnet.Write(target_model_filename, binary); } KALDI_LOG << "Done " << num_done << " files, " << num_no_tgt_mat << " with no tgt_mats, " << num_other_error << " with other errors. " << "[" << (crossvalidate ? "CROSS-VALIDATION" : "TRAINING") << ", " << (randomize ? "RANDOMIZED" : "NOT-RANDOMIZED") << ", " << time.Elapsed() / 60 << " min, processing " << total_frames / time.Elapsed() << " frames per sec.]"; if (objective_function == "xent") { KALDI_LOG << xent.ReportPerClass(); KALDI_LOG << xent.Report(); } else if (objective_function == "mse") { KALDI_LOG << mse.Report(); } else if (0 == objective_function.compare(0, 9, "multitask")) { KALDI_LOG << multitask.Report(); } else { KALDI_ERR << "Unknown objective function code : " << objective_function; } #if HAVE_CUDA == 1 CuDevice::Instantiate().PrintProfile(); #endif return 0; } catch(const std::exception &e) { std::cerr << e.what(); return -1; } }