Blame view

egs/aspire/s5/local/multi_condition/normalize_wavs.py 4.17 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  #!/usr/bin/env python
  # Copyright 2014  Johns Hopkins University (Authors: Vijayaditya Peddinti).  Apache 2.0.
  
  # normalizes the wave files provided in input file list with a common scaling factor
  # the common scaling factor is computed to 1/\sqrt(1/(total_samples) * \sum_i{\sum_j x_i(j)^2}) where total_samples is sum of all samples of all wavefiles. If the data is multi-channel then each channel is treated as a seperate wave files
  from __future__ import division
  from __future__ import print_function
  import argparse, scipy.io.wavfile, warnings, numpy as np, math
  
  def get_normalization_coefficient(file_list, is_rir, additional_scaling):
    assert(len(file_list) > 0)
    sampling_rate = None
    total_energy = 0.0
    total_samples = 0.0
    prev_dtype_max_value = None
    for file in file_list:
      try:
        [rate, data] = scipy.io.wavfile.read(file)
        if not str(data.dtype) in set(['int16', 'int32', 'int64']):
          raise Exception('Cannot process {0}, only wav files of integer type are suppported'.format(file))
  
        dtype_max_value = np.iinfo(data.dtype).max
        # ensure that all the data in the current list is of the same format
        if prev_dtype_max_value is not None:
          assert(dtype_max_value == prev_dtype_max_value)
        prev_dtype_max_value = dtype_max_value
  
        if len(data.shape) == 1:
          data = data.reshape([data.shape[0], 1])
        if sampling_rate is not None:
          assert(rate == sampling_rate)
        else:
          sampling_rate = rate
        data = data/dtype_max_value
        if is_rir:
          # just count the energy of the direct impulse response
          # this is treated as energy of signal from 0.001 seconds before impulse
          # to 0.05 seconds after impulse. This is done as we do not want the 
          # recording length to influence the scaling factor
          channel_one = data[:, 0]
          max_d = max(channel_one)
          delay_impulse = [i for i, j in enumerate(channel_one) if j == max_d][0]
          before_impulse = np.floor(rate * 0.001)
          after_impulse = np.floor(rate * 0.05)
          start_index = int(max(0, delay_impulse - before_impulse))
          end_index = int(min(len(channel_one), delay_impulse + after_impulse))
        else:
          start_index = 0
          end_index = data.shape[0]
        # numpy does not check for numerical overflow in integer type
        # so we convert the data into floats
        data = data.astype(np.float64)
        total_energy += np.sum(data[start_index:end_index, :] ** 2)
        data_shape = list(data.shape)
        data_shape[0] = end_index-start_index
        total_samples += np.prod(data_shape)
      except IOError:
        warnings.warn("Did not find the file {0}.".format(file))
    assert(total_samples > 0)
    scaling_coefficient = np.sqrt(total_samples/total_energy)
    print("Scaling coefficient is {0}.".format(scaling_coefficient))
    if math.isnan(scaling_coefficient):
      raise Exception(" Nan encountered while computing scaling coefficient. This is mostly due to numerical overflow")
    return scaling_coefficient
  
  if __name__ == "__main__":
    usage = """ Python script to normalize input wave file list"""
  
    parser = argparse.ArgumentParser(usage)
    parser.add_argument('--is-room-impulse-response', type=str, default = "false",  help='is the input a list of room impulse responses', choices = ['True', 'False', 'true', 'false'])
    parser.add_argument('--extra-scaling-factor', type=float, default = 1.0,  help='additional scaling factor to be multiplied with the wav files')
    parser.add_argument('input_file_list', type=str, help='list of wav files to be normalized collectively')
    parser.add_argument('output_file', type=str, help='output file to store normalization coefficient')
    params = parser.parse_args() 
    if params.is_room_impulse_response.lower() == 'true':
      params.is_room_impulse_response = True
    else:
      params.is_room_impulse_response = False
  
    file_list = []
    for line in  open(params.input_file_list).readlines():
      if len(line.strip()) > 0 :
        file_list.append(line.strip())
    norm_coefficient = get_normalization_coefficient(file_list, params.is_room_impulse_response, params.extra_scaling_factor)
    out_file = open(params.output_file, 'w')
    out_file.write('{0}'.format(norm_coefficient))
    out_file.close()