Blame view

scripts/rnnlm/get_embedding_dim.py 3.83 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
  #!/usr/bin/env python3
  
  # Copyright  2017  Johns Hopkins University (author: Daniel Povey)
  # License: Apache 2.0.
  
  import os
  import argparse
  import subprocess
  import sys
  import re
  
  
  parser = argparse.ArgumentParser(description="This script works out the embedding dimension from a "
                                   "nnet3 neural network (e.g. 0.raw).  It does this by invoking "
                                   "nnet3-info to print information about the neural network, and "
                                   "parsing it.  You should make sure nnet3-info is on your path "
                                   "before you call this script.  It is an error if the input and "
                                   "output dimensions of the neural network are not the same.  This "
                                   "script prints the embedding dimension to the standard output.",
                                   epilog="E.g. " + sys.argv[0] + " 0.raw",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  
  parser.add_argument("nnet",
                      help="Path for raw neural net (e.g. 0.raw)")
  
  args = parser.parse_args()
  
  if not os.path.exists(args.nnet):
      sys.exit(sys.argv[0] + ": input neural net '{0}' does not exist.".format(args.nnet))
  
  proc = subprocess.Popen(["nnet3-info", args.nnet], stdout=subprocess.PIPE)
  out_lines = proc.stdout.readlines()
  proc.communicate()
  if proc.returncode != 0:
      sys.exit(sys.argv[0] + ": error running command 'nnet3-info {0}'".format(args.nnet))
  
  
  # we're looking for lines like:
  # input-node name=input dim=600
  # output-node name=output input=output.affine dim=600
  
  input_dim=-1
  output_dim=-1
  left_context=0
  right_context=0
  for line in out_lines:
      line = line.decode('utf-8')
      m = re.search(r'input-node name=input dim=(\d+)', line)
      if m is not None:
          try:
              input_dim = int(m.group(1))
          except:
              sys.exit(sys.argv[0] + ": error processing line {0}".format(line))
  
      m = re.search(r'output-node name=output .* dim=(\d+)', line)
      if m is not None:
          try:
              output_dim = int(m.group(1))
          except:
              sys.exit(sys.argv[0] + ": error processing line {0}".format(line))
  
      m = re.match(r'left-context: (\d+)', line)
      if m is not None:
          left_context = int(m.group(1))
      m = re.match(r'right-context: (\d+)', line)
      if m is not None:
          right_context = int(m.group(1))
  
  if input_dim == -1:
      sys.exit(sys.argv[0] + ": could not get input dim from output "
               "of 'nnet3-info {0}'".format(args.nnet))
  
  if output_dim == -1:
      sys.exit(sys.argv[0] + ": could not get output dim from output "
               "of 'nnet3-info {0}'".format(args.nnet))
  
  if left_context == -1:
      sys.exit(sys.argv[0] + ": could not get left context output "
               "of 'nnet3-info {0}'".format(args.nnet))
  
  if right_context == -1:
      sys.exit(sys.argv[0] + ": could not get right context output "
               "of 'nnet3-info {0}'".format(args.nnet))
  
  if right_context > 0:
      sys.exit(sys.argv[0] + ": right-context of model {0} is >0: (it's {1}). "
               "This model cannot be used as an RNNLM because it sees the "
               "future.".format(args.nnet, left_context))
  
  if left_context > 0:
      sys.exit(sys.argv[0] + ": left-context of model {0} is >0: (it's {1}). "
               "This model cannot be used as an RNNLM because it requires left "
               "context and the code does not support this.  You can generally use "
               "IfDefined() in descriptors, and set configs of layers, in such "
               "a way that left-context is not required"
               "".format(args.nnet, left_context))
  
  if input_dim != output_dim:
      sys.exit(sys.argv[0] + ": input and output dims differ for "
               "nnet '{0}': {1} != {2}".format(
              args.nnet, input_dim, output_dim))
  
  print('{}'.format(input_dim))