Skip to content
Snippets Groups Projects
5.py 4.08 KiB
###############################################################################
#                                                                             #
# Copyright (c) 2018 Idiap Research Institute, http://www.idiap.ch/           #
# Contact: beat.support@idiap.ch                                              #
#                                                                             #
# This file is part of the beat.examples module of the BEAT platform.         #
#                                                                             #
# Commercial License Usage                                                    #
# Licensees holding valid commercial BEAT licenses may use this file in       #
# accordance with the terms contained in a written agreement between you      #
# and Idiap. For further information contact tto@idiap.ch                     #
#                                                                             #
# Alternatively, this file may be used under the terms of the GNU Affero      #
# Public License version 3 as published by the Free Software and appearing    #
# in the file LICENSE.AGPL included in the packaging of this file.            #
# The BEAT platform is distributed in the hope that it will be useful, but    #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY  #
# or FITNESS FOR A PARTICULAR PURPOSE.                                        #
#                                                                             #
# You should have received a copy of the GNU Affero Public License along      #
# with the BEAT platform. If not, see http://www.gnu.org/licenses/.           #
#                                                                             #
###############################################################################

import os
import numpy as np
from collections import namedtuple

from beat.backend.python.database import View

import bob.io.base
import bob.io.image
from bob.db.livdet2013 import Database


#----------------------------------------------------------


class All(View):
    """Outputs:
        - image: "{{ system_user.username }}/array_3d_uint8/1"
        - spoof: "{{ system_user.username }}/boolean/1"

    Several "image" are associated with a given "spoof".

    --------------- --------------- --------------- --------------- 
    |    image    | |    image    | |    image    | |    image    | 
    --------------- --------------- --------------- --------------- 
    ------------------------------- ------------------------------ 
    |            spoof            | |            spoof            |
    ------------------------------- ------------------------------ 
    """

    def index(self, root_folder, parameters):
        Entry = namedtuple('Entry', ['spoof', 'image'])

        # Open the database and load the objects to provide via the outputs
        db = Database()
        objs = sorted(db.objects(protocols=parameters.get('protocol'),
                                 groups=parameters['group'],
                                 classes=parameters.get('class')),
                      key=lambda x: x.is_live())

        return [ Entry(x.is_live(), x.make_path(root_folder)) for x in objs ]


    def get(self, output, index):
        obj = self.objs[index]

        if output == 'spoof':
            return {
                'value': obj.spoof
            }

        elif output == 'image':
            return {
                'value': bob.io.base.load(obj.image)
            }


#----------------------------------------------------------


def setup_tests():
    # Install a mock load function for the images
    def mock_load(root_folder):
        return np.ndarray((3, 10, 20), dtype=np.uint8)

    bob.io.base.load = mock_load


#----------------------------------------------------------


# Test the behavior of the views (on fake data)
if __name__ == '__main__':

    setup_tests()

    view = All()
    view.objs = view.index(
        root_folder='',
        parameters=dict(
            protocol='Biometrika',
            group='train'
        )
    )
    view.get('spoof', 0)
    view.get('image', 0)