putvein.py 5.27 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
# vim: set fileencoding=utf-8 :

"""
PUTVEIN database implementation of bob.bio.db.BioDatabase interface.
It is an extension of low level database interface, which directly talks to
PUTVEIN database for verification experiments (good to use in bob.bio.base
framework).
"""

from bob.bio.base.database import BioFile, BioDatabase
import bob.ip.color
import numpy as np


class File(BioFile):
    """
    Implements extra properties of vein files for the PUTVEIN database

    Parameters:

      f (object): Low-level file (or sample) object that is kept inside
    """
    def __init__(self, f):
        super(File, self).__init__(client_id=f.client_id,
                                   path=f.path,
                                   file_id=f.id)

        self.f = f

    def load(self, directory=None, extension='.bmp'):
        """
        The image returned by the ``bob.db.putvein`` is RGB (with shape
        (3, 768, 1024)). This method converts image to a greyscale (shape
        (768, 1024)) and then rotates image by 270 deg so that images can be
        used with ``bob.bio.vein`` algorythms designed for the
        ``bob.db.biowave_v1`` database.
        Output images dimentions - (1024, 768).
        """
        color_image = self.f.load(directory=directory,
                                  extension=extension)
        grayscale_image = bob.ip.color.rgb_to_gray(color_image)
        grayscale_image = np.rot90(grayscale_image, k=3)
        return grayscale_image


class PutveinBioDatabase(BioDatabase):
    """
    Implements verification API for querying PUTVEIN database.
    This class allows to use the following protocols:

    palm-L_1
    palm-LR_1
    palm-R_1
    palm-RL_1
    palm-R_BEAT_1

    palm-L_4
    palm-LR_4
    palm-R_4
    palm-RL_4
    palm-R_BEAT_4

    wrist-L_1
    wrist-LR_1
    wrist-R_1
    wrist-RL_1
    wrist-R_BEAT_1

    wrist-L_4
    wrist-LR_4
    wrist-R_4
    wrist-RL_4
    wrist-R_BEAT_4
    """

    def __init__(self, **kwargs):

        super(PutveinBioDatabase, self).__init__(name='putvein', **kwargs)

        from bob.db.putvein.query import Database as LowLevelDatabase
        self.__db = LowLevelDatabase()

    def __protocol_split__(self, prot_name):
        """
        Overrides the "high level" database names (see the list abowe) to the
        low level ``protocols`` (currently there are 8 low level protocols:
            L_1;
            LR_1;
            R_1;
            RL_1;
            R_BEAT_1;
            L_4;
            LR_4;
            R_4;
            RL_4;
            R_BEAT_4;
        And the kinds - wrist or palm.
        The low level protocols are derived from the original 4:
            L;
            R;
            LR;
            RL;
        please read the ``bob.db.putvein`` documentation.
        """
        allowed_prot_names = ["palm-L_1",
                              "palm-LR_1",
                              "palm-R_1",
                              "palm-RL_1",
                              "palm-R_BEAT_1",
                              "palm-L_4",
                              "palm-LR_4",
                              "palm-R_4",
                              "palm-RL_4",
                              "palm-R_BEAT_4",
                              "wrist-L_1",
                              "wrist-LR_1",
                              "wrist-R_1",
                              "wrist-RL_1",
                              "wrist-R_BEAT_1",
                              "wrist-L_4",
                              "wrist-LR_4",
                              "wrist-R_4",
                              "wrist-RL_4",
                              "wrist-R_BEAT_4"]

        if prot_name not in allowed_prot_names:
            raise IOError("Protocol name {} not allowed. Allowed names - {}".\
                          format(prot_name, allowed_prot_names))

        kind, prot = prot_name.split("-")

        return kind, prot

    def client_id_from_model_id(self, model_id, group='dev'):
        """Required as ``model_id != client_id`` on this database"""
        return self.__db.client_id_from_model_id(model_id)


    def model_ids_with_protocol(self, groups=None, protocol=None, **kwargs):
        """model_ids_with_protocol(groups = None, protocol = None, **kwargs) -> ids

        Returns a list of model ids for the given groups and given protocol.

        **Parameters:**

        groups : one or more of ``('world', 'dev', 'eval')``
          The groups to get the model ids for.

        protocol: a protocol name

        **Returns:**

        ids : [int]
          The list of (unique) model ids for the given groups.
        """
        kind, prot = self.__protocol_split__(protocol)

        return self.__db.model_ids(protocol=prot,
                                   groups=groups,
                                   kinds=kind)


    def objects(self, protocol=None, groups=None, purposes=None, model_ids=None, kinds=None, **kwargs):

        kind, prot = self.__protocol_split__(protocol)
        retval = self.__db.objects(protocol=prot,
                                   groups=groups,
                                   purposes=purposes,
                                   model_ids=model_ids,
                                   kinds=kind)
        return [File(f) for f in retval]


    def annotations(self, file):
        return None