diff --git a/advanced/databases/mnist/5.json b/advanced/databases/mnist/5.json new file mode 100644 index 0000000000000000000000000000000000000000..d02296885ca83c2bbc0d73de665fa26d599b499d --- /dev/null +++ b/advanced/databases/mnist/5.json @@ -0,0 +1,25 @@ +{ + "description": "The MNIST Database of Handwritten Digits", + "root_folder": "/idiap/group/biometric/databases/mnist", + "protocols": [ + { + "name": "idiap", + "template": "simple_digit_recognition/1", + "views": { + "train": { + "view": "View", + "parameters": { + "group": "train" + } + }, + "test": { + "view": "View", + "parameters": { + "group": "test" + } + } + } + } + ], + "schema_version": 2 +} \ No newline at end of file diff --git a/advanced/databases/mnist/5.py b/advanced/databases/mnist/5.py new file mode 100644 index 0000000000000000000000000000000000000000..a1bb0544a96e5fc3e799fcd21dd02ec6f3b12927 --- /dev/null +++ b/advanced/databases/mnist/5.py @@ -0,0 +1,111 @@ +############################################################################### +# # +# Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/ # +# Contact: beat.support@idiap.ch # +# # +# This file is part of the beat.examples module of the BEAT platform. # +# # +# Commercial License Usage # +# Licensees holding valid commercial BEAT licenses may use this file in # +# accordance with the terms contained in a written agreement between you # +# and Idiap. For further information contact tto@idiap.ch # +# # +# Alternatively, this file may be used under the terms of the GNU Affero # +# Public License version 3 as published by the Free Software and appearing # +# in the file LICENSE.AGPL included in the packaging of this file. # +# The BEAT platform is distributed in the hope that it will be useful, but # +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY # +# or FITNESS FOR A PARTICULAR PURPOSE. # +# # +# You should have received a copy of the GNU Affero Public License along # +# with the BEAT platform. If not, see http://www.gnu.org/licenses/. # +# # +############################################################################### + +import numpy as np +from collections import namedtuple + +from beat.backend.python.database import View as BaseView + +import bob.db.mnist + + +#---------------------------------------------------------- + + +class View(BaseView): + """Outputs: + - image: "{{ system_user.username }}/array_2d_uint8/1" + - id: "{{ system_user.username }}/uint64/1" + - class_id: "{{ system_user.username }}/uint64/1" + + One "id" is associated with a given "image". + Several "image" are associated with a given "class_id". + + --------------- --------------- --------------- --------------- --------------- --------------- + | image | | image | | image | | image | | image | | image | + --------------- --------------- --------------- --------------- --------------- --------------- + --------------- --------------- --------------- --------------- --------------- --------------- + | id | | id | | id | | id | | id | | id | + --------------- --------------- --------------- --------------- --------------- --------------- + ----------------------------------------------- ----------------------------------------------- + | class_id | | class_id | + ----------------------------------------------- ----------------------------------------------- + """ + + def index(self, root_folder, parameters): + Entry = namedtuple('Entry', ['class_id', 'id', 'image']) + + # Open the database and load the objects to provide via the outputs + db = bob.db.mnist.Database() + + features, labels = db.data(groups=parameters['group']) + + objs = sorted([ (labels[i], i, features[i]) for i in range(len(features)) ], + key=lambda x: (x[0], x[1])) + + return [ Entry(x[0], x[1], x[2]) for x in objs ] + + + def get(self, output, index): + obj = self.objs[index] + + if output == 'class_id': + return { + 'value': np.uint64(obj.class_id) + } + + elif output == 'id': + return { + 'value': np.uint64(obj.id) + } + + elif output == 'image': + return { + 'value': obj.image.reshape((28, 28)) + } + + +#---------------------------------------------------------- + + +def setup_tests(): + pass + + +#---------------------------------------------------------- + + +# Test the behavior of the views (on fake data) +if __name__ == '__main__': + + setup_tests() + + # Note: This database can't be tested without the actual data, since + # the actual files are needed by this implementation + + view = View() + view.objs = view.index(root_folder='', parameters=dict(group='train')) + view.get('class_id', 0) + view.get('id', 0) + view.get('image', 0) diff --git a/advanced/databases/mnist/5.rst b/advanced/databases/mnist/5.rst new file mode 100644 index 0000000000000000000000000000000000000000..2584c54faa06a406f71eb0586e54df25a8f8e659 --- /dev/null +++ b/advanced/databases/mnist/5.rst @@ -0,0 +1 @@ +The MNIST Database of Handwritten Digits \ No newline at end of file