Blame view

egs/iam/v1/local/unk_arc_post_to_transcription.py 4.8 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  #!/usr/bin/env python3
  
  #Copyright      2017  Ashish Arora
  
  """ This module will be used by scripts for open vocabulary setup.
   If the hypothesis transcription contains <unk>, then it will replace the 
   <unk> with the word predicted by <unk> model by concatenating phones decoded 
   from the unk-model. It is currently supported only for triphone setup.
   Args:
    phones: File name of a file that contains the phones.txt, (symbol-table for phones).
            phone and phoneID, Eg. a 217, phoneID of 'a' is 217. 
    words: File name of a file that contains the words.txt, (symbol-table for words). 
           word and wordID. Eg. ACCOUNTANCY 234, wordID of 'ACCOUNTANCY' is 234.
    unk: ID of <unk>. Eg. 231.
    one-best-arc-post: A file in arc-post format, which is a list of timing info and posterior
                 of arcs along the one-best path from the lattice.
                 E.g. 506_m01-049-00 8 12  1 7722  282 272 288 231
                      <utterance-id> <start-frame> <num-frames> <posterior> <word> [<ali>] 
                      [<phone1> <phone2>...]
    output-text: File containing hypothesis transcription with <unk> recognized by the
                 unk-model.
                 E.g. A move to stop mr. gaitskell.
    
    Eg. local/unk_arc_post_to_transcription.py lang/phones.txt lang/words.txt 
        data/lang/oov.int
  """
  import argparse
  import io
  import os
  import sys
  parser = argparse.ArgumentParser(description="""uses phones to convert unk to word""")
  parser.add_argument('phones', type=str, help='File name of a file that contains the'
                      'symbol-table for phones. Each line must be: <phone> <phoneID>')
  parser.add_argument('words', type=str, help='File name of a file that contains the'
                      'symbol-table for words. Each line must be: <word> <word-id>')
  parser.add_argument('unk', type=str, default='-', help='File name of a file that'
                      'contains the ID of <unk>. The content must be: <oov-id>, e.g. 231')
  parser.add_argument('--one-best-arc-post', type=str, default='-', help='A file in arc-post'
                      'format, which is a list of timing info and posterior of arcs'
                      'along the one-best path from the lattice')
  parser.add_argument('--output-text', type=str, default='-', help='File containing'
                      'hypothesis transcription with <unk> recognized by the unk-model')
  args = parser.parse_args()
  
  ### main ###
  phone_handle = open(args.phones, 'r', encoding='utf8') # Create file handles 
  word_handle = open(args.words, 'r', encoding='utf8')
  unk_handle = open(args.unk,'r', encoding='utf8')
  if args.one_best_arc_post == '-':
      arc_post_handle = io.TextIOWrapper(sys.stdin.buffer, encoding='utf8')
  else:
      arc_post_handle = open(args.one_best_arc_post, 'r', encoding='utf8')
  if args.output_text == '-':
      output_text_handle = io.TextIOWrapper(sys.stdout.buffer, encoding='utf8')
  else:
      output_text_handle = open(args.output_text, 'w', encoding='utf8')
  
  id2phone = dict() # Stores the mapping from phone_id (int) to phone (char)
  phones_data = phone_handle.read().strip().split("
  ")
  
  for key_val in phones_data:
    key_val = key_val.split(" ")
    id2phone[key_val[1]] = key_val[0]
  
  word_dict = dict()
  word_data_vect = word_handle.read().strip().split("
  ")
  
  for key_val in word_data_vect:
    key_val = key_val.split(" ")
    word_dict[key_val[1]] = key_val[0]
  unk_val = unk_handle.read().strip().split(" ")[0]
  
  utt_word_dict = dict() # Dict of list, stores mapping from utteranceID(int) to words(str)
  for line in arc_post_handle:
    line_vect = line.strip().split("\t")
    if len(line_vect) < 6: # Check for 1best-arc-post output
      print("Error: Bad line: '{}'   Expecting 6 fields. Skipping...".format(line),
            file=sys.stderr)
      continue
    utt_id = line_vect[0]
    word = line_vect[4]
    phones = line_vect[5]
    if utt_id not in list(utt_word_dict.keys()):
      utt_word_dict[utt_id] = list()
  
    if word == unk_val: # Get the 1best phone sequence given by the unk-model
      phone_id_seq = phones.split(" ")
      phone_seq = list()
      for pkey in phone_id_seq:
        phone_seq.append(id2phone[pkey]) # Convert the phone-id sequence to a phone sequence.
      phone_2_word = list()
      for phone_val in phone_seq:
        phone_2_word.append(phone_val.split('_')[0]) # Removing the world-position markers(e.g. _B)
      phone_2_word = ''.join(phone_2_word) # Concatnate phone sequence
      utt_word_dict[utt_id].append(phone_2_word) # Store word from unk-model
    else:
      if word == '0': # Store space/silence
        word_val = ' '
      else:
        word_val = word_dict[word]
      utt_word_dict[utt_id].append(word_val) # Store word from 1best-arc-post
  
  transcription = "" # Output transcription
  for utt_key in sorted(utt_word_dict.keys()):
    transcription = utt_key
    for word in utt_word_dict[utt_key]:
      transcription = transcription + " " + word
    output_text_handle.write(transcription + '
  ')