Blame view

scripts/rnnlm/get_vocab.py 1.97 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
  #!/usr/bin/env python3
  
  # Copyright  2017  Jian Wang
  # License: Apache 2.0.
  
  import os
  import argparse
  import sys
  sys.stdout = open(1, 'w', encoding='utf-8', closefd=False)
  
  import re
  
  
  parser = argparse.ArgumentParser(description="This script get a vocab from unigram counts "
                                   "of words produced by get_unigram_counts.sh",
                                   epilog="E.g. " + sys.argv[0] + " data/rnnlm/data > data/rnnlm/vocab/words.txt",
                                   formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  
  parser.add_argument("data_dir",
                      help="Directory in which to look for unigram counts.")
  
  args = parser.parse_args()
  
  eos_symbol = '</s>'
  special_symbols = ['<s>', '<brk>', '<eps>']
  
  
  # Add the count for every word in counts_file
  # the result is written into word_counts
  def add_counts(word_counts, counts_file):
      with open(counts_file, 'r', encoding="utf-8") as f:
          for line in f:
              line = line.strip(" \t\r
  ")
              word_and_count = line.split()
              assert len(word_and_count) == 2
              if word_and_count[0] in word_counts:
                  word_counts[word_and_count[0]] += int(word_and_count[1])
              else:
                  word_counts[word_and_count[0]] = int(word_and_count[1])
  
  word_counts = {}
  
  for f in os.listdir(args.data_dir):
      full_path = args.data_dir + "/" + f
      if os.path.isdir(full_path):
          continue
      if f.endswith(".counts"):
          add_counts(word_counts, full_path)
  
  if len(word_counts) == 0:
      sys.exit(sys.argv[0] + ": Directory {0} should contain at least one .counts file "
               .format(args.data_dir))
  
  print("<eps> 0")
  print("<s> 1")
  print("</s> 2")
  print("<brk> 3")
  
  idx = 4
  for word, _ in sorted(word_counts.items(), key=lambda x: x[1], reverse=True):
      if word == "</s>":
          continue
      print("{0} {1}".format(word, idx))
      idx += 1
  
  print(sys.argv[0] + ": vocab is generated with {0} words.".format(idx), file=sys.stderr)