signal.cc
4.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// feat/signal.cc
// Copyright 2015 Tom Ko
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "feat/signal.h"
namespace kaldi {
void ElementwiseProductOfFft(const Vector<BaseFloat> &a, Vector<BaseFloat> *b) {
int32 num_fft_bins = a.Dim() / 2;
for (int32 i = 0; i < num_fft_bins; i++) {
// do complex multiplication
ComplexMul(a(2*i), a(2*i + 1), &((*b)(2*i)), &((*b)(2*i + 1)));
}
}
void ConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFloat> *signal) {
int32 signal_length = signal->Dim();
int32 filter_length = filter.Dim();
int32 output_length = signal_length + filter_length - 1;
Vector<BaseFloat> signal_padded(output_length);
signal_padded.SetZero();
for (int32 i = 0; i < signal_length; i++) {
for (int32 j = 0; j < filter_length; j++) {
signal_padded(i + j) += (*signal)(i) * filter(j);
}
}
signal->Resize(output_length);
signal->CopyFromVec(signal_padded);
}
void FFTbasedConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFloat> *signal) {
int32 signal_length = signal->Dim();
int32 filter_length = filter.Dim();
int32 output_length = signal_length + filter_length - 1;
int32 fft_length = RoundUpToNearestPowerOfTwo(output_length);
KALDI_VLOG(1) << "fft_length for full signal convolution is " << fft_length;
SplitRadixRealFft<BaseFloat> srfft(fft_length);
Vector<BaseFloat> filter_padded(fft_length);
filter_padded.Range(0, filter_length).CopyFromVec(filter);
srfft.Compute(filter_padded.Data(), true);
Vector<BaseFloat> signal_padded(fft_length);
signal_padded.Range(0, signal_length).CopyFromVec(*signal);
srfft.Compute(signal_padded.Data(), true);
ElementwiseProductOfFft(filter_padded, &signal_padded);
srfft.Compute(signal_padded.Data(), false);
signal_padded.Scale(1.0 / fft_length);
signal->Resize(output_length);
signal->CopyFromVec(signal_padded.Range(0, output_length));
}
void FFTbasedBlockConvolveSignals(const Vector<BaseFloat> &filter, Vector<BaseFloat> *signal) {
int32 signal_length = signal->Dim();
int32 filter_length = filter.Dim();
int32 output_length = signal_length + filter_length - 1;
signal->Resize(output_length, kCopyData);
KALDI_VLOG(1) << "Length of the filter is " << filter_length;
int32 fft_length = RoundUpToNearestPowerOfTwo(4 * filter_length);
KALDI_VLOG(1) << "Best FFT length is " << fft_length;
int32 block_length = fft_length - filter_length + 1;
KALDI_VLOG(1) << "Block size is " << block_length;
SplitRadixRealFft<BaseFloat> srfft(fft_length);
Vector<BaseFloat> filter_padded(fft_length);
filter_padded.Range(0, filter_length).CopyFromVec(filter);
srfft.Compute(filter_padded.Data(), true);
Vector<BaseFloat> temp_pad(filter_length - 1);
temp_pad.SetZero();
Vector<BaseFloat> signal_block_padded(fft_length);
for (int32 po = 0; po < output_length; po += block_length) {
// get a block of the signal
int32 process_length = std::min(block_length, output_length - po);
signal_block_padded.SetZero();
signal_block_padded.Range(0, process_length).CopyFromVec(signal->Range(po, process_length));
srfft.Compute(signal_block_padded.Data(), true);
ElementwiseProductOfFft(filter_padded, &signal_block_padded);
srfft.Compute(signal_block_padded.Data(), false);
signal_block_padded.Scale(1.0 / fft_length);
// combine the block
if (po + block_length < output_length) { // current block is not the last block
signal->Range(po, block_length).CopyFromVec(signal_block_padded.Range(0, block_length));
signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad);
temp_pad.CopyFromVec(signal_block_padded.Range(block_length, filter_length - 1));
} else {
signal->Range(po, output_length - po).CopyFromVec(
signal_block_padded.Range(0, output_length - po));
if (filter_length - 1 < output_length - po)
signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad);
else
signal->Range(po, output_length - po).AddVec(1.0, temp_pad.Range(0, output_length - po));
}
}
}
}