Blame view

egs/madcat_ar/v1/local/tl/process_waldo_data.py 3.01 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
  #!/usr/bin/env python3
  
  """ This script reads image and transcription mapping and creates the following files :text, utt2spk, images.scp.
    Eg. local/process_waldo_data.py lines/hyp_line_image_transcription_mapping_kaldi.txt data/test
    Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 ﻮﺠﻫ ﻮﻌﻘﻟ ﻍﺍﺮﻗ ﺢﺗّﻯ ﺎﻠﻨﺧﺎﻋ
        utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001
        images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0
        data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif
  """
  
  import argparse
  import os
  import sys
  
  parser = argparse.ArgumentParser(description="Creates text, utt2spk and images.scp files",
                                   epilog="E.g.  " + sys.argv[0] + " data/train data/local/lines ",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('image_transcription_file', type=str,
                      help='Path to the file containing line image path and transcription information')
  parser.add_argument('out_dir', type=str,
                      help='directory location to write output files.')
  args = parser.parse_args()
  
  
  def read_image_text(image_text_path):
      """ Given the file path containing, mapping information of line image
       and transcription, it returns a dict. The dict contains this mapping
      info. It can be accessed via line_id and will provide transcription.
      Returns:
      --------
      dict: line_id and transcription mapping
      """
      image_transcription_dict = dict()
      with open(image_text_path, encoding='utf-8') as f:
          for line in f:
              line_vect = line.strip().split(' ')
              image_path = line_vect[0]
              line_id = os.path.basename(image_path).split('.png')[0]
              transcription = line_vect[1:]
              joined_transcription = list()
              for word in transcription:
                  joined_transcription.append(word)
              joined_transcription = " ".join(joined_transcription)
              image_transcription_dict[line_id] = joined_transcription
      return image_transcription_dict
  
  
  ### main ###
  print("Processing '{}' data...".format(args.out_dir))
  text_file = os.path.join(args.out_dir, 'text')
  text_fh = open(text_file, 'w', encoding='utf-8')
  utt2spk_file = os.path.join(args.out_dir, 'utt2spk')
  utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8')
  image_file = os.path.join(args.out_dir, 'images.scp')
  image_fh = open(image_file, 'w', encoding='utf-8')
  
  image_transcription_dict = read_image_text(args.image_transcription_file)
  for line_id in sorted(image_transcription_dict.keys()):
          writer_id = line_id.strip().split('_')[-3]
          updated_line_id = line_id + '.png'
          image_file_path = os.path.join('lines', updated_line_id)
          text = image_transcription_dict[line_id]
          utt_id = line_id
          text_fh.write(utt_id + ' ' + text + '
  ')
          utt2spk_fh.write(utt_id + ' ' + writer_id + '
  ')
          image_fh.write(utt_id + ' ' + image_file_path + '
  ')