Blame view
egs/yomdle_russian/v1/local/process_data.py
2.68 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
#!/usr/bin/env python3 # Copyright 2018 Ashish Arora # 2018 Chun Chieh Chang """ This script reads the extracted Tamil OCR (yomdle and slam) database files and creates the following files (for the data subset selected via --dataset): text, utt2spk, images.scp. Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that utt2spk file: english_phone_books_0001_0 english_phone_books_0001 images.scp file: english_phone_books_0001_0 \ data/download/truth_line_image/english_phone_books_0001_0.png """ import argparse import os import sys import csv import itertools import unicodedata import re import string import unicodedata parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files") parser.add_argument('database_path', type=str, help='Path to data') parser.add_argument('data_split', type=str, help='Path to file that contain datasplits') parser.add_argument('out_dir', type=str, help='directory to output files') args = parser.parse_args() ### main ### print("Processing '{}' data...".format(args.out_dir)) text_file = os.path.join(args.out_dir, 'text') text_fh = open(text_file, 'w', encoding='utf-8') utt2spk_file = os.path.join(args.out_dir, 'utt2spk') utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8') image_file = os.path.join(args.out_dir, 'images.scp') image_fh = open(image_file, 'w', encoding='utf-8') with open(args.data_split) as f: for line in f: line = line.strip() image_id = line image_filename = image_id + '.png' image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename) if not os.path.isfile (image_filepath): print("File does not exist {}".format(image_filepath)) continue line_id = int(line.split('_')[-1]) csv_filename = '_'.join(line.split('_')[:-1]) + '.csv' csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename) csv_file = open(csv_filepath, 'r', encoding='utf-8') for row in csv.reader(csv_file): if row[1] == image_filename: text = row[11] text_vect = text.split() # this is to avoid non-utf-8 spaces text = " ".join(text_vect) #text_normalized = unicodedata.normalize('NFD', text).replace(' ', '') if not text: continue text_fh.write(image_id + ' ' + text + ' ') utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + ' ') image_fh.write(image_id + ' ' + image_filepath + ' ') |