Blame view

egs/madcat_ar/v1/local/process_data.py 9.83 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  #!/usr/bin/env python3
  
  # Copyright  2018  Ashish Arora
  
  """ This script reads MADCAT files and creates the following files (for the
      data subset selected via --dataset) :text, utt2spk, images.scp.
    Eg. local/process_data.py data/local /export/corpora/LDC/LDC2012T15 /export/corpora/LDC/LDC2013T09
        /export/corpora/LDC/LDC2013T15 data/download/data_splits/madcat.train.raw.lineid
        data/dev data/local/lines/images.scp
    Eg. text file: LDC0001_000404_NHR_ARB_20070113.0052_11_LDC0001_00z2 وجه وعقل غارق حتّى النخاع
        utt2spk file: LDC0001_000397_NHR_ARB_20070113.0052_11_LDC0001_00z1 LDC0001
        images.scp file: LDC0009_000000_arb-NG-2-76513-5612324_2_LDC0009_00z0
        data/local/lines/1/arb-NG-2-76513-5612324_2_LDC0009_00z0.tif
  """
  
  import argparse
  import os
  import sys
  import xml.dom.minidom as minidom
  import unicodedata
  
  parser = argparse.ArgumentParser(description="Creates text, utt2spk and images.scp files",
                                   epilog="E.g.  " + sys.argv[0] + "  data/LDC2012T15"
                                   " data/LDC2013T09 data/LDC2013T15 data/madcat.train.raw.lineid "
                                   " data/train data/local/lines ",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  parser.add_argument('database_path1',
                      help='Path to the downloaded (and extracted) madcat data')
  parser.add_argument('database_path2',
                      help='Path to the downloaded (and extracted) madcat data')
  parser.add_argument('database_path3',
                      help='Path to the downloaded (and extracted) madcat data')
  parser.add_argument('data_splits',
                      help='Path to file that contains the train/test/dev split information')
  parser.add_argument('out_dir',
                      help='directory location to write output files.')
  parser.add_argument('images_scp_path',
                      help='Path of input images.scp file(maps line image and location)')
  parser.add_argument('writing_condition1',
                      help='Path to the downloaded (and extracted) writing conditions file 1')
  parser.add_argument('writing_condition2',
                      help='Path to the downloaded (and extracted) writing conditions file 2')
  parser.add_argument('writing_condition3',
                      help='Path to the downloaded (and extracted) writing conditions file 3')
  parser.add_argument("--augment", type=lambda x: (str(x).lower()=='true'), default=False,
                     help="performs image augmentation")
  parser.add_argument("--subset", type=lambda x: (str(x).lower()=='true'), default=False,
                     help="only processes subset of data based on writing condition")
  args = parser.parse_args()
  
  
  def check_file_location():
      """ Returns the complete path of the page image and corresponding
          xml file.
      Args:
      Returns:
          image_file_name (string): complete path and name of the page image.
          madcat_file_path (string): complete path and name of the madcat xml file
                                    corresponding to the page image.
      """
      madcat_file_path1 = os.path.join(args.database_path1, 'madcat', base_name + '.madcat.xml')
      madcat_file_path2 = os.path.join(args.database_path2, 'madcat', base_name + '.madcat.xml')
      madcat_file_path3 = os.path.join(args.database_path3, 'madcat', base_name + '.madcat.xml')
  
      image_file_path1 = os.path.join(args.database_path1, 'images', base_name + '.tif')
      image_file_path2 = os.path.join(args.database_path2, 'images', base_name + '.tif')
      image_file_path3 = os.path.join(args.database_path3, 'images', base_name + '.tif')
  
      if os.path.exists(madcat_file_path1):
          return madcat_file_path1, image_file_path1, wc_dict1
  
      if os.path.exists(madcat_file_path2):
          return madcat_file_path2, image_file_path2, wc_dict2
  
      if os.path.exists(madcat_file_path3):
          return madcat_file_path3, image_file_path3, wc_dict3
  
      return None, None, None
  
  
  def parse_writing_conditions(writing_conditions):
      """ Returns a dictionary which have writing condition of each page image.
      Args:
           writing_conditions(string): complete path of writing condition file.
      Returns:
          (dict): dictionary with key as page image name and value as writing condition.
      """
      with open(writing_conditions) as f:
          file_writing_cond = dict()
          for line in f:
              line_list = line.strip().split("\t")
              file_writing_cond[line_list[0]] = line_list[3]
      return file_writing_cond
  
  
  def check_writing_condition(wc_dict):
      """ Checks if a given page image is writing in a given writing condition.
          It is used to create subset of dataset based on writing condition.
      Args:
           wc_dict (dict): dictionary with key as page image name and value as writing condition.
      Returns:
          (bool): True if writing condition matches.
      """
      if args.subset:
          writing_condition = wc_dict[base_name].strip()
          if writing_condition != 'IUC':
              return False
          else:
              return True
      else:
          return True
  
  
  def read_text(madcat_file_path):
      """ Maps every word in the page image to a  corresponding line.
      Args:
          madcat_file_path (string): complete path and name of the madcat xml file
                                    corresponding to the page image.
      Returns:
          dict: Mapping every word in the page image to a  corresponding line.
      """
  
      word_line_dict = dict()
      doc = minidom.parse(madcat_file_path)
      zone = doc.getElementsByTagName('zone')
      for node in zone:
          line_id = node.getAttribute('id')
          word_image = node.getElementsByTagName('token-image')
          for tnode in word_image:
              word_id = tnode.getAttribute('id')
              word_line_dict[word_id] = line_id
  
      text_line_word_dict = dict()
      segment = doc.getElementsByTagName('segment')
      for node in segment:
          token = node.getElementsByTagName('token')
          for tnode in token:
              ref_word_id = tnode.getAttribute('ref_id')
              word = tnode.getElementsByTagName('source')[0].firstChild.nodeValue
              ref_line_id = word_line_dict[ref_word_id]
              if ref_line_id not in text_line_word_dict:
                  text_line_word_dict[ref_line_id] = list()
              text_line_word_dict[ref_line_id].append(word)
      return text_line_word_dict
  
  
  def get_line_image_location():
      image_loc_dict = dict()  # Stores image base name and location
      image_loc_vect = input_image_fh.read().strip().split("
  ")
      for line in image_loc_vect:
          base_name = os.path.basename(line)
          location_vect = line.split('/')
          location = "/".join(location_vect[:-1])
          image_loc_dict[base_name]=location
      return image_loc_dict
  
  
  ### main ###
  print("Processing '{}' data...".format(args.out_dir))
  
  text_file = os.path.join(args.out_dir, 'text')
  text_fh = open(text_file, 'w', encoding='utf-8')
  utt2spk_file = os.path.join(args.out_dir, 'utt2spk')
  utt2spk_fh = open(utt2spk_file, 'w', encoding='utf-8')
  image_file = os.path.join(args.out_dir, 'images.scp')
  image_fh = open(image_file, 'w', encoding='utf-8')
  
  input_image_file = args.images_scp_path
  input_image_fh = open(input_image_file, 'r', encoding='utf-8')
  
  wc_dict1 = parse_writing_conditions(args.writing_condition1)
  wc_dict2 = parse_writing_conditions(args.writing_condition2)
  wc_dict3 = parse_writing_conditions(args.writing_condition3)
  image_loc_dict = get_line_image_location()
  
  image_num = 0
  with open(args.data_splits) as f:
      prev_base_name = ''
      for line in f:
          base_name = os.path.splitext(os.path.splitext(line.split(' ')[0])[0])[0]
          if prev_base_name != base_name:
              prev_base_name = base_name
              madcat_xml_path, image_file_path, wc_dict = check_file_location()
              if wc_dict is None or not check_writing_condition(wc_dict):
                  continue
              madcat_doc = minidom.parse(madcat_xml_path)
              writer = madcat_doc.getElementsByTagName('writer')
              writer_id = writer[0].getAttribute('id')
              text_line_word_dict = read_text(madcat_xml_path)
              base_name = os.path.basename(image_file_path).split('.tif')[0]
              for line_id in sorted(text_line_word_dict):
                  if args.augment:
                      key = (line_id + '.')[:-1]
                      for i in range(0, 3):
                          location_id = "_{}_scale{}".format(line_id, i)
                          line_image_file_name = base_name + location_id + '.png'
                          location = image_loc_dict[line_image_file_name]
                          image_file_path = os.path.join(location, line_image_file_name)
                          line = text_line_word_dict[key]
                          text = ' '.join(line)
                          base_line_image_file_name = line_image_file_name.split('.png')[0]
                          utt_id = "{}_{}_{}".format(writer_id, str(image_num).zfill(6), base_line_image_file_name)
                          text_fh.write(utt_id + ' ' + text + '
  ')
                          utt2spk_fh.write(utt_id + ' ' + writer_id + '
  ')
                          image_fh.write(utt_id + ' ' + image_file_path + '
  ')
                          image_num += 1
                  else:
                      updated_base_name = "{}_{}.png".format(base_name, str(line_id).zfill(4))
                      location = image_loc_dict[updated_base_name]
                      image_file_path = os.path.join(location, updated_base_name)
                      line = text_line_word_dict[line_id]
                      text = ' '.join(line)
                      utt_id = "{}_{}_{}_{}".format(writer_id, str(image_num).zfill(6), base_name, str(line_id).zfill(4))
                      text_fh.write(utt_id + ' ' + text + '
  ')
                      utt2spk_fh.write(utt_id + ' ' + writer_id + '
  ')
                      image_fh.write(utt_id + ' ' + image_file_path + '
  ')
                      image_num += 1