Blame view

datasets/schemes.py 1.65 KB
f2d3bd141   Parcollet Titouan   Initial commit wi...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  import six
  import math
  import numpy
  from picklable_itertools import imap
  from picklable_itertools.extras import partition_all
  from fuel.schemes import BatchScheme
  
  
  class SequentialShuffledScheme(BatchScheme):
      """Sequential batches iterator.
  
      Iterate over all the examples in a dataset of fixed size sequentially
      in batches of a given size.
  
      Notes
      -----
      The batch size isn't enforced, so the last batch could be smaller.
  
      """
      def __init__(self, num_examples, batch_size, rng):
          self.num_examples = num_examples
          self.batch_size = batch_size
          self.rng = rng
  
      def get_request_iterator(self):
          return SequentialShuffledIterator(self.num_examples, self.batch_size,
                                            self.rng)
  
  class SequentialShuffledIterator(six.Iterator):
      def __init__(self, num_examples, batch_size, rng):
          self.num_examples = num_examples
          self.batch_size = batch_size
          self.rng = rng
          self.batch_indexes = range(int(math.ceil(num_examples/ float(batch_size))))
          self.rng.shuffle(self.batch_indexes)
          self.current = 0
          self.current_batch = 0
  
      def __iter__(self):
          self.rng.shuffle(self.batch_indexes)
          return self
  
      def __next__(self):
          if self.current >= self.num_examples:
              raise StopIteration
          current_index = self.batch_indexes[self.current_batch]
          slice_ = slice(current_index * self.batch_size,
                         min(self.num_examples,
                             (current_index + 1) * self.batch_size))
          self.current += self.batch_size
          self.current_batch += 1
          return slice_