Blame view

src/rnnlm/rnnlm-core-compute.cc 4.41 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
  // rnnlm/rnnlm-core-compute.cc
  
  // Copyright 2017  Daniel Povey
  
  // See ../../COPYING for clarification regarding multiple authors
  //
  // Licensed under the Apache License, Version 2.0 (the "License");
  // you may not use this file except in compliance with the License.
  // You may obtain a copy of the License at
  //
  //  http://www.apache.org/licenses/LICENSE-2.0
  //
  // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
  // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
  // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
  // MERCHANTABLITY OR NON-INFRINGEMENT.
  // See the Apache 2 License for the specific language governing permissions and
  // limitations under the License.
  
  #include <numeric>
  #include "rnnlm/rnnlm-core-compute.h"
  #include "rnnlm/rnnlm-example-utils.h"
  #include "nnet3/nnet-utils.h"
  
  namespace kaldi {
  namespace rnnlm {
  
  
  BaseFloat RnnlmCoreComputer::Compute(
      const RnnlmExample &minibatch,
      const RnnlmExampleDerived &derived,
      const CuMatrixBase<BaseFloat> &word_embedding,
      BaseFloat *weight,
      CuMatrixBase<BaseFloat> *word_embedding_deriv) {
    using namespace nnet3;
  
    bool need_model_derivative = false;
    bool need_input_derivative = (word_embedding_deriv != NULL);
    bool store_component_stats = false;
  
    ComputationRequest request;
    GetRnnlmComputationRequest(minibatch, need_model_derivative,
                               need_input_derivative,
                               store_component_stats,
                               &request);
  
    std::shared_ptr<const NnetComputation> computation = compiler_.Compile(request);
  
    NnetComputeOptions compute_opts;
  
    NnetComputer computer(compute_opts, *computation, nnet_, NULL);
  
    ProvideInput(minibatch, derived, word_embedding, &computer);
    computer.Run();  // This is the forward pass.
  
    BaseFloat ans = ProcessOutput(minibatch, derived, word_embedding,
                                  &computer, word_embedding_deriv, weight);
  
    if (word_embedding_deriv != NULL) {
      computer.Run();  // This is the backward pass.
  
      CuMatrix<BaseFloat> input_deriv;
      computer.GetOutputDestructive("input", &input_deriv);
      word_embedding_deriv->AddMatSmat(1.0, input_deriv,
                                       derived.input_words_smat,
                                       kTrans, 1.0);
    }
    num_minibatches_processed_++;
  
    return ans;
  }
  
  
  void RnnlmCoreComputer::ProvideInput(
      const RnnlmExample &minibatch,
      const RnnlmExampleDerived &derived,
      const CuMatrixBase<BaseFloat> &word_embedding,
      nnet3::NnetComputer *computer) {
    int32 embedding_dim = word_embedding.NumCols();
    CuMatrix<BaseFloat> input_embeddings(derived.cu_input_words.Dim(),
                                         embedding_dim,
                                         kUndefined);
    input_embeddings.CopyRows(word_embedding,
                              derived.cu_input_words);
    computer->AcceptInput("input", &input_embeddings);
  }
  
  
  RnnlmCoreComputer::RnnlmCoreComputer(const nnet3::Nnet &nnet):
      nnet_(nnet),
      compiler_(nnet),  // for now we don't make available other options
      num_minibatches_processed_(0),
      objf_info_(10) { }
  
  BaseFloat RnnlmCoreComputer::ProcessOutput(
      const RnnlmExample &minibatch,
      const RnnlmExampleDerived &derived,
      const CuMatrixBase<BaseFloat> &word_embedding,
      nnet3::NnetComputer *computer,
      CuMatrixBase<BaseFloat> *word_embedding_deriv,
      BaseFloat *weight_out) {
    // 'output' is the output of the neural network.  The row-index
    // combines the time (with higher stride) and the member 'n'
    // of the minibatch (with stride 1); the number of columns is
    // the word-embedding dimension.
    CuMatrix<BaseFloat> output;
    CuMatrix<BaseFloat> output_deriv;
    computer->GetOutputDestructive("output", &output);
    output_deriv.Resize(output.NumRows(), output.NumCols());
  
    BaseFloat weight, objf_num, objf_den, objf_den_exact;
  
  
    RnnlmObjectiveOptions objective_opts;  // Use the defaults; we're not training
                                           // so they won't matter.
    ProcessRnnlmOutput(objective_opts, minibatch, derived, word_embedding,
                       output, word_embedding_deriv, &output_deriv,
                       &weight, &objf_num, &objf_den,
                       &objf_den_exact);
  
    objf_info_.AddStats(weight, objf_num, objf_den, objf_den_exact);
    if (weight_out)
      *weight_out = weight;
    return objf_num + objf_den;
  }
  
  
  
  
  }  // namespace rnnlm
  }  // namespace kaldi