Commit 78943632 authored by Manuel Günther's avatar Manuel Günther
Browse files

When database directory is specified, now the database files are persistently...

When database directory is specified, now the database files are persistently downloaded into this directory, if needed.
parent 4d6999b4
......@@ -3,14 +3,14 @@
================
The MNIST database is a database of handwritten digits, which consists of a
training set of 60,000 examples, and a test set of 10,000 examples. It was
made available by Yann Le Cun and Corinna Cortes (`MNIST database
<http://yann.lecun.com/exdb/mnist/>`_). The data was originally extracted
from a larger set made available by `NIST <http://www.nist.gov/>`_, before
training set of 60,000 examples, and a test set of 10,000 examples. It was
made available by Yann Le Cun and Corinna Cortes (`MNIST database
<http://yann.lecun.com/exdb/mnist/>`_). The data was originally extracted
from a larger set made available by `NIST <http://www.nist.gov/>`_, before
being size-normalized and centered in a fixed-size image (28x28 pixels).
The actual raw data for the database should be downloaded from the `original
website <http://yann.lecun.com/exdb/mnist/>`_. This package only contains
website <http://yann.lecun.com/exdb/mnist/>`_. This package only contains
the `Bob <http://www.idiap.ch/software/bob/>`_ accessor methods to use this
database directly from python.
......@@ -33,7 +33,7 @@ The package is available in two different distribution formats:
1. You can download it from `PyPI <http://pypi.python.org/pypi/xbob.db.mnist>`_, or
2. You can download it in its source form from `its git repository
<https://github.com/bioidiap/xbob.db.mnist>`_.
<https://github.com/bioidiap/xbob.db.mnist>`_.
The database raw files must be installed somewhere in your environment.
......@@ -94,13 +94,15 @@ In this case, this should return two NumPy arrays:
2. `labels` are the corresponding classes (digits 0 to 9) for each of the 60,000 samples
If you don't have the data installed on your machine, you can also use the following
If you don't have the data installed on your machine, you can also use the following
set of commands that will:
1. first look for the database in the xbob/db/mnist/ subdirectory and use it if is available
2. or automatically download it from Yann Lecun's website into a temporary folder, that will
be erased when the destructor of the xbob.db.mnist database is called.
2. automatically download it from Yann Lecun's website into a temporary folder that will
be erased when the destructor of the xbob.db.mnist database is called.
3. automatically download it into the provided directory that will **not** be deleted.
::
......@@ -109,3 +111,11 @@ be erased when the destructor of the xbob.db.mnist database is called.
>>> images, labels = db.data(groups='train', labels=[0,1,2,3,4,5,6,7,8,9])
>>> del db # delete the temporary downloaded files if any
or
::
>>> db = xbob.db.mnist.Database("Directory") # Persistently downloads files into the folder "Directory"
>>> images, labels = db.data(groups='train', labels=[0,1,2,3,4,5,6,7,8,9])
>>> del db # The download directory stays
......@@ -23,7 +23,7 @@ from setuptools import setup, find_packages
setup(
name='xbob.db.mnist',
version='1.0.3a0',
version='1.0.4a0',
description='MNIST Database Access API for Bob',
url='https://pypi.python.org/pypi/xbob.db.mnist',
license='GPLv3',
......
......@@ -18,6 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shutil
import os
class Database():
"""Wrapper class for the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/).
......@@ -33,16 +34,19 @@ class Database():
self.m_mnist_filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
self.m_tmp_dir = None
if data_dir: # Path to the data is provided
# check if the data is available in the given directory (or if not given, in the default directory)
if not self._db_is_installed(data_dir):
self.m_data_dir = self._create_tmp_dir_and_download(data_dir)
if data_dir is None:
# if we create a temporary directory, mark it to be deleted at the end
self.m_tmp_dir = self.m_data_dir
elif data_dir is not None:
self.m_data_dir = data_dir
else:
if self.__db_is_installed__():
from pkg_resources import resource_filename
import os
self.m_data_dir = os.path.dirname(resource_filename(__name__, 'query.py'))
else:
self.m_data_dir = self.__create_tmp_dir_and_download__()
self.m_tmp_dir = self.m_data_dir # To avoid bad surprises to the user
from pkg_resources import resource_filename
self.m_data_dir = os.path.dirname(resource_filename(__name__, 'query.py'))
self.m_train_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[0])
self.m_train_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[1])
self.m_test_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[2])
......@@ -56,23 +60,27 @@ class Database():
if e.errno != 2: # code 2 - no such file or directory
raise("xbob.db.mnist: Error while erasing temporarily downloaded data files")
def __db_is_installed__(self):
def _db_is_installed(self, directory = None):
from pkg_resources import resource_filename
import os
db_files = [resource_filename(__name__, k) for k in self.m_mnist_filenames]
if directory is None:
db_files = [resource_filename(__name__, k) for k in self.m_mnist_filenames]
else:
db_files = [os.path.join(directory, k) for k in self.m_mnist_filenames]
for f in db_files:
if not os.path.exists(f):
return False
return True
def __create_tmp_dir_and_download__(self):
import tempfile, os, sys
def _create_tmp_dir_and_download(self, directory=None):
import tempfile, sys
tmp_directory = tempfile.mkdtemp(prefix='mnist_db')
if directory is None:
directory = tempfile.mkdtemp(prefix='mnist_db')
print("Downloading the mnist database from http://yann.lecun.com/exdb/mnist/ ...")
for f in self.m_mnist_filenames:
tmp_file = os.path.join(tmp_directory, f)
tmp_file = os.path.join(directory, f)
url = 'http://yann.lecun.com/exdb/mnist/'+f
if sys.version_info[0] < 3:
......@@ -90,9 +98,9 @@ class Database():
with open(tmp_file, 'wb') as out_file:
copyfileobj(response, out_file)
return tmp_directory
return directory
def __read_labels__(self, fname):
def _read_labels(self, fname):
"""Reads the labels from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
......@@ -102,7 +110,7 @@ class Database():
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
def __read_images__(self, fname):
def _read_images(self, fname):
"""Reads the images from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
......@@ -113,7 +121,7 @@ class Database():
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
def __check_parameters_for_validity__(self, parameters, parameter_description, valid_parameters, default_parameters = None):
def _check_parameters_for_validity(self, parameters, parameter_description, valid_parameters, default_parameters = None):
"""Checks the given parameters for validity, i.e., if they are contained in the set of valid parameters.
It also assures that the parameters form a tuple or a list.
If parameters is 'None' or empty, the default_parameters will be returned (if default_parameters is omitted, all valid_parameters are returned).
......@@ -187,24 +195,24 @@ class Database():
"""
# check if groups set are valid
groups = self.__check_parameters_for_validity__(groups, "group", self.m_groups)
vlabels = self.__check_parameters_for_validity__(labels, "label", self.m_labels)
groups = self._check_parameters_for_validity(groups, "group", self.m_groups)
vlabels = self._check_parameters_for_validity(labels, "label", self.m_labels)
# Reads data from the groups
import numpy
if 'train' in groups and 'test' in groups:
images1 = self.__read_images__(self.m_train_fname_images)
labels1 = self.__read_labels__(self.m_train_fname_labels)
images2 = self.__read_images__(self.m_test_fname_images)
labels2 = self.__read_labels__(self.m_test_fname_labels)
images1 = self._read_images(self.m_train_fname_images)
labels1 = self._read_labels(self.m_train_fname_labels)
images2 = self._read_images(self.m_test_fname_images)
labels2 = self._read_labels(self.m_test_fname_labels)
images = numpy.vstack([images1,images2])
labels = numpy.hstack([labels1,labels2])
elif 'train' in groups:
images = self.__read_images__(self.m_train_fname_images)
labels = self.__read_labels__(self.m_train_fname_labels)
images = self._read_images(self.m_train_fname_images)
labels = self._read_labels(self.m_train_fname_labels)
elif 'test' in groups:
images = self.__read_images__(self.m_test_fname_images)
labels = self.__read_labels__(self.m_test_fname_labels)
images = self._read_images(self.m_test_fname_images)
labels = self._read_labels(self.m_test_fname_labels)
else:
images = numpy.ndarray(shape=(0,784), dtype=numpy.uint8)
labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8)
......
......@@ -51,3 +51,18 @@ class MNISTDatabaseTest(unittest.TestCase):
self.assertEqual(d.shape[0], 70000)
self.assertEqual(d.shape[1], 784)
self.assertEqual(l.shape[0], 70000)
def test02_download(self):
# tests that the files are downloaded *and stored*, when the directory is specified
import tempfile, os, shutil
temp_dir = tempfile.mkdtemp(prefix='mnist_db_test_')
db = Database(temp_dir)
del db
self.assertTrue(os.path.exists(temp_dir))
# check that the database works when data is downloaded already
db = Database(temp_dir)
db.data()
del db
shutil.rmtree(temp_dir)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment