Blame view
egs/wsj/s5/steps/cleanup/internal/retrieve_similar_docs.py
14.2 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 |
#! /usr/bin/env python # Copyright 2017 Vimal Manohar # Apache 2.0. """This script retrieves documents similar to the query documents using a similarity score based on the total TFIDF for all the terms in the query document. Some terminology: original utterance-id = The utterance-id of the original long audio segments and the corresponding reference transcript source-text = reference transcript source-text-id = original utterance-id sub-segment = Approximately 30s long chunk of the original utterance query-id = utterance-id of the sub-segment document = Approximately 1000 words of a source-text doc-id = Id of the document e.g. foo1 A B C D E F is in the original text file and foo1 foo 100 200 is in the original segments file. Here foo1 is the source-text-id and "A B C D" is the reference transcript. It is a 100s long segment from the recording foo. foo1 is split into 30s long sub-segments as follows: foo1-1 foo1 100 130 foo1-2 foo1 125 155 foo1-3 foo1 150 180 foo1-4 foo1 175 200 foo1-{1,2,3,4} are query-ids. The source-text for foo1 is split into two-word documents. doc1 A B doc2 C D doc3 E F doc{1,2,3} are doc-ids. --source-text2doc-ids option is given a mapping that contains foo1 doc1 doc2 doc3 --query-id2source-text-id option is given a mapping that contains foo1-1 foo1 foo1-2 foo1 foo1-3 foo1 foo1-4 foo1 The query TF-IDFs are all indexed by the utterance-id of the sub-segments of the original utterances. The source TF-IDFs use the document-ids created by splitting the source-text (corresponding to original utterances) into documents. For each query (sub-segment), we need to retrieve the documents that were created from the same original utterance that the sub-segment was from. For this, we have to load the source TF-IDF that has those documents. This information is provided using the option --source-text2tf-idf-file, which is like an SCP file with the first column being the source-text-id and the second column begin the location of TF-IDF for the documents corresponding to that source-text-id. The output of this script is a file where the first column is the query-id (i.e. sub-segment-id) and the remaining columns, which is at least one in number and a maxmium of (1 + 2 * num-neighbors-to-search) columns are tuples separated by commas (<doc-id>, <start-fraction>, <end-fraction>), where <doc-id> is the document-id <start-fraction> is the proportion of the document from the beginning that needs to be in the retrieved set. <end-fraction> is the proportion of the document from the end that needs to be in the retrieved set. If both <start-fraction> and <end-fraction> are 1, then the full document is added to the retrieved set. Some examples of the lines in the output file are: foo1-1 doc1,1,1 foo1-2 doc1,0,0.2 doc2,1,1 doc3,0.2,0 """ from __future__ import print_function import argparse import logging import tf_idf logger = logging.getLogger('__name__') handler = logging.StreamHandler() handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s [%(filename)s:%(lineno)s - " "%(funcName)s - %(levelname)s ] %(message)s") handler.setFormatter(formatter) for l in [logger, logging.getLogger('tf_idf'), logging.getLogger('libs')]: l.setLevel(logging.DEBUG) l.addHandler(handler) def get_args(): parser = argparse.ArgumentParser( description="""This script retrieves documents similar to the query documents using a similarity score based on the total TFIDF for all the terms in the query document. See the beginning of the script for more details about the arguments to the script.""") parser.add_argument("--verbose", type=int, default=0, choices=[0, 1, 2, 3], help="Higher for more logging statements") parser.add_argument("--num-neighbors-to-search", type=int, default=0, help="""Number of neighboring documents to search around the one retrieved based on maximum tf-idf similarity. A value of 0 means only the document with the maximum tf-idf similarity is retrieved, and none of the documents adjacent to it.""") parser.add_argument("--neighbor-tfidf-threshold", type=float, default=0.9, help="""Ignore neighbors that have tf-idf similarity with the query document less than this threshold factor lower than the best score.""") parser.add_argument("--partial-doc-fraction", default=0.2, help="""The fraction of neighboring document that will be part of the retrieved document set. If this is greater than 0, then a fraction of words from the neighboring documents is added to the retrieved document.""") parser.add_argument("--source-text-id2doc-ids", type=argparse.FileType('r'), required=True, help="""A mapping from the source text to a list of documents that it is broken into <text-utterance-id> <document-id-1> ... <document-id-N>""") parser.add_argument("--query-id2source-text-id", type=argparse.FileType('r'), required=True, help="""A mapping from the query document-id to a source text from which a document needs to be retrieved.""") parser.add_argument("--source-text-id2tfidf", type=argparse.FileType('r'), required=True, help="""An SCP file for the TF-IDF for source documents indexed by the source-text-id.""") parser.add_argument("--query-tfidf", type=argparse.FileType('r'), required=True, help="""Archive of TF-IDF objects for query documents indexed by the query-id. The format is query-id <TFIDF> ... </TFIDF> """) parser.add_argument("--relevant-docs", type=argparse.FileType('w'), required=True, help="""Output archive of a list of source documents similar to a query document, indexed by the query document id.""") args = parser.parse_args() if args.partial_doc_fraction < 0 or args.partial_doc_fraction > 1: logger.error("--partial-doc-fraction must be in [0,1]") raise ValueError return args def read_map(file_handle, num_values_per_key=None, min_num_values_per_key=None, must_contain_unique_key=True): """Reads a map from a file into a dictionary and returns it. Expects the map is stored in the file in the following format: <key> <value-1> <value-2> ... <value-N> The values are returned as a tuple stored in a dictionary indexed by the "key". Arguments: file_handle - A handle to an opened input file containing the map num_values_per_key - If provided, the function raises an error if the number of values read for a key in the input file does not match the "num_values_per_key" min_num_values_per_key - If provided, the function raises an error if the number of values read for a key in the input file is less than "min_num_values_per_key" must_contain_unique_key - If set to True, then it is required that the file has a unique key; otherwise this function will exit with error. Returns: { key: tuple(values) } """ dict_map = {} for line in file_handle: try: parts = line.strip().split() key = parts[0] if (num_values_per_key is not None and len(parts) - 1 != num_values_per_key): logger.error( "Expecting {0} columns; Got {1}.".format( num_values_per_key + 1, len(parts))) raise TypeError if (min_num_values_per_key is not None and len(parts) - 1 < min_num_values_per_key): logger.error( "Expecting at least {0} columns; Got {1}.".format( min_num_values_per_key + 1, len(parts))) raise TypeError if must_contain_unique_key and key in dict_map: logger.error("Found duplicate key %s", key) raise TypeError if num_values_per_key is not None and num_values_per_key == 1: dict_map[key] = parts[1] else: dict_map[key] = parts[1:] except Exception: logger.error("Failed reading line %s in file %s", line, file_handle.name) raise file_handle.close() return dict_map def get_document_ids(source_docs, indexes): indexes = sorted( [(key, value[0], value[1]) for key, value in indexes.items()], key=lambda x: x[0]) doc_ids = [] for i, partial_start, partial_end in indexes: try: doc_ids.append((source_docs[i], partial_start, partial_end)) except IndexError: pass return doc_ids def run(args): """The main function that does all the processing. Takes as argument the Namespace object obtained from _get_args(). """ query_id2source_text_id = read_map(args.query_id2source_text_id, num_values_per_key=1) source_text_id2doc_ids = read_map(args.source_text_id2doc_ids, min_num_values_per_key=1) source_text_id2tfidf = read_map(args.source_text_id2tfidf, num_values_per_key=1) num_queries = 0 prev_source_text_id = "" for query_id, query_tfidf in tf_idf.read_tfidf_ark(args.query_tfidf): num_queries += 1 # The source text from which a document is to be retrieved for the # input query source_text_id = query_id2source_text_id[query_id] if prev_source_text_id != source_text_id: source_tfidf = tf_idf.TFIDF() source_tfidf.read( open(source_text_id2tfidf[source_text_id])) prev_source_text_id = source_text_id # The source documents corresponding to the source text. # This is set of documents which will be searched over for the query. source_doc_ids = source_text_id2doc_ids[source_text_id] scores = query_tfidf.compute_similarity_scores( source_tfidf, source_docs=source_doc_ids, query_id=query_id) assert len(scores) > 0, ( "Did not get scores for query {0}".format(query_id)) if args.verbose > 2: for tup, score in scores.items(): logger.debug("Score, {num}: {0} {1} {2}".format( tup[0], tup[1], score, num=num_queries)) best_index, best_doc_id = max( enumerate(source_doc_ids), key=lambda x: scores[(query_id, x[1])]) best_score = scores[(query_id, best_doc_id)] assert source_doc_ids[best_index] == best_doc_id assert best_score == max([scores[(query_id, x)] for x in source_doc_ids]) best_indexes = {} if args.num_neighbors_to_search == 0: best_indexes[best_index] = (1, 1) if best_index > 0: best_indexes[best_index - 1] = (0, args.partial_doc_fraction) if best_index < len(source_doc_ids) - 1: best_indexes[best_index + 1] = (args.partial_doc_fraction, 0) else: excluded_indexes = set() for index in range( max(best_index - args.num_neighbors_to_search, 0), min(best_index + args.num_neighbors_to_search + 1, len(source_doc_ids))): if (scores[(query_id, source_doc_ids[index])] >= args.neighbor_tfidf_threshold * best_score): best_indexes[index] = (1, 1) # Type 2 if index > 0 and index - 1 in excluded_indexes: try: # Type 1 and 3 start_frac, end_frac = best_indexes[index - 1] assert end_frac == 0 best_indexes[index - 1] = ( start_frac, args.partial_doc_fraction) except KeyError: # Type 1 best_indexes[index - 1] = ( 0, args.partial_doc_fraction) else: excluded_indexes.add(index) if index > 0 and index - 1 not in excluded_indexes: # Type 3 best_indexes[index] = (args.partial_doc_fraction, 0) best_docs = get_document_ids(source_doc_ids, best_indexes) assert len(best_docs) > 0, ( "Did not get best docs for query {0} " "Scores: {1} " "Source docs: {2} " "Best index: {best_index}, score: {best_score} ".format( query_id, scores, source_doc_ids, best_index=best_index, best_score=best_score)) assert (best_doc_id, 1.0, 1.0) in best_docs print ("{0} {1}".format(query_id, " ".join( ["%s,%.2f,%.2f" % x for x in best_docs])), file=args.relevant_docs) if num_queries == 0: raise RuntimeError("Failed to retrieve any document.") logger.info("Retrieved similar documents for " "%d queries", num_queries) def main(): args = get_args() if args.verbose > 1: handler.setLevel(logging.DEBUG) try: run(args) finally: for f in [args.query_id2source_text_id, args.source_text_id2doc_ids, args.relevant_docs, args.query_tfidf, args.source_text_id2tfidf]: f.close() if __name__ == '__main__': main() |