Blame view
datasets/timit.py
2 KB
f2d3bd141 Initial commit wi... |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import tables from collections import OrderedDict from fuel.datasets.hdf5 import PytablesDataset from utils import make_local_copy class Timit(PytablesDataset): """TIMIT dataset. Parameters ---------- which_set : str, opt either 'train', 'dev' or 'test'. alignment : bool Whether return alignment. features : str The features to use. They will lead to the correct h5 file. """ def __init__(self, which_set='train', local_copy=False, **kwargs): #self.path = '/home/parcollt/projects/rpp-bengioy/parcollt/Deep-Quaternary-Convolutional-Neural-Networks/TIMIT/timit_fbank_energy_deltas.h5' self.path = '/u/parcollt/WORKSPACE/QCNN/Deep-Quaternary-Convolutional-Neural-Networks/TIMIT/timit_fbank_energy_deltas.h5' #self.path = '/Users/titouanparcollet/CloudStation/LABO/WORKSPACE/EXPS/QCNN/timit_fbank_energy_deltas.h5' if local_copy and not self.path.startswith('/Tmp'): self.path = make_local_copy(self.path) self.which_set = which_set self.sources = ('features', 'features_shapes', 'phonemes') super(Timit, self).__init__( self.path, self.sources, data_node=which_set, **kwargs) def get_phoneme_dict(self): phoneme_list = self.h5file.root._v_attrs.phones_list return OrderedDict(enumerate(phoneme_list)) def get_phoneme_ind_dict(self): phoneme_list = self.h5file.root._v_attrs.phones_list return OrderedDict(zip(phoneme_list, range(len(phoneme_list)))) def get_normalization_factors(self): means = self.h5file.root._v_attrs.means stds = self.h5file.root._v_attrs.stds return means, stds def open_file(self, path): self.h5file = tables.open_file(path, mode="r") node = self.h5file.get_node('/', self.data_node) self.nodes = [getattr(node, source) for source in self.sources_in_file] if self.stop is None: self.stop = self.nodes[0].nrows self.num_examples = self.stop - self.start |