Blame view
egs/cifar/v1/local/process_data.py
5.64 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
#!/usr/bin/env python # Copyright 2017 Johns Hopkins University (author: Hossein Hadian) # Apache 2.0 """ This script prepares the training and test data for CIFAR-10 or CIFAR-100. """ from __future__ import division import argparse import os import sys parser = argparse.ArgumentParser(description="""Converts train/test data of CIFAR-10 or CIFAR-100 to Kaldi feature format""") parser.add_argument('database', default='data/dl/cifar-10-batches-bin', help='path to downloaded cifar data (binary version)') parser.add_argument('dir', help='output dir') parser.add_argument('--cifar-version', default='CIFAR-10', choices=['CIFAR-10', 'CIFAR-100']) parser.add_argument('--dataset', default='train', choices=['train', 'test']) parser.add_argument('--out-ark', default='-', help='where to write output feature data') args = parser.parse_args() # CIFAR image dimensions: C = 3 # num_channels H = 32 # num_rows W = 32 # num_cols def load_cifar10_data_batch(datafile): num_images_in_batch = 10000 data = [] labels = [] with open(datafile, 'rb') as fh: for i in range(num_images_in_batch): label = ord(fh.read(1)) bin_img = fh.read(C * H * W) img = [[[ord(byte)/255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] for row in range(H)] for channel in range(C)] labels += [label] data += [img] return data, labels def load_cifar100_data_batch(datafile, num_images_in_batch): data = [] fine_labels = [] coarse_labels = [] with open(datafile, 'rb') as fh: for i in range(num_images_in_batch): coarse_label = ord(fh.read(1)) fine_label = ord(fh.read(1)) bin_img = fh.read(C * H * W) img = [[[ord(byte)/255.0 for byte in bin_img[channel*H*W+row*W:channel*H*W+(row+1)*W]] for row in range(H)] for channel in range(C)] fine_labels += [fine_label] coarse_labels += [coarse_label] data += [img] return data, fine_labels, coarse_labels def image_to_feat_matrix(img): mat = [0]*H # 32 * 96 for i in range(W): mat[i] = [0]*C*H for ch in range(C): for j in range(H): mat[i][j*C+ch] = img[ch][j][i] return mat def write_kaldi_matrix(file_handle, matrix, key): # matrix is a list of lists file_handle.write(key + " [ ") num_rows = len(matrix) if num_rows == 0: raise Exception("Matrix is empty") num_cols = len(matrix[0]) for row_index in range(len(matrix)): if num_cols != len(matrix[row_index]): raise Exception("All the rows of a matrix are expected to " "have the same length") file_handle.write(" ".join([str(x) for x in matrix[row_index]])) if row_index != num_rows - 1: file_handle.write(" ") file_handle.write(" ] ") def zeropad(x, length): s = str(x) while len(s) < length: s = '0' + s return s ### main ### cifar10 = (args.cifar_version.lower() == 'cifar-10') if args.out_ark == '-': out_fh = sys.stdout # output file handle to write the feats to else: out_fh = open(args.out_ark, 'wb') if cifar10: img_id = 1 # similar to utt_id labels_file = os.path.join(args.dir, 'labels.txt') labels_fh = open(labels_file, 'wb') if args.dataset == 'train': for i in range(1, 6): fpath = os.path.join(args.database, 'data_batch_' + str(i) + '.bin') data, labels = load_cifar10_data_batch(fpath) for i in range(len(data)): key = zeropad(img_id, 5) labels_fh.write(key + ' ' + str(labels[i]) + ' ') feat_mat = image_to_feat_matrix(data[i]) write_kaldi_matrix(out_fh, feat_mat, key) img_id += 1 else: fpath = os.path.join(args.database, 'test_batch.bin') data, labels = load_cifar10_data_batch(fpath) for i in range(len(data)): key = zeropad(img_id, 5) labels_fh.write(key + ' ' + str(labels[i]) + ' ') feat_mat = image_to_feat_matrix(data[i]) write_kaldi_matrix(out_fh, feat_mat, key) img_id += 1 labels_fh.close() else: img_id = 1 # similar to utt_id fine_labels_file = os.path.join(args.dir, 'labels.txt') # coarse_labels_file = os.path.join(args.dir, 'coarse_labels.txt') fine_labels_fh = open(fine_labels_file, 'wb') # coarse_labels_fh = open(coarse_labels_file, 'wb') if args.dataset == 'train': fpath = os.path.join(args.database, 'train.bin') data, fine_labels, coarse_labels = load_cifar100_data_batch(fpath, 50000) for i in range(len(data)): key = zeropad(img_id, 5) fine_labels_fh.write(key + ' ' + str(fine_labels[i]) + ' ') # coarse_labels_fh.write(key + ' ' + str(coarse_labels[i]) + ' ') feat_mat = image_to_feat_matrix(data[i]) write_kaldi_matrix(out_fh, feat_mat, key) img_id += 1 else: fpath = os.path.join(args.database, 'test.bin') data, fine_labels, coarse_labels = load_cifar100_data_batch(fpath, 10000) for i in range(len(data)): key = zeropad(img_id, 5) fine_labels_fh.write(key + ' ' + str(fine_labels[i]) + ' ') # coarse_labels_fh.write(key + ' ' + str(coarse_labels[i]) + ' ') feat_mat = image_to_feat_matrix(data[i]) write_kaldi_matrix(out_fh, feat_mat, key) img_id += 1 fine_labels_fh.close() # coarse_labels_fh.close() out_fh.close() |