Skip to content
Snippets Groups Projects
Commit 950a64e8 authored by Samuel GAIST's avatar Samuel GAIST
Browse files

[advanced][databases][mnist] Add V2 version of mnist database

This new version also fixes a bug making in not working
on the platform.
parent e41651c7
No related branches found
No related tags found
1 merge request!29Add V2 version of mnist database
{
"description": "The MNIST Database of Handwritten Digits",
"root_folder": "/idiap/group/biometric/databases/mnist",
"protocols": [
{
"name": "idiap",
"template": "simple_digit_recognition/1",
"views": {
"train": {
"view": "View",
"parameters": {
"group": "train"
}
},
"test": {
"view": "View",
"parameters": {
"group": "test"
}
}
}
}
],
"schema_version": 2
}
\ No newline at end of file
###############################################################################
# #
# Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.examples module of the BEAT platform. #
# #
# Commercial License Usage #
# Licensees holding valid commercial BEAT licenses may use this file in #
# accordance with the terms contained in a written agreement between you #
# and Idiap. For further information contact tto@idiap.ch #
# #
# Alternatively, this file may be used under the terms of the GNU Affero #
# Public License version 3 as published by the Free Software and appearing #
# in the file LICENSE.AGPL included in the packaging of this file. #
# The BEAT platform is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY #
# or FITNESS FOR A PARTICULAR PURPOSE. #
# #
# You should have received a copy of the GNU Affero Public License along #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/. #
# #
###############################################################################
import numpy as np
from collections import namedtuple
from beat.backend.python.database import View as BaseView
import bob.db.mnist
#----------------------------------------------------------
class View(BaseView):
"""Outputs:
- image: "{{ system_user.username }}/array_2d_uint8/1"
- id: "{{ system_user.username }}/uint64/1"
- class_id: "{{ system_user.username }}/uint64/1"
One "id" is associated with a given "image".
Several "image" are associated with a given "class_id".
--------------- --------------- --------------- --------------- --------------- ---------------
| image | | image | | image | | image | | image | | image |
--------------- --------------- --------------- --------------- --------------- ---------------
--------------- --------------- --------------- --------------- --------------- ---------------
| id | | id | | id | | id | | id | | id |
--------------- --------------- --------------- --------------- --------------- ---------------
----------------------------------------------- -----------------------------------------------
| class_id | | class_id |
----------------------------------------------- -----------------------------------------------
"""
def index(self, root_folder, parameters):
Entry = namedtuple('Entry', ['class_id', 'id', 'image'])
# Open the database and load the objects to provide via the outputs
db = bob.db.mnist.Database()
features, labels = db.data(groups=parameters['group'])
objs = sorted([ (labels[i], i, features[i]) for i in range(len(features)) ],
key=lambda x: (x[0], x[1]))
return [ Entry(x[0], x[1], x[2]) for x in objs ]
def get(self, output, index):
obj = self.objs[index]
if output == 'class_id':
return {
'value': np.uint64(obj.class_id)
}
elif output == 'id':
return {
'value': np.uint64(obj.id)
}
elif output == 'image':
return {
'value': obj.image.reshape((28, 28))
}
#----------------------------------------------------------
def setup_tests():
pass
#----------------------------------------------------------
# Test the behavior of the views (on fake data)
if __name__ == '__main__':
setup_tests()
# Note: This database can't be tested without the actual data, since
# the actual files are needed by this implementation
view = View()
view.objs = view.index(root_folder='', parameters=dict(group='train'))
view.get('class_id', 0)
view.get('id', 0)
view.get('image', 0)
The MNIST Database of Handwritten Digits
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment