query.py 3.65 KB
Newer Older
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
1
2
3
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :

4

5
import os
6
7
8
9
10
11
import shutil
import struct
import gzip
import numpy

from bob.db.base.utils import check_parameters_for_validity
12

13
14
15
16
17
18

class Database:
  """Wrapper class for the MNIST database of handwritten digits.

  The original database files are distributed over:
  http://yann.lecun.com/exdb/mnist/.
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
19
20
  """

21
22
23
24
25
26
27
28
29
30
31
32
33
34

  def __init__(self):

    from .driver import Interface
    f = Interface().files()

    self.train_images = f[0]
    self.train_labels = f[1]
    self.test_images  = f[2]
    self.test_labels  = f[3]

    self._labels = set(range(0,10))
    self._groups = ('train', 'test')

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
35

36
  def _read_labels(self, fname):
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
37
    """Reads the labels from the original MNIST label binary file"""
38
39
40
41
42
43
44
45
46
47
48

    with gzip.open(fname, 'rb') as f:

      # reads 2 big-ending integers
      magic_nr, n_examples = struct.unpack(">II", f.read(8))
      # reads the rest, using an uint8 dataformat (endian-less)

      labels = numpy.fromstring(f.read(), dtype='uint8')

      return labels

49

50
  def _read_images(self, fname):
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
51
52
    """Reads the images from the original MNIST label binary file"""

53
54
55
56
57
    with gzip.open(fname, 'rb') as f:

      # reads 4 big-ending integers
      magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
      shape = (n_examples, rows*cols)
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
58

59
60
61
62
      # reads the rest, using an uint8 dataformat (endian-less)
      images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)

      return images
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
63
64
65
66
67


  def labels(self):
    """Returns the vector of labels
    """
68
69
70

    return self._labels

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
71
72
73
74

  def groups(self):
    """Returns the vector of groups
    """
75
76
77

    return self._groups

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
78
79
80
81
82

  def data(self, groups=None, labels=None):
    """Loads the MNIST samples and labels and returns them in NumPy arrays


83
84
85
86
87
88
89
90
    Parameters:

      groups (:py:class:`str` or :py:class:`list`): One of the groups ``train``
        or ``test``, or a list with both of them (which is the default)

      labels (:py:class:`int` or :py:class:`list`): A subset of the labels
        (digits 0 to 9) (everything is the default)

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
91

92
    Returns:
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
93

André Anjos's avatar
André Anjos committed
94
95
96
97
      numpy.ndarray: A 2D array representing the digit images, with as many
      rows as examples in the dataset, as many columns as pixels (actually,
      there are 28x28 = 784 rows). The pixels of each image are unrolled in
      C-scan order (i.e., first row 0, then row 1, etc.)
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
98

André Anjos's avatar
André Anjos committed
99
100
      numpy.ndarray: A 1D array with as many elements as examples in the
      dataset, containing the labels for each image returned above.
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
101
102
103
104

    """

    # check if groups set are valid
105
106
    groups = check_parameters_for_validity(groups, "group", self._groups)
    vlabels = check_parameters_for_validity(labels, "label", self._labels)
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
107
108
109

    # Reads data from the groups
    if 'train' in groups and 'test' in groups:
110
111
112
113
      images1 = self._read_images(self.train_images)
      labels1 = self._read_labels(self.train_labels)
      images2 = self._read_images(self.test_images)
      labels2 = self._read_labels(self.test_labels)
114
      images = numpy.vstack([images1,images2])
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
115
      labels = numpy.hstack([labels1,labels2])
116

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
117
    elif 'train' in groups:
118
119
120
      images = self._read_images(self.train_images)
      labels = self._read_labels(self.train_labels)

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
121
    elif 'test' in groups:
122
123
124
      images = self._read_images(self.test_images)
      labels = self._read_labels(self.test_labels)

Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
125
126
127
128
129
    else:
      images = numpy.ndarray(shape=(0,784), dtype=numpy.uint8)
      labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8)

    # List of indices for which the labels are in the list of requested labels
130
    indices = numpy.where(numpy.array([v in vlabels for v in labels]))[0]
Laurent EL SHAFEY's avatar
Laurent EL SHAFEY committed
131
132
133
134
    images = images[indices,:]
    labels = labels[indices]

    return images, labels