Blame view
datasets/utils.py
3.4 KB
f2d3bd141 Initial commit wi... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import hashlib import os import os import shutil from fuel.streams import DataStream from fuel.transformers import (Mapping, ForceFloatX, Padding, SortMapping, Cast) from fuel.schemes import ShuffledScheme from schemes import SequentialShuffledScheme from transformers import (MaximumFrameCache, Transpose, Normalize, Reshape, Subsample, ConvReshape, DictRep) phone_to_phoneme_dict = {'ao': 'aa', 'ax': 'ah', 'ax-h': 'ah', 'axr': 'er', 'hv': 'hh', 'ix': 'ih', 'el': 'l', 'em': 'm', 'en': 'n', 'nx': 'n', 'eng': 'ng', 'zh': 'sh', 'pcl': 'sil', 'tcl': 'sil', 'kcl': 'sil', 'bcl': 'sil', 'dcl': 'sil', 'gcl': 'sil', 'h#': 'sil', 'pau': 'sil', 'epi': 'sil', 'ux': 'uw'} def file_hash(afile, blocksize=65536): buf = afile.read(blocksize) hasher = hashlib.md5() while len(buf) > 0: hasher.update(buf) buf = afile.read(blocksize) return hasher.digest() def make_local_copy(filename): local_name = os.path.join('/Tmp/', os.environ['USER'], os.path.basename(filename)) if (not os.path.isfile(local_name) or file_hash(open(filename)) != file_hash(open(local_name))): print '.. made local copy at', local_name shutil.copy(filename, local_name) return local_name def key(x): return x[0].shape[0] def construct_conv_stream(dataset, rng, pool_size, maximum_frames, quaternion=False, **kwargs): """Construct data stream. Parameters: ----------- dataset : Dataset Dataset to use. rng : numpy.random.RandomState Random number generator. pool_size : int Pool size for TIMIT dataset. maximum_frames : int Maximum frames for TIMIT datset. subsample : bool, optional Subsample features. """ stream = DataStream( dataset, iteration_scheme=SequentialShuffledScheme(dataset.num_examples, pool_size, rng)) stream = Reshape('features', 'features_shapes', data_stream=stream) means, stds = dataset.get_normalization_factors() stream = Normalize(stream, means, stds) stream.produces_examples = False stream = Mapping(stream, SortMapping(key=key)) stream = MaximumFrameCache(max_frames=maximum_frames, data_stream=stream, rng=rng) stream = Padding(data_stream=stream, mask_sources=['features', 'phonemes']) stream = Transpose(stream, [(0, 1, 2), (1, 0), (0, 1), (1, 0)]) stream = ConvReshape('features', data_stream=stream, quaternion=quaternion) stream = Transpose(stream, [(0, 2, 3, 1), (0, 1), (0, 1), (0, 1)]) stream.produces_examples = False stream = Cast(stream, 'int32', which_sources=('phonemes',)) stream = ForceFloatX(stream) return DictRep(stream) |