Commit 5840e121 authored by Laurent EL SHAFEY's avatar Laurent EL SHAFEY
Browse files

Automatically download the db if required, fix bug, add test

parent 2cdd119c
......@@ -93,3 +93,17 @@ In this case, this should return two NumPy arrays:
2. `labels` are the corresponding classes (digits 0 to 9) for each of the 60,000 samples
If you don't have the data installed on your machine, you can also use the following
set of command that will:
1. first look for the database in the xbob/db/mnist/ subdirectory and use it if is available
2. or automatically download it from Yann Lecun's website into a temporary folder, that will
be erased when the destructor of the xbob.db.mnist database is called::
>>> import xbob.db.mnist
>>> db = xbob.db.mnist.Database() # Check for the data files locally, and download them if required
>>> images, labels ='train', labels=[0,1,2,3,4,5,6,7,8,9])
>>> del db # delete the temporary downloaded files if any
......@@ -9,7 +9,7 @@ from setuptools import setup, find_packages
description='MNIST Database Access API for Bob',
......@@ -17,6 +17,8 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <>.
import shutil
class Database():
"""Wrapper class for the MNIST database of handwritten digits (
......@@ -28,11 +30,51 @@ class Database():
import os
self.m_labels = set(range(0,10))
self.m_groups = ('train', 'test')
self.m_data_dir = data_dir
self.m_train_fname_images = os.path.join(self.m_data_dir, 'train-images-idx3-ubyte.gz')
self.m_train_fname_labels = os.path.join(self.m_data_dir, 'train-labels-idx1-ubyte.gz')
self.m_test_fname_images = os.path.join(self.m_data_dir, 't10k-images-idx3-ubyte.gz')
self.m_test_fname_labels = os.path.join(self.m_data_dir, 't10k-labels-idx1-ubyte.gz')
self.m_mnist_filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
self.m_tmp_dir = None
if data_dir: # Path to the data is provided
self.m_data_dir = data_dir
if self.__db_is_installed__():
from pkg_resources import resource_filename
import os
self.m_data_dir = os.path.dirname(resource_filename(__name__, ''))
self.m_data_dir = self.__create_tmp_dir_and_download__()
self.m_tmp_dir = self.m_data_dir # To avoid bad surprises to the user
self.m_train_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[0])
self.m_train_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[1])
self.m_test_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[2])
self.m_test_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[3])
def __del__(self):
if self.m_tmp_dir:
shutil.rmtree(self.m_tmp_dir) # delete directory
except OSError, e:
if e.errno != 2: # code 2 - no such file or directory
raise 'xbob.db.mnist: Error while erasing temporarily downloaded data files'
def __db_is_installed__(self):
from pkg_resources import resource_filename
import os
db_files = [resource_filename(__name__, k) for k in self.m_mnist_filenames]
for f in db_files:
if not os.path.exists(f):
return False
return True
def __create_tmp_dir_and_download__(self):
import urllib2, tempfile, os
tmp_directory = tempfile.mkdtemp(prefix='mnist_db')
print "Downloading the mnist database from ..."
for f in self.m_mnist_filenames:
url = urllib2.urlopen(''+f)
dfile = open(os.path.join(tmp_directory, f), 'w')
return tmp_directory
def __read_labels__(self, fname):
"""Reads the labels from the original MNIST label binary file"""
......@@ -139,7 +181,7 @@ class Database():
labels1 = self.__read_labels__(self.m_train_fname_labels)
images2 = self.__read_images__(self.m_test_fname_images)
labels2 = self.__read_labels__(self.m_test_fname_labels)
images = numpy.hstack([images1,images2])
images = numpy.vstack([images1,images2])
labels = numpy.hstack([labels1,labels2])
elif 'train' in groups:
images = self.__read_images__(self.m_train_fname_images)
......@@ -27,7 +27,7 @@ class MNISTDatabaseTest(unittest.TestCase):
"""Performs various tests on the MNIST database."""
def test01_query(self):
db = Database('')
db = Database()
f = db.labels()
self.assertEqual(len(f), 10) # number of labels (digits 0 to 9)
......@@ -39,4 +39,16 @@ class MNISTDatabaseTest(unittest.TestCase):
self.assertTrue('train' in f)
self.assertTrue('test' in f)
# TODO: Test the number of samples?
# Test the number of samples/labels
d, l ='train')
self.assertEqual(d.shape[0], 60000)
self.assertEqual(d.shape[1], 784)
self.assertEqual(l.shape[0], 60000)
d, l ='test')
self.assertEqual(d.shape[0], 10000)
self.assertEqual(d.shape[1], 784)
self.assertEqual(l.shape[0], 10000)
d, l =
self.assertEqual(d.shape[0], 70000)
self.assertEqual(d.shape[1], 784)
self.assertEqual(l.shape[0], 70000)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment