Commit 260cba6d authored by André Anjos's avatar André Anjos 💬
Browse files

Remove private download function in favour of bob.db.base's implementation (closes #4)

parent 5bd4aad4
Pipeline #9534 failed with stages
in 6 minutes and 40 seconds
......@@ -17,3 +17,4 @@ dist
build
*.egg
src/
bob/db/mnist/data/
This diff is collapsed.
include README.rst bootstrap-buildout.py buildout.cfg develop.cfg COPYING version.txt requirements.txt
include README.rst buildout.cfg develop.cfg LICENSE version.txt requirements.txt
recursive-include doc *.py *.rst
recursive-include bob *.gz
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# @date: Wed May 8 19:18:16 CEST 2013
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
The MNIST Database is a database of handwritten digits, which has a training
set of 60,000 examples, and a test set of 10,000 examples. It is a subset of
a larger set available from NIST. The digits have been size-normalized and
centered in a fixed-size image. You can download the MNIST database from:
https://web-beta.archive.org/web/20161231041016/http://yann.lecun.com/exdb/mnist/
centered in a fixed-size image.
"""
from .query import Database
def get_config():
"""Returns a string containing the configuration information.
"""
import bob.extension
return bob.extension.get_config(__name__)
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, an not on the import module.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args: obj.__module__ = __name__
__appropriate__(
Database,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# @date: Wed May 8 19:15:47 CEST 2013
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""Commands the MNIST database can respond to.
"""
import os
import sys
from bob.db.base.driver import Interface as BaseInterface
import pkg_resources
def download_mnist(self):
# Hack that will download the mnist database
from bob.db.base.driver import Interface as BaseInterface
import pkg_resources
import bob.db.mnist
db_folder = pkg_resources.resource_filename(__name__, '') # Defining a folder for download
db = bob.db.mnist.Database(data_dir=db_folder) # Downloading
del db
class Interface(BaseInterface):
def name(self):
return 'mnist'
def version(self):
import pkg_resources # part of setuptools
return pkg_resources.require('bob.db.%s' % self.name())[0].version
def files(self):
return ()
basedir = pkg_resources.resource_filename(__name__, '')
filelist = os.path.join(basedir, 'files.txt')
return [os.path.join(basedir, k.strip()) for k in \
open(filelist, 'rt').readlines() if k.strip()]
def type(self):
return 'binary'
return 'text'
def add_commands(self, parser):
def add_commands(self, parser):
from . import __doc__ as docs
from bob.db.base.driver import download_command
subparsers = self.setup_parser(parser, "MNIST database", docs)
parser = download_command(subparsers)
parser.set_defaults(func=download_mnist)
self.setup_parser(parser, "MNIST database", docs)
data/train-images-idx3-ubyte.gz
data/train-labels-idx1-ubyte.gz
data/t10k-images-idx3-ubyte.gz
data/t10k-labels-idx1-ubyte.gz
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Laurent El Shafey <Laurent.El-Shafey@idiap.ch>
# @date: Wed May 8 19:42:39 CEST 2013
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import shutil
import os
import shutil
import struct
import gzip
import numpy
from bob.db.base.utils import check_parameters_for_validity
class Database():
"""Wrapper class for the MNIST database of handwritten digits (https://web-beta.archive.org/web/20161231041016/http://yann.lecun.com/exdb/mnist/).
class Database:
"""Wrapper class for the MNIST database of handwritten digits.
The original database files are distributed over:
http://yann.lecun.com/exdb/mnist/.
"""
def __init__(self, data_dir = None):
"""Creates the database. The data_dir argument should be the path to the directory
containing the four binary files available from https://web-beta.archive.org/web/20161231041016/http://yann.lecun.com/exdb/mnist/"""
# initialize members
import os
self.m_labels = set(range(0,10))
self.m_groups = ('train', 'test')
self.m_mnist_filenames = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',
't10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
self.m_tmp_dir = None
# check if the data is available in the given directory (or if not given, in the default directory)
if not self._db_is_installed(data_dir):
self.m_data_dir = self._create_tmp_dir_and_download(data_dir)
if data_dir is None:
# if we create a temporary directory, mark it to be deleted at the end
self.m_tmp_dir = self.m_data_dir
elif data_dir is not None:
self.m_data_dir = data_dir
else:
from pkg_resources import resource_filename
self.m_data_dir = os.path.dirname(resource_filename(__name__, 'query.py'))
self.m_train_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[0])
self.m_train_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[1])
self.m_test_fname_images = os.path.join(self.m_data_dir, self.m_mnist_filenames[2])
self.m_test_fname_labels = os.path.join(self.m_data_dir, self.m_mnist_filenames[3])
def __del__(self):
try:
if self.m_tmp_dir:
shutil.rmtree(self.m_tmp_dir) # delete directory
except OSError as e:
if e.errno != 2: # code 2 - no such file or directory
raise("bob.db.mnist: Error while erasing temporarily downloaded data files")
def _db_is_installed(self, directory = None):
from pkg_resources import resource_filename
import os
if directory is None:
db_files = [resource_filename(__name__, k) for k in self.m_mnist_filenames]
else:
db_files = [os.path.join(directory, k) for k in self.m_mnist_filenames]
for f in db_files:
if not os.path.exists(f):
return False
return True
def _create_tmp_dir_and_download(self, directory=None):
import tempfile, sys
if directory is None:
directory = tempfile.mkdtemp(prefix='mnist_db')
elif not os.path.exists(directory):
os.makedirs(directory)
print("Downloading the mnist database from https://web-beta.archive.org/web/20161231041016/http://yann.lecun.com/exdb/mnist/ ...")
for f in self.m_mnist_filenames:
tmp_file = os.path.join(directory, f)
url = 'https://web-beta.archive.org/web/20161231041016/http://yann.lecun.com/exdb/mnist/'+f
if sys.version_info[0] < 3:
# python2 technique for downloading a file
from urllib2 import urlopen
with open(tmp_file, 'wb') as out_file:
response = urlopen(url)
out_file.write(response.read())
else:
# python3 technique for downloading a file
from urllib.request import urlopen
from shutil import copyfileobj
with urlopen(url) as response:
with open(tmp_file, 'wb') as out_file:
copyfileobj(response, out_file)
return directory
def __init__(self):
from .driver import Interface
f = Interface().files()
self.train_images = f[0]
self.train_labels = f[1]
self.test_images = f[2]
self.test_labels = f[3]
self._labels = set(range(0,10))
self._groups = ('train', 'test')
def _read_labels(self, fname):
"""Reads the labels from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
with gzip.open(fname, 'rb') as f:
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
def _read_images(self, fname):
"""Reads the images from the original MNIST label binary file"""
import struct, gzip, numpy
f = gzip.GzipFile(fname, 'rb')
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
def _check_parameters_for_validity(self, parameters, parameter_description, valid_parameters, default_parameters = None):
"""Checks the given parameters for validity, i.e., if they are contained in the set of valid parameters.
It also assures that the parameters form a tuple or a list.
If parameters is 'None' or empty, the default_parameters will be returned (if default_parameters is omitted, all valid_parameters are returned).
This function will return a tuple or list of parameters, or raise a ValueError.
Keyword parameters:
parameters
The parameters to be checked.
Might be a string, a list/tuple of strings, or None.
parameter_description
A short description of the parameter.
This will be used to raise an exception in case the parameter is not valid.
valid_parameters
A list/tuple of valid values for the parameters.
default_parameters
The list/tuple of default parameters that will be returned in case parameters is None or empty.
If omitted, all valid_parameters are used.
"""
if parameters is None:
# parameters are not specified, i.e., 'None' or empty lists
parameters = default_parameters if default_parameters is not None else valid_parameters
if not isinstance(parameters, (list, tuple, set)):
# parameter is just a single element, not a tuple or list -> transform it into a tuple
parameters = (parameters,)
with gzip.open(fname, 'rb') as f:
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# perform the checks
for parameter in parameters:
if parameter not in valid_parameters:
raise ValueError("Invalid %s '%s'. Valid values are %s, or lists/tuples of those" % (parameter_description, parameter, valid_parameters))
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
# check passed, now return the list/tuple of parameters
return parameters
def labels(self):
"""Returns the vector of labels
"""
return self.m_labels
return self._labels
def groups(self):
"""Returns the vector of groups
"""
return self.m_groups
return self._groups
def data(self, groups=None, labels=None):
"""Loads the MNIST samples and labels and returns them in NumPy arrays
Keyword Parameters:
groups
One of the groups 'train' or 'test' or a list with both of them (which is the default).
Parameters:
groups (:py:class:`str` or :py:class:`list`): One of the groups ``train``
or ``test``, or a list with both of them (which is the default)
labels (:py:class:`int` or :py:class:`list`): A subset of the labels
(digits 0 to 9) (everything is the default)
labels
A subset of the labels (digits 0 to 9) (everything is the default).
Returns:
Returns: A tuple composed of images and labels as 2D numpy arrays considering
all the filtering criteria and organized as follows:
images (numpy.ndarray): A 2D array with as many rows as examples in the
dataset, as many columns as pixels (actually, there are 28x28 = 784
rows). The pixels of each image are unrolled in C-scan order (i.e., first
row 0, then row 1, etc.)
images
A 2D numpy.ndarray with as many rows as examples in the dataset, as many
columns as pixels (actually, there are 28x28 = 784 rows). The pixels of each
image are unrolled in C-scan order (i.e., first row 0, then row 1, etc.).
labels (numpy.ndarray): A 1D array with as many elements as examples in
the dataset
labels
A 1D numpy.ndarray with as many elements as examples in the dataset.
"""
# check if groups set are valid
groups = self._check_parameters_for_validity(groups, "group", self.m_groups)
vlabels = self._check_parameters_for_validity(labels, "label", self.m_labels)
groups = check_parameters_for_validity(groups, "group", self._groups)
vlabels = check_parameters_for_validity(labels, "label", self._labels)
# Reads data from the groups
import numpy
if 'train' in groups and 'test' in groups:
images1 = self._read_images(self.m_train_fname_images)
labels1 = self._read_labels(self.m_train_fname_labels)
images2 = self._read_images(self.m_test_fname_images)
labels2 = self._read_labels(self.m_test_fname_labels)
images1 = self._read_images(self.train_images)
labels1 = self._read_labels(self.train_labels)
images2 = self._read_images(self.test_images)
labels2 = self._read_labels(self.test_labels)
images = numpy.vstack([images1,images2])
labels = numpy.hstack([labels1,labels2])
elif 'train' in groups:
images = self._read_images(self.m_train_fname_images)
labels = self._read_labels(self.m_train_fname_labels)
images = self._read_images(self.train_images)
labels = self._read_labels(self.train_labels)
elif 'test' in groups:
images = self._read_images(self.m_test_fname_images)
labels = self._read_labels(self.m_test_fname_labels)
images = self._read_images(self.test_images)
labels = self._read_labels(self.test_labels)
else:
images = numpy.ndarray(shape=(0,784), dtype=numpy.uint8)
labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8)
......@@ -226,4 +132,3 @@ class Database():
labels = labels[indices]
return images, labels
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Laurent El Shafey <laurent.el-shafey@idiap.ch>
#
# Copyright (C) 2011-2013 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""A few checks at the MNIST database.
"""
import unittest
import bob.db.mnist
from . import Database
def db_available(test):
"""Decorator for detecting if we're running the test at Idiap"""
import os
import functools
from nose.plugins.skip import SkipTest
@functools.wraps(test)
def wrapper(*args, **kwargs):
from .driver import Interface
f = Interface().files()
for k in f:
if not os.path.exists(k):
raise SkipTest("Raw database files are not available")
return test(*args, **kwargs)
return wrapper
@db_available
def test_query():
db = bob.db.mnist.Database()
db = Database()
f = db.labels()
assert len(f) == 10 # number of labels (digits 0 to 9)
......@@ -48,19 +57,3 @@ def test_query():
assert d.shape[0] == 70000
assert d.shape[1] == 784
assert l.shape[0] == 70000
def test_download():
# tests that the files are downloaded *and stored*, when the directory is specified
import tempfile, os, shutil
temp_dir = tempfile.mkdtemp(prefix='mnist_db_test_')
db = bob.db.mnist.Database(temp_dir)
del db
assert os.path.exists(temp_dir)
# check that the database works when data is downloaded already
db = bob.db.mnist.Database(temp_dir)
assert db.data() is not None
del db
shutil.rmtree(temp_dir)
##############################################################################
#
# Copyright (c) 2006 Zope Foundation and Contributors.
# All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE.
#
##############################################################################
"""Bootstrap a buildout-based project
Simply run this script in a directory containing a buildout.cfg.
The script accepts buildout command-line options, so you can
use the -c option to specify an alternate configuration file.
"""
import os
import shutil
import sys
import tempfile
from optparse import OptionParser
__version__ = '2015-07-01'
# See zc.buildout's changelog if this version is up to date.
tmpeggs = tempfile.mkdtemp(prefix='bootstrap-')
usage = '''\
[DESIRED PYTHON FOR BUILDOUT] bootstrap.py [options]
Bootstraps a buildout-based project.
Simply run this script in a directory containing a buildout.cfg, using the
Python that you want bin/buildout to use.
Note that by using --find-links to point to local resources, you can keep
this script from going over the network.
'''
parser = OptionParser(usage=usage)
parser.add_option("--version",
action="store_true", default=False,
help=("Return bootstrap.py version."))
parser.add_option("-t", "--accept-buildout-test-releases",
dest='accept_buildout_test_releases',
action="store_true", default=False,
help=("Normally, if you do not specify a --version, the "
"bootstrap script and buildout gets the newest "
"*final* versions of zc.buildout and its recipes and "
"extensions for you. If you use this flag, "
"bootstrap and buildout will get the newest releases "
"even if they are alphas or betas."))
parser.add_option("-c", "--config-file",
help=("Specify the path to the buildout configuration "
"file to be used."))
parser.add_option("-f", "--find-links",
help=("Specify a URL to search for buildout releases"))
parser.add_option("--allow-site-packages",
action="store_true", default=False,
help=("Let bootstrap.py use existing site packages"))
parser.add_option("--buildout-version",
help="Use a specific zc.buildout version")
parser.add_option("--setuptools-version",
help="Use a specific setuptools version")
parser.add_option("--setuptools-to-dir",
help=("Allow for re-use of existing directory of "
"setuptools versions"))
options, args = parser.parse_args()
if options.version:
print("bootstrap.py version %s" % __version__)
sys.exit(0)
######################################################################
# load/install setuptools