casiasurf.py 6.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import numpy as np
import bob.io.video
from bob.bio.video import FrameSelector, FrameContainer
from bob.pad.face.database import VideoPadFile  
from bob.pad.base.database import PadDatabase

12
13
from bob.extension import rc

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class CasiaSurfPadFile(VideoPadFile):
    """
    A high level implementation of the File class for the CASIA-SURF database.

    Note that this does not represent a file per se, but rather a sample
    that may contain more than one file.

    Attributes
    ----------
    f : :py:class:`object`
      An instance of the Sample class defined in the low level db interface
      of the CASIA-SURF database, in the bob.db.casiasurf.models.py file.
    
    """

29
    def __init__(self, s, stream_type):
30
31
32
33
      """ Init

      Parameters
      ----------
34
      s : :py:class:`object`
35
36
37
38
39
        An instance of the Sample class defined in the low level db interface
        of the CASIA-SURF database, in the bob.db.casiasurf.models.py file.
      stream_type: str of list of str
        The streams to be loaded.
      """
40
      self.s = s
41
      self.stream_type = stream_type
42
43
44
45
46
      if not isinstance(s.attack_type, str):
        attack_type = str(s.attack_type)
      else:
        attack_type = s.attack_type

47
48
49
      if attack_type == '0':
        s.attack_type = None

50
      super(CasiaSurfPadFile, self).__init__(
51
            client_id=s.id,
52
53
            file_id=s.id,
            attack_type=attack_type,
54
            path=s.id)
55
      
56

57
    def load(self, directory=rc['bob.db.casiasurf.directory'], extension='.jpg', frame_selector=FrameSelector(selection_style='all')):
58
        """Overloaded version of the load method defined in ``VideoPadFile``.
59
60
61
62
63
64
65
66
67
68
69
70
71

        Parameters
        ----------
        directory : :py:class:`str`
          String containing the path to the CASIA-SURF database 
        extension : :py:class:`str`
          Extension of the image files 
        frame_selector : :py:class:`bob.bio.video.FrameSelector`
            The frame selector to use.

        Returns
        -------
        dict:
72
73
74
          image data for multiple streams stored in the dictionary. 
          The structure of the dictionary: ``data={"stream1_name" : numpy array, "stream2_name" : numpy array}``
          Names of the streams are defined in ``self.stream_type``.
75
        """
76
        return self.s.load(directory, extension, modality=self.stream_type)
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95


class CasiaSurfPadDatabase(PadDatabase): 
    """High level implementation of the Database class for the 3DMAD database.
   
    Note that at the moment, this database only contains a training and validation set.

    The protocol specifies the modality(ies) to load.

    Attributes
    ----------
    db : :py:class:`bob.db.casiasurf.Database`
      the low-level database interface
    low_level_group_names : list of :py:obj:`str`
      the group names in the low-level interface (world, dev, test)
    high_level_group_names : list of :py:obj:`str`
      the group names in the high-level interface (train, dev, eval)

    """
96
       
97
    def __init__(self, protocol='all', original_directory=rc['bob.db.casiasurf.directory'], original_extension='.jpg', **kwargs):
98
99
100
101
102
103
104
105
106
107
      """Init function

        Parameters
        ----------
        protocol : :py:class:`str`
          The name of the protocol that defines the default experimental setup for this database.
        original_directory : :py:class:`str`
          The directory where the original data of the database are stored.
        original_extension : :py:class:`str`
          The file name extension of the original data.
108
        
109
110
111
112
      """

      from bob.db.casiasurf import Database as LowLevelDatabase
      self.db = LowLevelDatabase()
113

114
115
      self.low_level_group_names = ('train', 'validation', 'test')  
      self.high_level_group_names = ('train', 'dev', 'eval')
116

117
118
119
120
121
122
      super(CasiaSurfPadDatabase, self).__init__(
          name='casiasurf',
          protocol=protocol,
          original_directory=original_directory,
          original_extension=original_extension,
          **kwargs)
123

124
    @property
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    def original_directory(self):
        return self.db.original_directory


    @original_directory.setter
    def original_directory(self, value):
        self.db.original_directory = value

    def objects(self,
                groups=None,
                protocol='all',
                purposes=None,
                model_ids=None,
                **kwargs):
        """Returns a list of CasiaSurfPadFile objects, which fulfill the given restrictions.

        Parameters
        ----------
        groups : list of :py:class:`str`
          The groups of which the clients should be returned.
          Usually, groups are one or more elements of ('train', 'dev', 'eval')
        protocol : :py:class:`str`
          The protocol for which the samples should be retrieved.
        purposes : :py:class:`str`
          The purposes for which Sample objects should be retrieved.
150
          Usually it is either 'real' or 'attack'
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        model_ids
          This parameter is not supported in PAD databases yet.

        Returns
        -------
        samples : :py:class:`CasiaSurfPadFilePadFile`
            A list of CasiaSurfPadFile objects.
        """

        groups = self.convert_names_to_lowlevel(groups, self.low_level_group_names, self.high_level_group_names)

        if groups is not None:
          
          # for training
          lowlevel_purposes = []
166
          if 'train' in groups and 'real' in purposes:
167
            lowlevel_purposes.append('real') 
168
          if 'train' in groups and 'attack' in purposes:
169
170
171
            lowlevel_purposes.append('attack') 

          # for dev and eval
172
          if ('validation' in groups or 'test' in groups) and 'attack' in purposes:
173
174
            lowlevel_purposes.append('unknown')

175
        samples = self.db.objects(groups=groups, purposes=lowlevel_purposes, **kwargs)
176
        samples = [CasiaSurfPadFile(s, stream_type=protocol) for s in samples]
177
178
179
180
181
182
183
        return samples

    
    def annotations(self, file):
        """No annotations are provided with this DB
        """
        return None