Blame view

datasets/utils.py 3.4 KB
f2d3bd141   Parcollet Titouan   Initial commit wi...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  import hashlib
  import os
  import os
  import shutil
  from fuel.streams import DataStream
  from fuel.transformers import (Mapping, ForceFloatX, Padding,
                                 SortMapping, Cast)
  from fuel.schemes import ShuffledScheme
  from schemes import SequentialShuffledScheme
  from transformers import (MaximumFrameCache, Transpose, Normalize,
                            Reshape, Subsample, ConvReshape, DictRep)
  
  phone_to_phoneme_dict = {'ao':   'aa',
                           'ax':   'ah',
                           'ax-h': 'ah',
                           'axr':  'er',
                           'hv':   'hh',
                           'ix':   'ih',
                           'el':   'l',
                           'em':   'm',
                           'en':   'n',
                           'nx':   'n',
                           'eng':   'ng',
                           'zh':   'sh',
                           'pcl':  'sil',
                           'tcl':  'sil',
                           'kcl':  'sil',
                           'bcl':  'sil',
                           'dcl':  'sil',
                           'gcl':  'sil',
                           'h#':   'sil',
                           'pau':  'sil',
                           'epi':  'sil',
                           'ux':   'uw'}
  
  def file_hash(afile, blocksize=65536):
      buf = afile.read(blocksize)
      hasher = hashlib.md5()
      while len(buf) > 0:
          hasher.update(buf)
          buf = afile.read(blocksize)
      return hasher.digest()
  
  def make_local_copy(filename):
      local_name = os.path.join('/Tmp/', os.environ['USER'],
                                os.path.basename(filename))
      if (not os.path.isfile(local_name) or
                  file_hash(open(filename)) != file_hash(open(local_name))):
          print '.. made local copy at', local_name
          shutil.copy(filename, local_name)
      return local_name
  
  def key(x):
      return x[0].shape[0]
  
  def construct_conv_stream(dataset, rng, pool_size, maximum_frames,
                            quaternion=False, **kwargs):
      """Construct data stream.
      Parameters:
      -----------
      dataset : Dataset
          Dataset to use.
      rng : numpy.random.RandomState
      Random number generator.
      pool_size : int
          Pool size for TIMIT dataset.
      maximum_frames : int
          Maximum frames for TIMIT datset.
      subsample : bool, optional
          Subsample features.
      """
      stream = DataStream(
          dataset,
          iteration_scheme=SequentialShuffledScheme(dataset.num_examples,
                                                    pool_size, rng))
      stream = Reshape('features', 'features_shapes', data_stream=stream)
      means, stds = dataset.get_normalization_factors()
      stream = Normalize(stream, means, stds)
      stream.produces_examples = False
      stream = Mapping(stream,
                       SortMapping(key=key))
      stream = MaximumFrameCache(max_frames=maximum_frames, data_stream=stream,
                                 rng=rng)
      stream = Padding(data_stream=stream,
                       mask_sources=['features', 'phonemes'])
      stream = Transpose(stream, [(0, 1, 2), (1, 0), (0, 1), (1, 0)])
      stream = ConvReshape('features',
                            data_stream=stream, quaternion=quaternion)
      stream = Transpose(stream, [(0, 2, 3, 1), (0, 1), (0, 1), (0, 1)])
      stream.produces_examples = False
      stream = Cast(stream, 'int32', which_sources=('phonemes',))
      stream = ForceFloatX(stream)
      return DictRep(stream)