Commit e4f7eafd authored by Laurent EL SHAFEY's avatar Laurent EL SHAFEY
Browse files

Fix for old numpy/python version

parent 8fef6f1c
......@@ -9,7 +9,7 @@ from setuptools import setup, find_packages
setup(
name='xbob.db.mnist',
version='1.0.0f',
version='1.0.0g',
description='MNIST Database Access API for Bob',
url='http://github.com/bioidiap/xbob.db.mnist',
license='GPLv3',
......
......@@ -79,22 +79,22 @@ class Database():
def __read_labels__(self, fname):
"""Reads the labels from the original MNIST label binary file"""
import struct, gzip, numpy
with gzip.GzipFile(fname, 'rb') as f:
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
f = gzip.GzipFile(fname, 'rb')
# reads 2 big-ending integers
magic_nr, n_examples = struct.unpack(">II", f.read(8))
# reads the rest, using an uint8 dataformat (endian-less)
labels = numpy.fromstring(f.read(), dtype='uint8')
return labels
def __read_images__(self, fname):
"""Reads the images from the original MNIST label binary file"""
import struct, gzip, numpy
with gzip.GzipFile(fname, 'rb') as f:
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
f = gzip.GzipFile(fname, 'rb')
# reads 4 big-ending integers
magic_nr, n_examples, rows, cols = struct.unpack(">IIII", f.read(16))
shape = (n_examples, rows*cols)
# reads the rest, using an uint8 dataformat (endian-less)
images = numpy.fromstring(f.read(), dtype='uint8').reshape(shape)
return images
def __check_parameters_for_validity__(self, parameters, parameter_description, valid_parameters, default_parameters = None):
......@@ -194,7 +194,7 @@ class Database():
labels = numpy.ndarray(shape=(0,), dtype=numpy.uint8)
# List of indices for which the labels are in the list of requested labels
indices = numpy.where(numpy.in1d(labels, list(vlabels)))[0]
indices = numpy.where(numpy.array([v in vlabels for v in labels]))[0]
images = images[indices,:]
labels = labels[indices]
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment