Blame view
scripts/rnnlm/get_unigram_probs.py
8.01 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
#!/usr/bin/env python3 # Copyright 2017 Jian Wang # License: Apache 2.0. import os import argparse import sys import re parser = argparse.ArgumentParser(description="This script gets the unigram probabilities of words.", epilog="E.g. " + sys.argv[0] + " --vocab-file=data/rnnlm/vocab/words.txt " "--data-weights-file=exp/rnnlm/data_weights.txt data/rnnlm/data " "> exp/rnnlm/unigram_probs.txt", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--vocab-file", type=str, default='', required=True, help="Specify the vocab file.") parser.add_argument("--unk-word", type=str, default='', help="String form of unknown word, e.g. <unk>. Words in the counts " "but not present in the vocabulary will be mapped to this word. " "If the empty string, we act as if there is no unknown-word, and " "OOV words are treated as an error.") parser.add_argument("--data-weights-file", type=str, default='', required=True, help="File that specifies multiplicities and weights for each data source: " "e.g. if <text_dir> contains foo.txt and bar.txt, then should have lines " "like 'foo 1 0.5' and 'bar 5 1.5'. These " "don't have to sum to one.") parser.add_argument("--smooth-unigram-counts", type=float, default=1.0, help="Specify the constant for smoothing. We will add " "(smooth_unigram_counts * num_words_with_non_zero_counts / vocab_size) " "to every unigram counts.") parser.add_argument("text_dir", help="Directory in which to look for data") args = parser.parse_args() SPECIAL_SYMBOLS = ["<eps>", "<s>", "<brk>"] # get the name with txt and counts file path for all data sources except dev # return a dict with key is the name of data_source, # value is a tuple (txt_file_path, counts_file_path) def get_all_data_sources_except_dev(text_dir): data_sources = {} for f in os.listdir(text_dir): full_path = text_dir + "/" + f if f == 'dev.txt' or f == 'dev.counts' or os.path.isdir(full_path): continue if f.endswith(".txt"): name = f[0:-4] if name in data_sources: data_sources[name] = (full_path, data_sources[name][1]) else: data_sources[name] = (full_path, None) elif f.endswith(".counts"): name = f[0:-7] if name in data_sources: data_sources[name] = (data_sources[name][0], full_path) else: data_sources[name] = (None, full_path) else: sys.exit(sys.argv[0] + ": Text directory should not contain files with suffixes " "other than .txt or .counts: " + f) for name, (txt_file, counts_file) in data_sources.items(): if txt_file is None or counts_file is None: sys.exit(sys.argv[0] + ": Missing .txt or .counts file for data source: " + name) return data_sources # read the data-weights for data_sources from weights_file # return a dict with key is name of a data source, # value is a tuple (repeated_times_per_epoch, weight) def read_data_weights(weights_file, data_sources): data_weights = {} with open(weights_file, 'r', encoding="utf-8") as f: for line in f: try: fields = line.split() assert len(fields) == 3 if fields[0] in data_weights: raise Exception("duplicated data source({0}) specified in " "data-weights: {1}".format(fields[0], weights_file)) data_weights[fields[0]] = (int(fields[1]), float(fields[2])) except Exception as e: sys.exit(sys.argv[0] + ": bad data-weights line: '" + line.rstrip(" ") + "': " + str(e)) for name in data_sources.keys(): if name not in data_weights: sys.exit(sys.argv[0] + ": Weight for data source '{0}' not set".format(name)) return data_weights # read the voab # return the vocab, which is a dict mapping the word to a integer id. def read_vocab(vocab_file): vocab = {} with open(vocab_file, 'r', encoding="utf-8") as f: for line in f: fields = line.split() assert len(fields) == 2 if fields[0] in vocab: sys.exit(sys.argv[0] + ": duplicated word({0}) in vocab: {1}" .format(fields[0], vocab_file)) vocab[fields[0]] = int(fields[1]) # check there is no duplication and no gap among word ids sorted_ids = sorted(vocab.values()) for idx, id in enumerate(sorted_ids): assert idx == id if args.unk_word != '' and args.unk_word not in vocab: sys.exit(sys.argv[0] + "--unk-word={0} does not appear in vocab file {1}".format( args.unk_word, vocab_file)) return vocab # Get total (weighted) count for words from all data_sources # return a list of counts indexed by word id. def get_counts(data_sources, data_weights, vocab): counts = [0.0] * len(vocab) for name, (_, counts_file) in data_sources.items(): weight = data_weights[name][0] * data_weights[name][1] if weight == 0.0: continue with open(counts_file, 'r', encoding="utf-8") as f: for line in f: fields = line.split() if len(fields) != 2: print("Warning, should be 2 cols:", fields, line, file=sys.stderr); assert(len(fields) == 2) word = fields[0] count = fields[1] if word not in vocab: if args.unk_word == '': sys.exit(sys.argv[0] + ": error: an OOV word {0} is present in the " "counts file {1} but you have not specified an unknown word to " "map it to (--unk-word option).".format(word, counts_file)) else: word = args.unk_word counts[vocab[word]] += weight * int(fields[1]) return counts # Smooth counts and get unigram probs for words # return a list of probs indexed by word id. def get_unigram_probs(vocab, counts, smooth_constant): special_symbol_ids = [vocab[x] for x in SPECIAL_SYMBOLS] vocab_size = len(vocab) - len(SPECIAL_SYMBOLS) num_words_with_non_zero_counts = 0 for word_id, count in enumerate(counts): if word_id in special_symbol_ids: continue if counts[word_id] > 0: num_words_with_non_zero_counts += 1 if num_words_with_non_zero_counts < vocab_size and smooth_constant == 0.0: sys.exit(sys.argv[0] + ": --smooth-unigram-counts should not be zero, " "since there are words with zero-counts") smooth_count = smooth_constant * num_words_with_non_zero_counts / vocab_size total_counts = 0.0 for word_id, count in enumerate(counts): if word_id in special_symbol_ids: continue counts[word_id] += smooth_count total_counts += counts[word_id] probs = [] for count in counts: probs.append(count / total_counts) return probs if os.system("rnnlm/ensure_counts_present.sh {0}".format(args.text_dir)) != 0: print(sys.argv[0] + ": command 'rnnlm/ensure_counts_present.sh {0}' failed.".format( args.text_dir)) data_sources = get_all_data_sources_except_dev(args.text_dir) data_weights = read_data_weights(args.data_weights_file, data_sources) vocab = read_vocab(args.vocab_file) counts = get_counts(data_sources, data_weights, vocab) probs = get_unigram_probs(vocab, counts, args.smooth_unigram_counts) for idx, p in enumerate(probs): print(idx, p) print(sys.argv[0] + ": generated unigram probs.", file=sys.stderr) |