tf_idf.py 15.5 KB
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
# Copyright 2016    Vimal Manohar
# Apache 2.0.

"""This module contains structures to accumulate, store and use stats
for Term-frequency and Inverse-document-frequency values.
"""

from __future__ import print_function
from __future__ import division
import logging
import math
import re
import sys

sys.path.insert(0, 'steps')

logger = logging.getLogger('__name__')
logger.addHandler(logging.NullHandler())


class IDFStats(object):
    """Stores stats for computing inverse-document-frequencies.
    """
    def __init__(self):
        self.num_docs_for_term = {}
        self.num_docs = 0

    def get_inverse_document_frequency(self, term, weighting_scheme="log"):
        """Get IDF for a term.

        Weighting scheme is the function applied on the raw
        inverse-document frequencies n(t) = |d in D: t in d|
        when computing idf(t,d).
        Let N = Total number of documents.

        IDF weighting schemes:-
        unary  : idf(t,D) = 1
        log    : idf(t,D) = log (N / (1 + n(t)))
        log-smoothed : idf(t,D) = log(1 + N / n(t))
        probabilistic: idf(t,D) = log((N - n(t)) / n(t))
        """
        n_t = float(self.num_docs_for_term.get(term, 0))
        num_terms = len(self.num_docs_for_term)

        if num_terms == 0:
            raise RuntimeError("No IDF stats have been accumulated.")

        if weighting_scheme == "unary":
            return 1
        if weighting_scheme == "log":
            return math.log(float(self.num_docs) / (1.0 + n_t))
        if weighting_scheme == "log-smoothed":
            return math.log(1.0 + float(self.num_docs) / (1.0 + n_t))
        if weighting_scheme == "probabilitic":
            return math.log((self.num_docs - n_t - 1) / (1.0 + n_t))

    def accumulate(self, term):
        """Adds one count to the number of docs containing the term "term".
        """
        self.num_docs_for_term[term] = self.num_docs_for_term.get(term, 0) + 1
        if len(term) == 1:
            self.num_docs += 1

    def write(self, file_handle):
        """Writes the IDF stats to file using the format:
        <term-1> <term-2> ... <term-N> <num-docs>
        for n-gram (<term-1>, ... <term-N>)
        """
        for term, num in self.num_docs_for_term.items():
            if num == 0:
                continue
            assert isinstance(term, tuple)
            print ("{term} {n}".format(term=" ".join(term), n=num),
                   file=file_handle)

    def read(self, file_handle):
        """Loads IDF stats from file. """
        for line in file_handle:
            parts = line.strip().split()
            term = tuple(parts[0:-1])
            self.num_docs_for_term[term] = float(parts[-1])
            if len(term) == 1:
                self.num_docs += 1

        if len(self.num_docs_for_term) == 0:
            raise RuntimeError("Read no IDF stats.")


class TFStats(object):
    """Store stats for TF-IDF computation.
    A separate object of IDFStats is stored within this object.
    """
    def __init__(self):
        self.raw_counts = {}
        self.max_counts_for_term = {}

    def get_term_frequency(self, term, doc, weighting_scheme="raw",
                           normalization_factor=0.5):
        """Returns the term-frequency for (term, document) pair.

        The function applied on the raw term-frequencies f(t,d) when computing
        tf(t,d) is specified by the weighting_scheme.
        binary : tf(t,d) = 1 if t in d else 0
        raw    : tf(t,d) = f(t,d)
        log    : tf(t,d) = 1 + log(f(t,d))
        normalized : tf(t,d) = K + (1-K) * f(t,d) / max{f(t',d): t' in d}
        """
        if weighting_scheme == "binary":
            return 1 if (term, doc) in self.raw_counts else 0
        if weighting_scheme == "raw":
            return self.raw_counts.get((term, doc), 0)
        if weighting_scheme == "log":
            if (term, doc) in self.raw_counts:
                return 1 + math.log(self.raw_counts[(term, doc)])
            return 0
        if weighting_scheme == "normalized":
            return (normalization_factor
                    + (1 - normalization_factor)
                    * self.raw_counts.get((term, doc), 0)
                    / (1.0 + self.max_counts_for_term.get(term, 0)))
        raise KeyError("Unknown tf-weighting-scheme {0}".format(
            weighting_scheme))

    def accumulate(self, doc, text, ngram_order):
        """Accumulate raw stats from a document for upto the specified
        ngram-order."""
        for n in range(1, ngram_order + 1):
            for i in range(len(text)):
                term = tuple(text[i:(i+n)])
                self.raw_counts.setdefault((term, doc), 0)
                self.raw_counts[(term, doc)] += 1

    def compute_term_stats(self, idf_stats=None):
        """Compute the maximum counts for each term over all the documents
        based on the stored raw counts."""
        if len(self.raw_counts) == 0:
            raise RuntimeError("No (term, doc) found in tf-stats.")
        for tup, counts in self.raw_counts.items():
            term = tup[0]

            if counts > self.max_counts_for_term.get(term, 0):
                self.max_counts_for_term[term] = counts

            if idf_stats is not None:
                idf_stats.accumulate(term)

    def __str__(self):
        """Returns a string with all the stats in the following format:
        <n-gram order> <term-1> <term-2> ... <term-n> <document-id> <counts>
        """
        lines = []
        for tup, counts in self.raw_counts.items():
            term, doc = tup
            lines.append("{order} {term} {doc} {counts}".format(
                order=len(term), term=" ".join(term),
                doc=doc, counts=counts))
        return "\n".join(lines)

    def read(self, file_handle, ngram_order=None, idf_stats=None):
        """Reads the TF stats stored in a file in the following format:
        <ngram-order> <term-1> <term-2> ... <term-n> <document-id> <counts>

        If idf_stats is provided then idf_stats is accumulated simultaneously.
        """
        for line in file_handle:
            parts = line.strip().split()
            order = parts[0]
            assert len(parts) - 3 == order
            if ngram_order is not None and order > ngram_order:
                continue
            term = tuple(parts[1:(order+1)])
            doc = parts[-2]
            counts = float(parts[-1])

            self.raw_counts[(term, doc)] = counts

            if counts > self.max_counts_for_term.get(term, 0):
                self.max_counts_for_term[term] = counts

            if idf_stats is not None:
                idf_stats.accumulate(term)

        if len(self.raw_counts) == 0:
            raise RuntimeError("Read no TF stats.")


class TFIDF(object):
    """Class to store TF-IDF values for term-document pairs.

    Parameters:
        tf_idf - A dictionary of TF-IDF values indexed by (term, document)
                 tuple as key
    """

    def __init__(self):
        self.tf_idf = {}

    def get_value(self, term, doc):
        """Returns TF-IDF value for (term, doc) tuple if it exists.
        Otherwise returns 0.
        """
        return self.tf_idf[(term, doc)]

    def compute_similarity_scores(self, source_tfidf, source_docs=None,
                                  do_length_normalization=False,
                                  query_id=None):
        """Computes TF-IDF similarity score between each pair of query
        document contained in this object and the source documents
        in the source_tfidf object.

        Arguments:
            source_docs - If provided, the similarity scores are computed
                          for only the source documents contained in
                          source_docs.
            use_average - If True, then the similarity scores is
                          normalized by the length of query. This is usually
                          not required when the scores are only utilized
                          for ranking the source documents.
            query_id - If provided, check that this tf_idf object
                       contains values only for document with id 'query_id'

        Returns a dictionary
            { (query_document_id, source_document_id): similarity_score }
        """
        num_terms_per_doc = {}
        similarity_scores = {}

        for tup, value in self.tf_idf.items():
            term, doc = tup
            num_terms_per_doc[doc] = num_terms_per_doc.get(doc, 0) + 1

            if query_id is not None and doc != query_id:
                raise RuntimeError("TF-IDF contains document {0}, which is "
                                   "not the required query {1}. \n"
                                   "Something wrong in how this TF-IDF object "
                                   "was created or a bug in the "
                                   "calling script.".format(
                                       doc, query_id))

            if source_docs is not None:
                for src_doc in source_docs:
                    try:
                        src_value = source_tfidf.get_value(term, src_doc)
                    except KeyError:
                        logger.debug(
                            "Could not find ({term}, {src}) in "
                            "source_tfidf. "
                            "Choosing a tf-idf value of 0.".format(
                                term=term, src=src_doc))
                        src_value = 0

                    similarity_scores[(doc, src_doc)] = (
                        similarity_scores.get((doc, src_doc), 0)
                        + src_value * value)
            else:
                for src_tup, src_value in source_tfidf.tf_idf.items():
                    similarity_scores[(doc, src_doc)] = (
                        similarity_scores.get((doc, src_doc), 0)
                        + src_value * value)

        if do_length_normalization:
            for doc_pair, value in similarity_scores.items():
                doc, src_doc = doc_pair
                similarity_scores[(doc, src_doc)] = value / num_terms_per_doc[doc]

        if logger.isEnabledFor(logging.DEBUG):
            for doc, count in num_terms_per_doc.items():
                logger.debug(
                    'Seen {0} terms in query document {1}'.format(count, doc))

        return similarity_scores

    def read(self, tf_idf_file):
        """Loads TFIDF object from file."""

        if len(self.tf_idf) != 0:
            raise RuntimeError("TD-IDF object is not empty.")
        seen_footer = False
        line = tf_idf_file.readline()
        parts = line.strip().split()
        if re.search('^<TFIDF>', line) is None:
            raise TypeError(
                "Invalid format of TD-IDF object. "
                "Missing header <TFIDF>; got {0}".format(line))
        assert parts[0] == "<TFIDF>"
        if len(parts) > 1:
            # Read header; go to the rest of line
            line = " ".join(parts[1:])
        else:
            # Nothing in this line. Read the next lines.
            line = tf_idf_file.readline()
        while line:
            parts = line.strip().split()
            if re.search('</TFIDF>', line):
                if len(parts) > 1:
                    raise TypeError(
                        "Expecting footer </TFIDF> "
                        "to be on a separate line; got {0}".format(line))
                assert parts[0] == "</TFIDF>"
                seen_footer = True
                break
            if re.search('<TFIDF>', line):
                raise TypeError("Got unexpected header <TFIDF> in line "
                                "{0}".format(line))

            order = int(parts[0])
            term = tuple(parts[1:(order + 1)])
            doc = parts[-2]
            tfidf = float(parts[-1])

            entry = (term, doc)
            if entry in self.tf_idf:
                raise RuntimeError("Duplicate entry {0} found while reading "
                                   "TFIDF object.".format(entry))
            self.tf_idf[entry] = tfidf

            line = tf_idf_file.readline()
        if not seen_footer:
            raise TypeError(
                "Did not see footer </TFIDF> "
                "in TFIDF object; got {0}".format(line))

        if len(self.tf_idf) == 0:
            raise RuntimeError(
                "Read no TF-IDF values from file {0}".format(tf_idf_file.name))

    def write(self, tf_idf_file):
        """Writes TFIDF object to file."""

        print ("<TFIDF>", file=tf_idf_file)
        for tup, value in self.tf_idf.items():
            term, doc = tup
            print("{order} {term} {doc} {tfidf}".format(
                order=len(term), term=" ".join(term),
                doc=doc, tfidf=value),
                  file=tf_idf_file)
        print ("</TFIDF>", file=tf_idf_file)


def write_tfidf_from_stats(
        tf_stats, idf_stats, tf_idf_file, tf_weighting_scheme="raw",
        idf_weighting_scheme="log", tf_normalization_factor=0.5,
        expected_document_id=None):
    """Writes TF-IDF values to file args.tf_idf_file.
    The format used is
    <ngram-order> <term> <document> <tfidf>.
    Markers "<TFIDF>" and "</TFIDF>" are added for parsing this file
    easily.

    Arguments:
        tf_stats - A TFStats object
        idf_stats - An IDFStats object
        tf_idf_file - Output file to which the TF-IDF values will be written
        tf_weighting_scheme - See doc_string in TFStats class
        idf_weighting_scheme - See doc_string in IDFStats class
        tf_normalization_factor - See doc_string in TFStats class
        document_id - If provided, checks that the TFStats object contains
                      stats only for this document_id.
    """
    if len(tf_stats.raw_counts) == 0:
        raise RuntimeError("Supplied tf-stats object is empty.")

    if idf_stats.num_docs == 0:
        raise RuntimeError("Supplied idf-stats object is empty.")

    print ("<TFIDF>", file=tf_idf_file)
    for tup in tf_stats.raw_counts:
        term, doc = tup

        if expected_document_id is not None and doc != expected_document_id:
            raise RuntimeError("TFStats object contains stats with "
                               "document {0}, "
                               "which is not the specified "
                               "document {1}.".format(doc,
                                                      expected_document_id))

        tf_value = tf_stats.get_term_frequency(
            term, doc,
            weighting_scheme=tf_weighting_scheme,
            normalization_factor=tf_normalization_factor)

        idf_value = idf_stats.get_inverse_document_frequency(
            term, weighting_scheme=idf_weighting_scheme)

        print("{order} {term} {doc} {tfidf}".format(
            order=len(term), term=" ".join(term),
            doc=doc, tfidf=tf_value * idf_value),
              file=tf_idf_file)
    print ("</TFIDF>", file=tf_idf_file)


def read_key(fd):
  """ [str] = read_key(fd)
   Read the utterance-key from the opened ark/stream descriptor 'fd'.
  """
  str = ''
  while 1:
    char = fd.read(1)
    if char == '' : break
    if char == ' ' : break
    str += char
  str = str.strip()
  if str == '': return None # end of file,
  return str


def read_tfidf_ark(file_handle):
    """Read a kaldi archive of TFIDF objects indexed by a key (document-id).
    <document-id1> <tf-idf-object1>
    <document-id2> <tf-idf-object2>
    ...
    """
    try:
        key = read_key(file_handle)
        while key:
            tf_idf = TFIDF()
            try:
                tf_idf.read(file_handle)
            except RuntimeError:
                raise
            yield key, tf_idf
            key = read_key(file_handle)
    finally:
        file_handle.close()