Blame view
egs/wsj/s5/steps/cleanup/internal/stitch_documents.py
5.22 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
#! /usr/bin/env python # Copyright 2016 Vimal Manohar # Apache 2.0. """This script reads an archive of mapping from query to documents and stitches the documents for each query into a new document. Here "document" is just a list of words. query2docs is a mapping from query-id to a list of tuples (document-id, start-fraction, end-fraction) The tuple can be just the document-id, which is equivaluent to specifying a start-fraction and end-fraction of 1.0 The start and end fractions are used to stitch only a part of the document to the retrieved set for the query. e.g. query1 doc1 doc2 query2 doc1,0,0.3 doc2,1,1 input-documents doc1 A B C doc2 D E output-documents query1 A B C D E query2 C D E """ from __future__ import print_function import argparse import logging logger = logging.getLogger(__name__) handler = logging.StreamHandler() handler.setLevel(logging.INFO) formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - " "%(funcName)s - %(levelname)s ] %(message)s") handler.setFormatter(formatter) for l in [logger, logging.getLogger('libs')]: l.setLevel(logging.DEBUG) l.addHandler(handler) def get_args(): """Returns arguments parsed from command-line.""" parser = argparse.ArgumentParser( description="""This script reads an archive of mapping from query to documents and stitches the documents for each query into a new document.""") parser.add_argument("--query2docs", type=argparse.FileType('r'), required=True, help="""Input file containing an archive of list of documents indexed by a query document id.""") parser.add_argument("--input-documents", type=argparse.FileType('r'), required=True, help="""Input file containing the documents indexed by the document id.""") parser.add_argument("--output-documents", type=argparse.FileType('w'), required=True, help="""Output documents indexed by the query document-id, obtained by stitching input documents corresponding to the query.""") parser.add_argument("--check-sorted-docs-per-query", type=str, choices=["true", "false"], default="false", help="If specified, the script will expect " "the document ids in --query2docs to be " "sorted.") args = parser.parse_args() args.check_sorted_docs_per_query = bool( args.check_sorted_docs_per_query == "true") return args def run(args): documents = {} for line in args.input_documents: parts = line.strip().split() key = parts[0] documents[key] = parts[1:] args.input_documents.close() for line in args.query2docs: try: parts = line.strip().split() query = parts[0] document_infos = parts[1:] output_document = [] prev_doc_id = '' for doc_info in document_infos: try: doc_id, start_fraction, end_fraction = doc_info.split(',') start_fraction = float(start_fraction) end_fraction = float(end_fraction) except ValueError: doc_id = doc_info start_fraction = 1.0 end_fraction = 1.0 if args.check_sorted_docs_per_query: if prev_doc_id != '': if doc_id <= prev_doc_id: raise RuntimeError( "Documents not sorted and " "--check-sorted-docs-per-query was True; " "{0} <= {1}".format(doc_id, prev_doc_id)) prev_doc_id = doc_id doc = documents[doc_id] num_words = len(doc) if start_fraction == 1.0 or end_fraction == 1.0: assert end_fraction == end_fraction output_document.extend(doc) else: assert (start_fraction + end_fraction < 1.0) if start_fraction > 0: output_document.extend( doc[0:int(start_fraction * num_words)]) if end_fraction > 0: output_document.extend( doc[int(end_fraction * num_words):]) print ("{0} {1}".format(query, " ".join(output_document)), file=args.output_documents) except Exception: logger.error("Error processing line %s in file %s", line, args.query2docs.name) raise def main(): args = get_args() try: run(args) except: logger.error("Failed to stictch document; got error ", exc_info=True) raise SystemExit(1) finally: for f in [args.query2docs, args.input_documents, args.output_documents]: if f is not None: f.close() if __name__ == '__main__': main() |