Blame view

egs/yomdle_tamil/v1/local/process_data.py 2.45 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  #!/usr/bin/env python3
  
  # Copyright      2018  Ashish Arora
  #                2018  Chun Chieh Chang
  
  """ This script reads the extracted Tamil OCR (yomdle and slam) database files 
      and creates the following files (for the data subset selected via --dataset):
      text, utt2spk, images.scp.
    Eg. local/process_data.py data/download/ data/local/splits/train.txt data/train
  
    Eg. text file: english_phone_books_0001_1 To sum up, then, it would appear that
        utt2spk file: english_phone_books_0001_0 english_phone_books_0001
        images.scp file: english_phone_books_0001_0 \
        data/download/truth_line_image/english_phone_books_0001_0.png
  """
  
  import argparse
  import os
  import sys
  import csv
  import itertools
  import unicodedata
  import re
  import string
  parser = argparse.ArgumentParser(description="Creates text, utt2spk, and images.scp files")
  parser.add_argument('database_path', type=str, help='Path to data')
  parser.add_argument('data_split', type=str, help='Path to file that contain datasplits')
  parser.add_argument('out_dir', type=str, help='directory to output files')
  args = parser.parse_args()
  
  ### main ###
  print("Processing '{}' data...".format(args.out_dir))
  
  text_file = os.path.join(args.out_dir, 'text')
  text_fh = open(text_file, 'w', encoding='utf-8')
  utt2spk_file = os.path.join(args.out_dir, 'utt2spk')
  utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8')
  image_file = os.path.join(args.out_dir, 'images.scp')
  image_fh = open(image_file, 'w', encoding='utf-8')
  
  with open(args.data_split) as f:
      for line in f:
          line = line.strip()
          image_id = line
          image_filename = image_id + '.png'
          image_filepath = os.path.join(args.database_path, 'truth_line_image', image_filename)
          if not os.path.isfile (image_filepath):
              print("File does not exist {}".format(image_filepath))
              continue
          line_id = int(line.split('_')[-1])
          csv_filename = '_'.join(line.split('_')[:-1]) + '.csv'
          csv_filepath = os.path.join(args.database_path, 'truth_csv', csv_filename)
          csv_file = open(csv_filepath, 'r', encoding='utf-8')
          for row in csv.reader(csv_file):
              if row[1] == image_filename:
                  text = row[11]
                  if not text:
                      continue
                  text_fh.write(image_id + ' ' + text + '
  ')
                  utt2spk_fh.write(image_id + ' ' + '_'.join(line.split('_')[:-1]) + '
  ')
                  image_fh.write(image_id + ' ' + image_filepath +  '
  ')