nnet-batch-compute.h 37.4 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846
// nnet3/nnet-batch-compute.h

// Copyright 2012-2018  Johns Hopkins University (author: Daniel Povey)
//                2018       Hang Lyu

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_NNET3_NNET_BATCH_COMPUTE_H_
#define KALDI_NNET3_NNET_BATCH_COMPUTE_H_

#include <vector>
#include <string>
#include <list>
#include <utility>
#include <condition_variable>
#include "base/kaldi-common.h"
#include "gmm/am-diag-gmm.h"
#include "hmm/transition-model.h"
#include "itf/decodable-itf.h"
#include "nnet3/nnet-optimize.h"
#include "nnet3/nnet-compute.h"
#include "nnet3/am-nnet-simple.h"
#include "nnet3/nnet-am-decodable-simple.h"
#include "decoder/lattice-faster-decoder.h"
#include "util/stl-utils.h"


namespace kaldi {
namespace nnet3 {


/**
   class NnetInferenceTask represents a chunk of an utterance that is
   requested to be computed.  This will be given to NnetBatchComputer, which
   will aggregate the tasks and complete them.
 */
struct NnetInferenceTask {
  // The copy constructor is required to exist because of std::vector's resize()
  // function, but in practice should never be used.
  NnetInferenceTask(const NnetInferenceTask &other) {
    KALDI_ERR << "NnetInferenceTask was not designed to be copied.";
  }
  NnetInferenceTask() { }


  // The input frames, which are treated as being numbered t=0, t=1, etc.  (If
  // the lowest t value was originally nonzero in the 'natural' numbering, this
  // just means we conceptually shift the 't' values; the only real constraint
  // is that the 't' values are contiguous.
  CuMatrix<BaseFloat> input;

  // The index of the first output frame (in the shifted numbering where the
  // first output frame is numbered zero.  This will typically be less than one,
  // because most network topologies require left context.  If this was an
  // 'interior' chunk of a recurrent topology like LSTMs, first_input_t may be
  // substantially less than zero, due to 'extra_left_context'.
  int32 first_input_t;

  // The stride of output 't' values: e.g., will be 1 for normal-frame-rate
  // models, and 3 for low-frame-rate models such as chain models.
  int32 output_t_stride;

  // The number of output 't' values (they will start from zero and be separated
  // by output_t_stride).  This will be the num-rows of 'output'.
  int32 num_output_frames;

  // 'num_initial_unused_output_frames', which will normally be zero, is the
  // number of rows of the output matrix ('output' or 'output_cpu') which won't
  // actually be needed by the user, usually because they overlap with a
  // previous chunk.  This can happen because the number of outputs isn't a
  // multiple of the number of chunks.
  int32 num_initial_unused_output_frames;

  // 0 < num_used_output_frames <= num_output_frames - num_initial_unused_output_frames
  // is the number of output frames which are actually going to be used by the
  // user.  (Due to edge effects, not all are necessarily used).
  int32 num_used_output_frames;

  // first_used_output_frame_index is provided for the convenience of the user
  // so that they can know how this chunk relates to the utterance which it is
  // a part of.
  // It represents an output frame index in the original utterance-- after
  // subsampling; so not a 't' value but a 't' value divided by
  // frame-subsampling-factor.  Specifically, it tells you the row index in the
  // full utterance's output which corresponds to the first 'used' frame index
  // at the output of this chunk, specifically: the row numbered
  // 'num_initial_unused_output_frames' in the 'output' or 'output_cpu' data
  // member.
  int32 first_used_output_frame_index;

  // True if this chunk is an 'edge' (the beginning or end of an utterance) AND
  // is structurally different somehow from non-edge chunk, e.g. requires less
  // context.  This is present only so that NnetBatchComputer will know the
  // appropriate minibatch size to use.
  bool is_edge;

  // True if this task represents an irregular-sized chunk.  These can happen
  // only for utterances that are shorter than the requested minibatch size, and
  // it should be quite rare.  We use a minibatch size of 1 in this case.
  bool is_irregular;

  // The i-vector for this chunk, if this network accepts i-vector inputs.
  CuVector<BaseFloat> ivector;

  // A priority (higher is more urgent); may be either sign.  May be updated
  // after this object is provided to class NnetBatchComputer.
  double priority;

  // This semaphore will be incremented by class NnetBatchComputer when this
  // chunk is done.  After this semaphore is incremented, class
  // NnetBatchComputer will no longer hold any pointers to this class.
  Semaphore semaphore;

  // Will be set to true by the caller if they want the output of the neural net
  // to be copied to CPU (to 'output').  If false, the output will stay on
  // the GPU (if used)- in cu_output.
  bool output_to_cpu;

  // The neural net output, of dimension num_output_frames by the output-dim of
  // the neural net, will be written to 'output_cpu' if 'output_to_cpu' is true.
  // This is expected to be empty when this task is provided to class
  // NnetBatchComputer, and will be nonempty (if output_to_cpu == true) when the
  // task is completed and the semaphore is signaled.
  Matrix<BaseFloat> output_cpu;

  // The output goes here instead of 'output_to_cpu' is false.
  CuMatrix<BaseFloat> output;
};


struct NnetBatchComputerOptions: public NnetSimpleComputationOptions {
  int32 minibatch_size;
  int32 edge_minibatch_size;
  bool ensure_exact_final_context;
  BaseFloat partial_minibatch_factor;

  NnetBatchComputerOptions(): minibatch_size(128),
                              edge_minibatch_size(32),
                              ensure_exact_final_context(false),
                              partial_minibatch_factor(0.5) {
  }

  void Register(OptionsItf *po) {
    NnetSimpleComputationOptions::Register(po);
    po->Register("minibatch-size", &minibatch_size, "Number of chunks per "
                 "minibatch (see also edge-minibatch-size)");
    po->Register("edge-minibatch-size", &edge_minibatch_size, "Number of "
                 "chunks per minibatch: this applies to chunks at the "
                 "beginnings and ends of utterances, in cases (such as "
                 "recurrent models) when the computation would be different "
                 "from the usual one.");
    po->Register("ensure-exact-final-context", &ensure_exact_final_context,
                 "If true, for utterances shorter than --frames-per-chunk, "
                 "use exact-length, special computations.  If false, "
                 "pad with repeats of the last frame.  Would only affect "
                 "the output for backwards-recurrent models, but would "
                 "negatively impact speed in all cases.");
    po->Register("partial-minibatch-factor", &partial_minibatch_factor,
                 "Factor that controls how small partial minibatches will be "
                 "they become necessary.  We will potentially do the computation "
                 "for sizes: int(partial_minibatch_factor^n * minibatch_size "
                 ", for n = 0, 1, 2....  Set it to 0.0 if you want to use "
                 "only the specified minibatch sizes.");
  }
};


/**
   Merges together the 'output_cpu' (if the 'output_to_cpu' members are true) or
   the 'output' members of 'tasks' into a single CPU matrix 'output'.  Requires that
   those outputs are nonempty (i.e. that those tasks must have been completed).

   @param [in] tasks  The vector of tasks whose outputs are to be merged.
         The tasks must have already been completed.
   @param [output  output  The spliced-together output matrix

   TODO: in the future, maybe start from GPU and use pinned matrices for the
   transfer.
 */
void MergeTaskOutput(
    const std::vector<NnetInferenceTask> &tasks,
    Matrix<BaseFloat> *output);
void MergeTaskOutput(
    const std::vector<NnetInferenceTask> &tasks,
    CuMatrix<BaseFloat> *output);

/**
   This class does neural net inference in a way that is optimized for GPU use:
   it combines chunks of multiple utterances into minibatches for more efficient
   computation.  It does the computation in one background thread that accesses
   the GPU.  It is thread safe, i.e. you can call it from multiple threads
   without having to worry about data races and the like.
*/
class NnetBatchComputer {
 public:
  /**  Constructor.  It stores references to all the arguments, so don't delete
       them till this object goes out of scop.

       \param [in] opts  Options struct
       \param [in] nnet  The neural net which we'll be doing the computation with
       \param [in] priors Either the empty vector, or a vector of prior
                        probabilities which we'll take the log of and subtract
                        from the neural net outputs (e.g. used in non-chain
                        systems).
   */
  NnetBatchComputer(const NnetBatchComputerOptions &opts,
                    const Nnet &nnet,
                    const VectorBase<BaseFloat> &priors);


  /// Accepts a task, meaning the task will be queued.  (Note: the pointer is
  /// still owned by the caller.
  /// If the max_minibatches_full >= 0, then the calling thread will block until
  /// no more than that many full minibatches are waiting to be computed.  This
  /// is a mechanism to prevent too many requests from piling up in memory.
  void AcceptTask(NnetInferenceTask *task,
                  int32 max_minibatches_full = -1);

  /// Returns the number of full minibatches waiting to be computed.
  int32 NumFullPendingMinibatches() const { return num_full_minibatches_; }


  /**
      Does some kind of computation, choosing the highest-priority thing to
      compute.  It returns true if it did some kind of computation, and false
      otherwise.  This function locks the class, but not for the entire time
      it's being called: only at the beginning and at the end.
        @param [in] allow_partial_minibatch  If false, then this will only
              do the computation if a full minibatch is ready; if true, it
              is allowed to do computation on partial (not-full) minibatches.
   */
  bool Compute(bool allow_partial_minibatch);


  /**
     Split a single utterance into a list of separate tasks which can then
     be given to this class by AcceptTask().

     @param [in] output_to_cpu  Will become the 'output_to_cpu' member of the
             output tasks; this controls whether the computation code should transfer
             the outputs to CPU (which is to save GPU memory).
     @param [in] ivector  If non-NULL, and i-vector for the whole utterance is
             expected to be supplied here (and online_ivectors should be NULL).
             This is relevant if you estimate i-vectors per speaker instead of
             online.
     @param [in] online_ivectors  Matrix of ivectors, one every 'online_ivector_period' frames.
     @param [in] online_ivector_period  Affects the interpretation of 'online_ivectors'.
     @param [out]  tasks       The tasks created will be output to here.  The
                      priorities will be set to zero; setting them to a meaningful
                      value is up to the caller.
  */
  void SplitUtteranceIntoTasks(
      bool output_to_cpu,
      const Matrix<BaseFloat> &input,
      const Vector<BaseFloat> *ivector,
      const Matrix<BaseFloat> *online_ivectors,
      int32 online_ivector_period,
      std::vector<NnetInferenceTask> *tasks);
  void SplitUtteranceIntoTasks(
      bool output_to_cpu,
      const CuMatrix<BaseFloat> &input,
      const CuVector<BaseFloat> *ivector,
      const CuMatrix<BaseFloat> *online_ivectors,
      int32 online_ivector_period,
      std::vector<NnetInferenceTask> *tasks);

  const NnetBatchComputerOptions &GetOptions() { return opts_; }

  ~NnetBatchComputer();

 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchComputer);

  // Information about a specific minibatch size for a group of tasks sharing a
  // specific structure (in terms of left and right context, etc.)
  struct MinibatchSizeInfo {
    // the computation for this minibatch size.
    std::shared_ptr<const NnetComputation> computation;
    int32 num_done;  // The number of minibatches computed: for diagnostics.
    int64 tot_num_tasks;  // The total number of tasks in those minibatches,
    // also for diagnostics... can be used to compute
    // how 'full', on average, these minibatches were.
    double seconds_taken;  // The total time elapsed in computation for this
                          // minibatch type.
    MinibatchSizeInfo(): computation(NULL), num_done(0),
                         tot_num_tasks(0), seconds_taken(0.0) { }
  };


  // A computation group is a group of tasks that have the same structure
  // (number of input and output frames, left and right context).
  struct ComputationGroupInfo {
    // The tasks to be completed.  This array is added-to by AcceptTask(),
    // and removed-from by GetHighestPriorityComputation(), which is called
    // from Compute().
    std::vector<NnetInferenceTask*> tasks;

    // Map from minibatch-size to information specific to this minibatch-size,
    // including the NnetComputation.  This is set up by
    // GetHighestPriorityComputation(), which is called from Compute().
    std::map<int32, MinibatchSizeInfo> minibatch_info;
  };

  // This struct allows us to arrange the tasks into groups that can be
  // computed in the same minibatch.
  struct ComputationGroupKey {
    ComputationGroupKey(const NnetInferenceTask &task):
        num_input_frames(task.input.NumRows()),
        first_input_t(task.first_input_t),
        num_output_frames(task.num_output_frames) {}

    bool operator == (const ComputationGroupKey &other) const {
      return num_input_frames == other.num_input_frames &&
          first_input_t == other.first_input_t &&
          num_output_frames == other.num_output_frames;
    }
    int32 num_input_frames;
    int32 first_input_t;
    int32 num_output_frames;
  };

  struct ComputationGroupKeyHasher {
    int32 operator () (const ComputationGroupKey &key) const {
      return key.num_input_frames + 18043 * key.first_input_t +
          6413 * key.num_output_frames;
    }
  };


  typedef unordered_map<ComputationGroupKey, ComputationGroupInfo,
                        ComputationGroupKeyHasher> MapType;

  // Gets the priority for a group, higher means higher priority.  (A group is a
  // list of tasks that may be computed in the same minibatch).  What this
  // function does is a kind of heuristic.
  // If allow_partial_minibatch == false, it will set the priority for
  // any minibatches that are not full to negative infinity.
  inline double GetPriority(bool allow_partial_minibatch,
                            const ComputationGroupInfo &info) const;

  // Returns the minibatch size for this group of tasks, i.e. the size of a full
  // minibatch for this type of task, which is what we'd ideally like to
  // compute.  Note: the is_edge and is_irregular options should be the same
  // for for all tasks in the group.
  //   - If 'tasks' is empty or info.is_edge and info.is_irregular are both,
  //     false, then return opts_.minibatch_size
  //   - If 'tasks' is nonempty and tasks[0].is_irregular is true, then
  //     returns 1.
  //   - If 'tasks' is nonempty and tasks[0].is_irregular is false and
  //     tasks[0].is_edge is true, then returns opts_.edge_minibatch_size.
  inline int32 GetMinibatchSize(const ComputationGroupInfo &info) const;


  // This function compiles, and returns, a computation for tasks of
  // the structure present in info.tasks[0], and the specified minibatch
  // size.
  std::shared_ptr<const NnetComputation> GetComputation(
      const ComputationGroupInfo &info,
      int32 minibatch_size);


  // Returns the actual minibatch size we'll use for this computation.  In most
  // cases it will be opts_.minibatch_size (or opts_.edge_minibatch_size if
  // appropriate; but if the number of available tasks is much less than the
  // appropriate minibatch size, it may be less.  The minibatch size may be
  // greater than info.tasks.size(); in that case, the remaining 'n' values in
  // the minibatch are not used.  (It may also be less than info.tasks.size(),
  // in which case we only do some of them).
  int32 GetActualMinibatchSize(const ComputationGroupInfo &info) const;


  // This function gets the highest-priority 'num_tasks' tasks from 'info',
  // removes them from the array info->tasks, and puts them into the array
  // 'tasks' (which is assumed to be initially empty).
  // This function also updates the num_full_minibatches_ variable if
  // necessary, and takes care of notifying any related condition variables.
  void GetHighestPriorityTasks(
      int32 num_tasks,
      ComputationGroupInfo *info,
      std::vector<NnetInferenceTask*> *tasks);

  /**
      This function finds and returns the computation corresponding to the
      highest-priority group of tasks.

       @param [in] allow_partial_minibatch  If this is true, then this
             function may return a computation corresponding to a partial
             minibatch-- i.e. the minibatch size in the computation may be
             less than the minibatch size in the options class, and/or
             the number of tasks may not be as many as the minibatch size
             in the computation.
       @param [out] minibatch_size  If this function returns non-NULL, then
             this will be set to the minibatch size that the returned
             computation expects.  This may be less than tasks->size(),
             in cases where the minibatch was not 'full'.
       @param [out] tasks  The tasks which we'll be doing the computation
             for in this minibatch are put here (and removed from tasks_,
             in cases where this function returns non-NULL.
       @return  This function returns a pointer to the appropriate
             'MinibatchSizeInfo' object corresponding to the computation
             that we'll be doing for this minibatch, or NULL if there is nothing
             to compute.
  */
  MinibatchSizeInfo *GetHighestPriorityComputation(
      bool allow_partial_minibatch,
      int32 *minibatch_size,
      std::vector<NnetInferenceTask*> *tasks);

  /**
     formats the inputs to the computation and transfers them to GPU.
        @param [in]  minibatch_size  The number of parallel sequences
            we're doing this computation for.  This will be
            more than tasks.size() in some cases.
        @param [in] tasks  The tasks we're doing the computation for.
            The input comes from here.
        @param [out] input  The main feature input to the computation is
            put into here.
        @param [out] ivector  If we're using i-vectors, the i-vectors are
            put here.
  */
  void FormatInputs(int32 minibatch_size,
                    const std::vector<NnetInferenceTask*> &tasks,
                    CuMatrix<BaseFloat> *input,
                    CuMatrix<BaseFloat> *ivector);


  // Copies 'output', piece by piece, to the 'output_cpu' or 'output'
  // members of 'tasks', depending on their 'output_to_cpu' value.
  void FormatOutputs(const CuMatrix<BaseFloat> &output,
                     const std::vector<NnetInferenceTask*> &tasks);


  // Changes opts_.frames_per_chunk to be a multiple of
  // opts_.frame_subsampling_factor, if needed.
  void CheckAndFixConfigs();

  // this function creates and returns the computation request which is to be
  // compiled.
  static void GetComputationRequest(const NnetInferenceTask &task,
                                    int32 minibatch_size,
                                    ComputationRequest *request);

  // Prints some logging information about what we computed, with breakdown by
  // minibatch type.
  void PrintMinibatchStats();

  NnetBatchComputerOptions opts_;
  const Nnet &nnet_;
  CachingOptimizingCompiler compiler_;
  CuVector<BaseFloat> log_priors_;

  // Mutex that guards this object.  It is only held for fairly quick operations
  // (not while the actual computation is being done).
  std::mutex mutex_;

  // tasks_ contains all the queued tasks.
  // Each key contains a vector of NnetInferenceTask* pointers, of the same
  // structure (i.e., IsCompatible() returns true).
  MapType tasks_;

  // num_full_minibatches_ is a function of the data in tasks_ (and the
  // minibatch sizes, specified in opts_.  It is the number of full minibatches
  // of tasks that are pending, meaning: for each group of tasks, the number of
  // pending tasks divided by the minibatch-size for that group in integer
  // arithmetic.  This is kept updated for thread synchronization reasons, because
  // it is the shared variable
  int32 num_full_minibatches_;

  // a map from 'n' to a condition variable corresponding to the condition:
  // num_full_minibatches_ <= n.  Any time the number of full minibatches drops
  // below n, the corresponding condition variable is notified (if it exists).
  std::unordered_map<int32, std::condition_variable*> no_more_than_n_minibatches_full_;

  // some static information about the neural net, computed at the start.
  int32 nnet_left_context_;
  int32 nnet_right_context_;
  int32 input_dim_;
  int32 ivector_dim_;
  int32 output_dim_;
};


/**
   This class implements a simplified interface to class NnetBatchComputer,
   which is suitable for programs like 'nnet3-compute' where you want to support
   fast GPU-based inference on a sequence of utterances, and get them back
   from the object in the same order.
 */
class NnetBatchInference {
 public:

  NnetBatchInference(
      const NnetBatchComputerOptions &opts,
      const Nnet &nnet,
      const VectorBase<BaseFloat> &priors);

  /**
    The user should call this one by one for the utterances that this class
    needs to compute (interspersed with calls to GetOutput()).  This call
    will block when enough ready-to-be-computed data is present.

      @param [in] utterance_id  The string representing the utterance-id;
             it will be provided back to the user when GetOutput() is
             called.
      @param [in] input  The input features (e.g. MFCCs)
      @param [in] ivector  If non-NULL, this is expected to be the
             i-vector for this utterance (and 'online_ivectors' should
             be NULL).
      @param [in] online_ivector_period  Only relevant if
             'online_ivector' is non-NULL, this says how many
             frames of 'input' is covered by each row of
             'online_ivectors'.
  */
  void AcceptInput(const std::string &utterance_id,
                   const Matrix<BaseFloat> &input,
                   const Vector<BaseFloat> *ivector,
                   const Matrix<BaseFloat> *online_ivectors,
                   int32 online_ivector_period);

  /**
     The user should call this after the last input has been provided
     via AcceptInput().  This will force the last utterances to be
     flushed out (to be retrieved by GetOutput()), rather than waiting
     until the relevant minibatches are full.
  */
  void Finished();

  /**
      The user should call this to obtain output.  It's guaranteed to
      be in the same order as the input was provided, but it may be
      delayed.  'output' will be the output of the neural net, spliced
      together over the chunks (and with acoustic scaling applied if
      it was specified in the options; the subtraction of priors will
      depend whether you supplied a non-empty vector of priors to the
      constructor.

      This call does not block (i.e. does not wait on any semaphores) unless you
      have previously called Finished().  It returns true if it actually got any
      output; if none was ready it will return false.
  */
  bool GetOutput(std::string *utterance_id,
                 Matrix<BaseFloat> *output);

  ~NnetBatchInference();
 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchInference);

  // This is the computation thread, which is run in the background.  It will
  // exit once the user calls Finished() and all computation is completed.
  void Compute();
  // static wrapper for Compute().
  static void ComputeFunc(NnetBatchInference *object) { object->Compute(); }


  // This object implements the internals of what this class does.  It is
  // accessed both by the main thread (from where AcceptInput(), Finished() and
  // GetOutput() are called), and from the background thread in which Compute()
  // is called.
  NnetBatchComputer computer_;

  // This is set to true when the user calls Finished(); the computation thread
  // sees it and knows to flush
  bool is_finished_;

  // This semaphore is signaled by the main thread (the thread in which
  // AcceptInput() is called) every time a new utterance is added, and waited on
  // in the background thread in which Compute() is called.
  Semaphore tasks_ready_semaphore_;

  struct UtteranceInfo {
    std::string utterance_id;
    // The tasks into which we split this utterance.
    std::vector<NnetInferenceTask> tasks;
    // 'num_tasks_finished' is the number of tasks which are known to be
    // finished, meaning we successfully waited for those tasks' 'semaphore'
    // member.  When this reaches tasks.size(), we are ready to consolidate
    // the output into a single matrix and return it to the user.
    size_t num_tasks_finished;
  };

  // This list is only accessed directly by the main thread, by AcceptInput()
  // and GetOutput().  It is a list of utterances, with more recently added ones
  // at the back.  When utterances are given to the user by GetOutput(),
  std::list<UtteranceInfo*> utts_;

  int32 utterance_counter_;  // counter that increases on every utterance.

  // The thread running the Compute() process.
  std::thread compute_thread_;
};


/**
   Decoder object that uses multiple CPU threads for the graph search, plus a
   GPU for the neural net inference (that's done by a separate
   NnetBatchComputer object).  The interface of this object should
   accessed from only one thread, though-- presumably the main thread of the
   program.
 */
class NnetBatchDecoder {
 public:
  /**
     Constructor.
        @param [in] fst    FST that we are decoding with, will be shared between
                           all decoder threads.
        @param [in] decoder_config  Configuration object for the decoders.
        @param [in] trans_model   The transition model-- needed to construct the decoders,
                           and for determinization.
        @param [in] word_syms  A pointer to a symbol table of words, used for printing
                          the decoded words to stderr.  If NULL, the word-level output will not
                          be logged.
        @param [in] allow_partial   If true, in cases where no final-state was reached
                           on the final frame of the decoding, we still output a lattice;
                           it just may contain partial words (words that are cut off in
                           the middle).  If false, we just won't output anything for
                           those lattices.
        @param [in] num_threads  The number of decoder threads to use.  It will use
                          two more threads on top of this: the main thread, for I/O,
                          and a thread for possibly-GPU-based inference.
        @param [in] computer The NnetBatchComputer object, through which the
                           neural net will be evaluated.
   */
  NnetBatchDecoder(const fst::Fst<fst::StdArc> &fst,
                   const LatticeFasterDecoderConfig &decoder_config,
                   const TransitionModel &trans_model,
                   const fst::SymbolTable *word_syms,
                   bool allow_partial,
                   int32 num_threads,
                   NnetBatchComputer *computer);

  /**
    The user should call this one by one for the utterances that
    it needs to compute (interspersed with calls to GetOutput()).  This
    call will block when no threads are ready to start processing this
    utterance.

      @param [in] utterance_id  The string representing the utterance-id;
             it will be provided back to the user when GetOutput() is
             called.
      @param [in] input  The input features (e.g. MFCCs)
      @param [in] ivector  If non-NULL, this is expected to be the
             i-vector for this utterance (and 'online_ivectors' should
             be NULL).
      @param [in] online_ivector_period  Only relevant if
             'online_ivector' is non-NULL, this says how many
             frames of 'input' is covered by each row of
             'online_ivectors'.
  */
  void AcceptInput(const std::string &utterance_id,
                   const Matrix<BaseFloat> &input,
                   const Vector<BaseFloat> *ivector,
                   const Matrix<BaseFloat> *online_ivectors,
                   int32 online_ivector_period);

  /*
    The user should call this function each time there was a problem with an utterance
    prior to being able to call AcceptInput()-- e.g. missing i-vectors.  This will
    update the num-failed-utterances stats which are stored in this class.
   */
  void UtteranceFailed();

  /*
     The user should call this when all input has been provided, e.g.
     when AcceptInput will not be called any more.  It will block until
     all threads have terminated; after that, you can call GetOutput()
     until it returns false, which will guarantee that nothing remains
     to compute.
     It returns the number of utterances that have been successfully decoded.
   */
  int32 Finished();

  /**
      The user should call this to obtain output (This version should
      only be called if config.determinize_lattice == true (w.r.t. the
      config provided to the constructor).  The output is guaranteed to
      be in the same order as the input was provided, but it may be
      delayed, *and* some outputs may be missing, for example because
      of search failures (allow_partial will affect this).

      The acoustic scores in the output lattice will already be divided by
      the acoustic scale we decoded with.

      This call does not block (i.e. does not wait on any semaphores).  It
      returns true if it actually got any output; if none was ready it will
      return false.
         @param [out] utterance_id  If an output was ready, its utterance-id is written to here.
         @param [out] clat  If an output was ready, it compact lattice will be
                            written to here.
         @param [out] sentence  If an output was ready and a nonempty symbol table
                            was provided to the constructor of this class, contains
                            the word-sequence decoded as a string.  Otherwise will
                            be empty.
         @return  Returns true if a decoded output was ready.  (These appear asynchronously
                  as the decoding is done in background threads).
  */
  bool GetOutput(std::string *utterance_id,
                 CompactLattice *clat,
                 std::string *sentence);

  // This version of GetOutput is for where config.determinize_lattice == false
  // (w.r.t. the config provided to the constructor).  It is the same as the
  // other version except it outputs to a normal Lattice, not a CompactLattice.
  bool GetOutput(std::string *utterance_id,
                 Lattice *lat,
                 std::string *sentence);

  ~NnetBatchDecoder();

 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(NnetBatchDecoder);

  struct UtteranceInput {
    std::string utterance_id;
    const Matrix<BaseFloat> *input;
    const Vector<BaseFloat> *ivector;
    const Matrix<BaseFloat> *online_ivectors;
    int32 online_ivector_period;
  };

  // This object is created when a thread finished an utterance.  For utterances
  // where decoding failed somehow, the relevant lattice (compact_lat, if
  // opts_.determinize == true, or lat otherwise) will be empty (have no
  // states).
  struct UtteranceOutput {
    std::string utterance_id;
    bool finished;
    CompactLattice compact_lat;
    Lattice lat;
    std::string sentence;  // 'sentence' is only nonempty if a non-NULL symbol
                           // table was provided to the constructor of class
                           // NnetBatchDecoder; it's the sentence as a string (a
                           // sequence of words separated by space).  It's used
                           // for printing the sentence to stderr, which we do
                           // in the main thread to keep the order consistent.
  };

  // This is the decoding thread, several copies of which are run in the
  // background.  It will exit once the user calls Finished() and all
  // computation is completed.
  void Decode();
  // static wrapper for Compute().
  static void DecodeFunc(NnetBatchDecoder *object) { object->Decode(); }

  // This is the computation thread; it handles the neural net inference.
  void Compute();
  // static wrapper for Compute().
  static void ComputeFunc(NnetBatchDecoder *object) { object->Compute(); }


  // Sets the priorities of the tasks in a newly provided utterance.
  void SetPriorities(std::vector<NnetInferenceTask> *tasks);

  // In the single-thread case, this sets priority_offset_ to 'priority'.
  // In the multi-threaded case it causes priority_offset_ to approach
  // 'priority' at a rate that depends on the nunber of threads.
  void UpdatePriorityOffset(double priority);

  // This function does the determinization (if needed) and finds the best path through
  // the lattice to update the stats.  It is expected that when it is called, 'output' must
  // have its 'lat' member set up.
  void ProcessOutputUtterance(UtteranceOutput *output);

  const fst::Fst<fst::StdArc> &fst_;
  const LatticeFasterDecoderConfig &decoder_opts_;
  const TransitionModel &trans_model_;
  const fst::SymbolTable *word_syms_;  // May be NULL.  Owned here.
  bool allow_partial_;
  NnetBatchComputer *computer_;
  std::vector<std::thread*> decode_threads_;
  std::thread compute_thread_;  // Thread that calls computer_->Compute().


  // 'input_utterance', together with utterance_ready_semaphore_ and
  // utterance_consumed_semaphore_, use used to 'hand off' information about a
  // newly provided utterance from AcceptInput() to a decoder thread that is
  // ready to process a new utterance.
  UtteranceInput input_utterance_;
  Semaphore input_ready_semaphore_;  // Is signaled by the main thread when
                                     // AcceptInput() is called and a new
                                     // utterance is being provided (or when the
                                     // input is finished), and waited on in
                                     // decoder thread.
  Semaphore input_consumed_semaphore_;  // Is signaled in decoder thread when it
                                        // has finished consuming the input, so
                                        // the main thread can know when it
                                        // should continue (to avoid letting
                                        // 'input' go out of scope while it's
                                        // still needed).

  Semaphore tasks_ready_semaphore_; // Is signaled when new tasks are added to
                                    // the computer_ object (or when we're finished).

  bool is_finished_;  // True if the input is finished.  If this is true, a
                      // signal to input_ready_semaphore_ indicates to the
                      // decoder thread that it should terminate.

  bool tasks_finished_;  // True if we know that no more tasks will be given
                         // to the computer_ object.


  // pending_utts_ is a list of utterances that have been provided via
  // AcceptInput(), but their decoding has not yet finished.  AcceptInput() will
  // push_back to it, and GetOutput() will pop_front().  When a decoding thread
  // has finished an utterance it will set its 'finished' member to true.  There
  // is no need to synchronize or use mutexes here.
  std::list<UtteranceOutput*> pending_utts_;

  // priority_offset_ is something used in determining the priorities of nnet
  // computation tasks.  It starts off at zero and becomes more negative with
  // time, with the aim being that the priority of the first task (i.e. the
  // leftmost chunk) of a new utterance should be at about the same priority as
  // whatever chunks we are just now getting around to decoding.
  double priority_offset_;

  // Some statistics accumulated by this class, for logging and timing purposes.
  double tot_like_;  // Total likelihood (of best path) over all lattices that
                     // we output.
  int64 frame_count_;  // Frame count over all latices that we output.
  int32 num_success_;  // Number of successfully decoded files.
  int32 num_fail_;  // Number of files where decoding failed.
  int32 num_partial_;  // Number of files that were successfully decoded but
                       // reached no final-state (can only be nonzero if
                       // allow_partial_ is true).
  std::mutex stats_mutex_;  // Mutex that guards the statistics from tot_like_
                            // through num_partial_.
  Timer timer_;  // Timer used to print real-time info.
};


}  // namespace nnet3
}  // namespace kaldi

#endif  // KALDI_NNET3_NNET_BATCH_COMPUTE_H_