Blame view
egs/wsj/s5/steps/tfrnnlm/reader.py
3.64 KB
8dcb6dfcb first commit |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # Copyright (C) 2017 Intellisist, Inc. (Author: Hainan Xu) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for parsing RNNLM text files.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import os import tensorflow as tf def _read_words(filename): with tf.gfile.GFile(filename, "r") as f: return f.read().decode("utf-8").split() def _build_vocab(filename): words = _read_words(filename) word_to_id = dict(list(zip(words, list(range(len(words)))))) return word_to_id def _file_to_word_ids(filename, word_to_id): data = _read_words(filename) return [word_to_id[word] for word in data if word in word_to_id] def rnnlm_raw_data(data_path, vocab_path): """Load RNNLM raw data from data directory "data_path". Args: data_path: string path to the directory where train/valid files are stored Returns: tuple (train_data, valid_data, test_data, vocabulary) where each of the data objects can be passed to RNNLMIterator. """ train_path = os.path.join(data_path, "train") valid_path = os.path.join(data_path, "valid") word_to_id = _build_vocab(vocab_path) train_data = _file_to_word_ids(train_path, word_to_id) valid_data = _file_to_word_ids(valid_path, word_to_id) vocabulary = len(word_to_id) return train_data, valid_data, vocabulary, word_to_id def rnnlm_producer(raw_data, batch_size, num_steps, name=None): """Iterate on the raw RNNLM data. This chunks up raw_data into batches of examples and returns Tensors that are drawn from these batches. Args: raw_data: one of the raw data outputs from rnnlm_raw_data. batch_size: int, the batch size. num_steps: int, the number of unrolls. name: the name of this operation (optional). Returns: A pair of Tensors, each shaped [batch_size, num_steps]. The second element of the tuple is the same data time-shifted to the right by one. Raises: tf.errors.InvalidArgumentError: if batch_size or num_steps are too high. """ with tf.name_scope(name, "RNNLMProducer", [raw_data, batch_size, num_steps]): raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32) data_len = tf.size(raw_data) batch_len = data_len // batch_size data = tf.reshape(raw_data[0 : batch_size * batch_len], [batch_size, batch_len]) epoch_size = (batch_len - 1) // num_steps assertion = tf.assert_positive( epoch_size, message="epoch_size == 0, decrease batch_size or num_steps") with tf.control_dependencies([assertion]): epoch_size = tf.identity(epoch_size, name="epoch_size") i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue() x = tf.strided_slice(data, [0, i * num_steps], [batch_size, (i + 1) * num_steps]) x.set_shape([batch_size, num_steps]) y = tf.strided_slice(data, [0, i * num_steps + 1], [batch_size, (i + 1) * num_steps + 1]) y.set_shape([batch_size, num_steps]) return x, y |