Blame view

egs/wsj/s5/steps/cleanup/internal/stitch_documents.py 5.22 KB
8dcb6dfcb   Yannick Estève   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
  #! /usr/bin/env python
  
  # Copyright 2016    Vimal Manohar
  # Apache 2.0.
  
  """This script reads an archive of mapping from query to
  documents and stitches the documents for each query into a
  new document.
  Here "document" is just a list of words.
  
  query2docs is a mapping from query-id to a list of tuples
  (document-id, start-fraction, end-fraction)
  The tuple can be just the document-id, which is equivaluent to
  specifying a start-fraction and end-fraction of 1.0
  The start and end fractions are used to stitch only a part of the
  document to the retrieved set for the query.
  
  e.g.
  query1 doc1 doc2
  query2 doc1,0,0.3 doc2,1,1
  
  input-documents
  doc1 A B C
  doc2 D E
  output-documents
  query1 A B C D E
  query2 C D E
  """
  
  from __future__ import print_function
  import argparse
  import logging
  
  logger = logging.getLogger(__name__)
  handler = logging.StreamHandler()
  handler.setLevel(logging.INFO)
  formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
                                "%(funcName)s - %(levelname)s ] %(message)s")
  handler.setFormatter(formatter)
  
  for l in [logger, logging.getLogger('libs')]:
      l.setLevel(logging.DEBUG)
      l.addHandler(handler)
  
  
  def get_args():
      """Returns arguments parsed from command-line."""
  
      parser = argparse.ArgumentParser(
          description="""This script reads an archive of mapping from query to
          documents and stitches the documents for each query into a new
          document.""")
  
      parser.add_argument("--query2docs", type=argparse.FileType('r'),
                          required=True,
                          help="""Input file containing an archive
                          of list of documents indexed by a query document
                          id.""")
      parser.add_argument("--input-documents", type=argparse.FileType('r'),
                          required=True,
                          help="""Input file containing the documents
                          indexed by the document id.""")
      parser.add_argument("--output-documents", type=argparse.FileType('w'),
                          required=True,
                          help="""Output documents indexed by the query
                          document-id, obtained by stitching input documents
                          corresponding to the query.""")
      parser.add_argument("--check-sorted-docs-per-query", type=str,
                          choices=["true", "false"], default="false",
                          help="If specified, the script will expect "
                          "the document ids in --query2docs to be "
                          "sorted.")
  
      args = parser.parse_args()
  
      args.check_sorted_docs_per_query = bool(
          args.check_sorted_docs_per_query == "true")
  
      return args
  
  
  def run(args):
      documents = {}
      for line in args.input_documents:
          parts = line.strip().split()
          key = parts[0]
          documents[key] = parts[1:]
      args.input_documents.close()
  
      for line in args.query2docs:
          try:
              parts = line.strip().split()
              query = parts[0]
              document_infos = parts[1:]
  
              output_document = []
              prev_doc_id = ''
              for doc_info in document_infos:
                  try:
                      doc_id, start_fraction, end_fraction = doc_info.split(',')
                      start_fraction = float(start_fraction)
                      end_fraction = float(end_fraction)
                  except ValueError:
                      doc_id = doc_info
                      start_fraction = 1.0
                      end_fraction = 1.0
  
                  if args.check_sorted_docs_per_query:
                      if prev_doc_id != '':
                          if doc_id <= prev_doc_id:
                              raise RuntimeError(
                                  "Documents not sorted and "
                                  "--check-sorted-docs-per-query was True; "
                                  "{0} <= {1}".format(doc_id, prev_doc_id))
                      prev_doc_id = doc_id
  
                  doc = documents[doc_id]
                  num_words = len(doc)
  
                  if start_fraction == 1.0 or end_fraction == 1.0:
                      assert end_fraction == end_fraction
                      output_document.extend(doc)
                  else:
                      assert (start_fraction + end_fraction < 1.0)
                      if start_fraction > 0:
                          output_document.extend(
                              doc[0:int(start_fraction * num_words)])
                      if end_fraction > 0:
                          output_document.extend(
                              doc[int(end_fraction * num_words):])
  
              print ("{0} {1}".format(query, " ".join(output_document)),
                     file=args.output_documents)
          except Exception:
              logger.error("Error processing line %s in file %s", line,
                           args.query2docs.name)
              raise
  
  
  def main():
      args = get_args()
  
      try:
          run(args)
      except:
          logger.error("Failed to stictch document; got error ",
                       exc_info=True)
          raise SystemExit(1)
      finally:
          for f in [args.query2docs, args.input_documents,
                    args.output_documents]:
              if f is not None:
                  f.close()
  
  
  if __name__ == '__main__':
      main()