// nnet3/attention-test.cc // Copyright 2017 Hossein Hadian // 2017 Johns Hopkins University (author: Daniel Povey) // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "nnet3/attention.h" #include "util/common-utils.h" namespace kaldi { namespace nnet3 { namespace attention { // (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift)) void GetAttentionDotProductsSimple(BaseFloat alpha, const CuMatrixBase &A, const CuMatrixBase &B, CuMatrixBase *C) { KALDI_ASSERT(A.NumCols() == B.NumCols() && A.NumRows() == C->NumRows()); int32 input_num_cols = A.NumCols(), num_extra_rows = B.NumRows() - A.NumRows(), context_dim = C->NumCols(); KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0); int32 row_shift = num_extra_rows / (context_dim - 1); for (int32 i = 0; i < C->NumRows(); i++) { for (int32 j = 0; j < C->NumCols(); j++) { (*C)(i, j) = 0.0; for (int32 k = 0; k < input_num_cols; k++) { (*C)(i, j) += alpha * A(i, k) * B(i + (j * row_shift), k); } } } } // A->Row(i) += \sum_k alpha * C(i, k) * B.Row(i + k * row_shift). void ApplyScalesToOutputSimple(BaseFloat alpha, const CuMatrixBase &B, const CuMatrixBase &C, CuMatrixBase *A) { KALDI_ASSERT(A->NumCols() == B.NumCols() && A->NumRows() == C.NumRows()); int32 num_extra_rows = B.NumRows() - A->NumRows(), context_dim = C.NumCols(); KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0); int32 row_shift = num_extra_rows / (context_dim - 1); for (int32 i = 0; i < A->NumRows(); i++) { for (int32 j = 0; j < A->NumCols(); j++) { for (int32 k = 0; k < context_dim; k++) { (*A)(i, j) += alpha * C(i, k) * B(i + (k * row_shift), j); } } } } // B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i). void ApplyScalesToInputSimple(BaseFloat alpha, const CuMatrixBase &A, const CuMatrixBase &C, CuMatrixBase *B) { KALDI_ASSERT(A.NumCols() == B->NumCols() && A.NumRows() == C.NumRows()); int32 num_extra_rows = B->NumRows() - A.NumRows(), context_dim = C.NumCols(); KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0); int32 row_shift = num_extra_rows / (context_dim - 1); for (int32 i = 0; i < A.NumRows(); i++) { for (int32 j = 0; j < A.NumCols(); j++) { for (int32 k = 0; k < context_dim; k++) { (*B)(i + (k * row_shift), j) += alpha * C(i, k) * A(i, j); } } } } void UnitTestAttentionDotProductAndAddScales() { int32 output_num_rows = RandInt(1, 50), input_num_cols = RandInt(1, 10), row_shift = RandInt(1, 5), context_dim = RandInt(2, 5), num_extra_rows = (context_dim - 1) * row_shift, input_num_rows = output_num_rows + num_extra_rows; BaseFloat alpha = 0.25 * RandInt(1, 5); CuMatrix A(output_num_rows, input_num_cols), B(input_num_rows, input_num_cols), C(output_num_rows, context_dim); B.SetRandn(); C.SetRandn(); A.Set(0.0); CuMatrix A2(A); ApplyScalesToOutput(alpha, B, C, &A); ApplyScalesToOutputSimple(alpha, B, C, &A2); AssertEqual(A, A2); CuMatrix C2(C); GetAttentionDotProductsSimple(alpha, A, B, &C); GetAttentionDotProducts(alpha, A, B, &C2); AssertEqual(C, C2); CuMatrix B2(B); ApplyScalesToInput(alpha, A, C, &B); ApplyScalesToInputSimple(alpha, A, C, &B2); AssertEqual(B, B2); } void TestAttentionForwardBackward() { BaseFloat key_scale = 0.5 * RandInt(1, 3); BaseFloat epsilon = 1.0e-03; int32 test_dim = 3; bool output_context = (RandInt(0, 1) == 0); int32 output_num_rows = RandInt(1, 50), value_dim = RandInt(10, 30), key_dim = RandInt(10, 30), row_shift = RandInt(1, 5), context_dim = RandInt(2, 5), num_extra_rows = (context_dim - 1) * row_shift, input_num_rows = output_num_rows + num_extra_rows, query_dim = key_dim + context_dim; CuMatrix keys(input_num_rows, key_dim), queries(output_num_rows, query_dim), values(input_num_rows, value_dim), C(output_num_rows, context_dim), output(output_num_rows, value_dim + (output_context ? context_dim : 0)); keys.SetRandn(); queries.SetRandn(); values.SetRandn(); AttentionForward(key_scale, keys, queries, values, &C, &output); CuMatrix keys_deriv(input_num_rows, key_dim), queries_deriv(output_num_rows, query_dim), values_deriv(input_num_rows, value_dim), output_deriv(output_num_rows, output.NumCols()); output_deriv.SetRandn(); AttentionBackward(key_scale, keys, queries, values, C, output_deriv, &keys_deriv, &queries_deriv, &values_deriv); BaseFloat objf_baseline = TraceMatMat(output_deriv, output, kTrans); { // perturb the values and see if the objf changes as predicted. Vector predicted_vec(test_dim), observed_vec(test_dim); for (int32 i = 0; i < test_dim; i++) { CuMatrix values2(input_num_rows, value_dim); values2.SetRandn(); values2.Scale(epsilon); BaseFloat predicted_delta_objf = TraceMatMat(values_deriv, values2, kTrans); values2.AddMat(1.0, values); output.SetZero(); AttentionForward(key_scale, keys, queries, values2, &C, &output); BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans), observed_delta_objf = objf2 - objf_baseline; KALDI_LOG << "Changing values: predicted objf change is " << predicted_delta_objf << ", observed objf change is " << observed_delta_objf; predicted_vec(i) = predicted_delta_objf; observed_vec(i) = observed_delta_objf; } KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1)); } { // perturb the keys and see if the objf changes as predicted. Vector predicted_vec(test_dim), observed_vec(test_dim); for (int32 i = 0; i < test_dim; i++) { CuMatrix keys2(input_num_rows, key_dim); keys2.SetRandn(); keys2.Scale(epsilon); BaseFloat predicted_delta_objf = TraceMatMat(keys_deriv, keys2, kTrans); keys2.AddMat(1.0, keys); output.SetZero(); AttentionForward(key_scale, keys2, queries, values, &C, &output); BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans), observed_delta_objf = objf2 - objf_baseline; KALDI_LOG << "Changing keys: predicted objf change is " << predicted_delta_objf << ", observed objf change is " << observed_delta_objf; predicted_vec(i) = predicted_delta_objf; observed_vec(i) = observed_delta_objf; } KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1)); } { // perturb the queries and see if the objf changes as predicted. Vector predicted_vec(test_dim), observed_vec(test_dim); for (int32 i = 0; i < test_dim; i++) { CuMatrix queries2(output_num_rows, query_dim); queries2.SetRandn(); queries2.Scale(epsilon); BaseFloat predicted_delta_objf = TraceMatMat(queries_deriv, queries2, kTrans); queries2.AddMat(1.0, queries); output.SetZero(); AttentionForward(key_scale, keys, queries2, values, &C, &output); BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans), observed_delta_objf = objf2 - objf_baseline; KALDI_LOG << "Changing queries: predicted objf change is " << predicted_delta_objf << ", observed objf change is " << observed_delta_objf; predicted_vec(i) = predicted_delta_objf; observed_vec(i) = observed_delta_objf; } KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1)); } } void UnitTestAttention() { UnitTestAttentionDotProductAndAddScales(); TestAttentionForwardBackward(); } } // namespace attention } // namespace nnet3 } // namespace kaldi int main() { using namespace kaldi; using namespace kaldi::nnet3; using namespace kaldi::nnet3::attention; for (int32 loop = 0; loop < 2; loop++) { #if HAVE_CUDA == 1 CuDevice::Instantiate().SetDebugStrideMode(true); if (loop == 0) CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU else CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection #endif for (int32 i = 0; i < 5; i++) { UnitTestAttention(); } } }