Blame view

egs/iam/v1/local/gen_topo.py 4.29 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
  #!/usr/bin/env python
  
  # Copyright 2017 (author: Chun-Chieh Chang)
  
  # Generate a topology file.  This allows control of the number of states in the
  # non-silence HMMs, and in the silence HMMs. This is a modified version of
  # 'utils/gen_topo.pl'. The difference is that this creates two topologies for
  # the non-silence HMMs. The number of states for punctuations is different than
  # the number of states for other characters.
  
  from __future__ import print_function
  import argparse
  import string
  
  parser = argparse.ArgumentParser(description="Usage: steps/nnet3/chain/gen_topo.py "
                                               "<colon-separated-nonsilence-phones> <colon-separated-silence-phones>"
                                               "e.g.:  steps/nnet3/chain/gen_topo.pl 4:5:6:7:8:9:10 1:2:3
  ",
                                   epilog="See egs/swbd/s5c/local/chain/train_tdnn_a.sh for example of usage.");
  parser.add_argument("num_nonsil_states", type=int, help="number of states for nonsilence phones");
  parser.add_argument("num_sil_states", type=int, help="number of states for silence phones");
  parser.add_argument("num_punctuation_states", type=int, help="number of states for punctuation");
  parser.add_argument("nonsilence_phones", type=str,
                      help="List of non-silence phones as integers, separated by colons, e.g. 4:5:6:7:8:9");
  parser.add_argument("silence_phones", type=str,
                      help="List of silence phones as integers, separated by colons, e.g. 1:2:3");
  parser.add_argument("phone_list", type=str, help="file containing all phones and their corresponding number.");
  
  args = parser.parse_args()
  
  silence_phones = [ int(x) for x in args.silence_phones.split(":") ]
  nonsilence_phones = [ int(x) for x in args.nonsilence_phones.split(":") ]
  all_phones = silence_phones +  nonsilence_phones
  
  punctuation_phones = []
  exclude = set("!(),.?;:'-\"")
  with open(args.phone_list) as f:
      for line in f:
          line = line.strip()
          phone = line.split('_')[0]
          if len(phone) == 1 and phone in exclude:
              punctuation_phones.append(int(line.split(' ')[1]))
  # For nonsilence phones that are not punctuations
  print("<Topology>")
  print("<TopologyEntry>")
  print("<ForPhones>")
  print(" ".join([str(x) for x in nonsilence_phones if x not in punctuation_phones]))
  print("</ForPhones>")
  for x in range(0, args.num_nonsil_states):
      xp1 = x + 1
      print("<State> " + str(x) + " <PdfClass> " + str(x) + " <Transition> " + str(x) + " 0.75 <Transition> " + str(xp1) + " 0.25 </State>")
  print("<State> " + str(args.num_nonsil_states) + " </State>")
  print("</TopologyEntry>")
  
  # For nonsilence phones that ar punctuations
  print("<TopologyEntry>")
  print("<ForPhones>")
  print(" ".join([str(x) for x in nonsilence_phones if x in punctuation_phones]))
  print("</ForPhones>")
  for x in range(0, args.num_punctuation_states):
      xp1 = x + 1
      print("<State> " + str(x) + " <PdfClass> " + str(x) + " <Transition> " + str(x) + " 0.75 <Transition> " + str(xp1) + " 0.25 </State>")
  print("<State> " + str(args.num_punctuation_states) + " </State>")
  print("</TopologyEntry>")
  
  # For silence phones
  print("<TopologyEntry>")
  print("<ForPhones>")
  print(" ".join([str(x) for x in silence_phones]))
  print("</ForPhones>")
  if(args.num_sil_states > 1):
      transp = 1.0 / (args.num_sil_states - 1)
      
      state_str = "<State> 0 <PdfClass> 0 "
      for x in range(0, (args.num_sil_states - 1)):
          state_str = state_str + "<Transition> " + str(x) + " " + str(transp) + " "
      state_str = state_str + "</State>"
      print(state_str)
  
      for x in range(1, (args.num_sil_states - 1)):
          state_str = "<State> " + str(x) + " <PdfClass> " + str(x) + " "
          for y in range(1, args.num_sil_states):
              state_str = state_str + "<Transition> " + str(y) + " " + str(transp) + " "
          state_str = state_str + "</State>"
          print(state_str)
      second_last = args.num_sil_states - 1
      print("<State> " + str(second_last) + " <PdfClass> " + str(second_last) + " <Transition> " + str(second_last) + " 0.75 <Transition> " + str(args.num_sil_states) + " 0.25 </State>")
      print("<State> " + str(args.num_sil_states) + " </State>")
  else:
      print("<State> 0 <PdfClass> 0 <Transition> 0 0.75 <Transition> 1 0.25 </State>")
      print("<State> " + str(args.num_sil_states) + " </State>")
  print("</TopologyEntry>")
  print("</Topology>")