get_embedding_dim.py
3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (author: Daniel Povey)
# License: Apache 2.0.
import os
import argparse
import subprocess
import sys
import re
parser = argparse.ArgumentParser(description="This script works out the embedding dimension from a "
"nnet3 neural network (e.g. 0.raw). It does this by invoking "
"nnet3-info to print information about the neural network, and "
"parsing it. You should make sure nnet3-info is on your path "
"before you call this script. It is an error if the input and "
"output dimensions of the neural network are not the same. This "
"script prints the embedding dimension to the standard output.",
epilog="E.g. " + sys.argv[0] + " 0.raw",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("nnet",
help="Path for raw neural net (e.g. 0.raw)")
args = parser.parse_args()
if not os.path.exists(args.nnet):
sys.exit(sys.argv[0] + ": input neural net '{0}' does not exist.".format(args.nnet))
proc = subprocess.Popen(["nnet3-info", args.nnet], stdout=subprocess.PIPE)
out_lines = proc.stdout.readlines()
proc.communicate()
if proc.returncode != 0:
sys.exit(sys.argv[0] + ": error running command 'nnet3-info {0}'".format(args.nnet))
# we're looking for lines like:
# input-node name=input dim=600
# output-node name=output input=output.affine dim=600
input_dim=-1
output_dim=-1
left_context=0
right_context=0
for line in out_lines:
line = line.decode('utf-8')
m = re.search(r'input-node name=input dim=(\d+)', line)
if m is not None:
try:
input_dim = int(m.group(1))
except:
sys.exit(sys.argv[0] + ": error processing line {0}".format(line))
m = re.search(r'output-node name=output .* dim=(\d+)', line)
if m is not None:
try:
output_dim = int(m.group(1))
except:
sys.exit(sys.argv[0] + ": error processing line {0}".format(line))
m = re.match(r'left-context: (\d+)', line)
if m is not None:
left_context = int(m.group(1))
m = re.match(r'right-context: (\d+)', line)
if m is not None:
right_context = int(m.group(1))
if input_dim == -1:
sys.exit(sys.argv[0] + ": could not get input dim from output "
"of 'nnet3-info {0}'".format(args.nnet))
if output_dim == -1:
sys.exit(sys.argv[0] + ": could not get output dim from output "
"of 'nnet3-info {0}'".format(args.nnet))
if left_context == -1:
sys.exit(sys.argv[0] + ": could not get left context output "
"of 'nnet3-info {0}'".format(args.nnet))
if right_context == -1:
sys.exit(sys.argv[0] + ": could not get right context output "
"of 'nnet3-info {0}'".format(args.nnet))
if right_context > 0:
sys.exit(sys.argv[0] + ": right-context of model {0} is >0: (it's {1}). "
"This model cannot be used as an RNNLM because it sees the "
"future.".format(args.nnet, left_context))
if left_context > 0:
sys.exit(sys.argv[0] + ": left-context of model {0} is >0: (it's {1}). "
"This model cannot be used as an RNNLM because it requires left "
"context and the code does not support this. You can generally use "
"IfDefined() in descriptors, and set configs of layers, in such "
"a way that left-context is not required"
"".format(args.nnet, left_context))
if input_dim != output_dim:
sys.exit(sys.argv[0] + ": input and output dims differ for "
"nnet '{0}': {1} != {2}".format(
args.nnet, input_dim, output_dim))
print('{}'.format(input_dim))