batl.py 17 KB
Newer Older
1
2
3
4
5
6
7
8
#!/usr/bin/env python2
# -*- coding: utf-8 -*-

# Used in BATLMobilePadFile class
from bob.pad.base.database import PadDatabase, PadFile
from bob.bio.video import FrameSelector
from bob.extension import rc

9
10
from bob.pad.face.preprocessor.FaceCropAlign import detect_face_landmarks_in_image

11
12
import json

13
14
import os

15
16
import bob.io.base

17

18
class BatlPadFile(PadFile):
19
20
21
22
23
24
    """
    A high level implementation of the File class for the BATL
    database.
    """

    def __init__(self, f,
25
                 stream_type,  # a list of streams to be loaded
26
27
28
29
                 max_frames,
                 reference_stream_type="color",
                 warp_to_reference=True,
                 convert_to_rgb=False,
30
31
                 crop=None,
                 video_data_only=True):
32
33

        """
34
35
36
        **Parameters:**

        ``f`` : :py:class:`object`
37
            An instance of the File class defined in the low level db interface
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
            of the BATL database, in the ``bob.db.batl.models.py`` file.

        ``stream_type`` : [] or :py:class:`str`
            A types of the streams to be loaded.

        ``max_frames`` : :py:class:`int`
            A maximum number of frames to be loaded. Frames are
            selected uniformly.

        ``reference_stream_type`` : :py:class:`str`
            Align/register all channels to this one.
            Default: "color".

        ``warp_to_reference`` : :py:class:`bool`
            Align/register downloaded video to ``reference_stream_type``,
            if set to ``True``.
            Default: ``True``.

        ``convert_to_rgb`` : :py:class:`bool`
            Type cast the non-RGB data to RGB data type,
            if set to ``True``.
            Default: ``False``.

        ``crop`` : []
            Pre-crop the frames if given, see ``bob.db.batl`` for more
            details.
            Default: ``None``.

        ``video_data_only`` : :py:class:`bool`
            Load video data only if ``True``. Otherwise more meta-data
            is loaded, for example timestamps for each frame.
            See the ``load()`` method in the low-level database
            interface for more details.
            Default: ``True``.
72
73
74
75
        """

        self.f = f
        if f.is_attack():
76
            attack_type = 'attack'
77
78
79
        else:
            attack_type = None

80
        super(BatlPadFile, self).__init__(
81
82
83
84
85
            client_id=f.client_id,
            path=f.path,
            attack_type=attack_type,
            file_id=f.id)

86
        self.stream_type = stream_type
87
88
89
90
91
        self.max_frames = max_frames
        self.reference_stream_type = reference_stream_type  # "color"
        self.warp_to_reference = warp_to_reference  # True
        self.convert_to_rgb = convert_to_rgb  # False
        self.crop = crop  # None
92
        self.video_data_only = video_data_only  # True
93

94
95
96
97
    def load(self, directory=None, extension='.h5',
             frame_selector=FrameSelector(selection_style='all')):
        """
        Load method of the file class.
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        **Parameters:**

        ``directory`` : :py:class:`str`
            String containing the path to BATL database.
            Default: ``None``.

        ``extension`` : :py:class:`str`
            Extension of the BATL database.
            Default: ".h5".

        ``frame_selector`` : :any:`bob.bio.video.FrameSelector`, optional
            Specifying the frames to be selected.

        **Returns:**

        ``data`` : FrameContainer
            Video data stored in the FrameContainer,
            see ``bob.bio.video.utils.FrameContainer``
            for further details.
        """
119

120
121
        data = self.f.load(directory=directory,
                           extension=extension,
122
                           modality=self.stream_type,
123
124
125
126
127
                           reference_stream_type=self.reference_stream_type,
                           warp_to_reference=self.warp_to_reference,
                           convert_to_rgb=self.convert_to_rgb,
                           crop=self.crop,
                           max_frames=self.max_frames)
128

129
130
131
        for meta_data in data.keys():
            if meta_data != 'rppg':
                data[meta_data] = frame_selector(data[meta_data])
132
133
134
135
136

        if self.video_data_only:

            data = data['video']

137
138
139
        return data


140
class BatlPadDatabase(PadDatabase):
141
142
143
144
145
146
147
    """
    A high level implementation of the Database class for the BATL
    database.
    """

    def __init__(
            self,
148
            protocol='nowig',
149
150
            original_directory=rc['bob.db.batl.directory'],
            original_extension='.h5',
151
            annotations_temp_dir="",
152
            landmark_detect_method="mtcnn",
153
154
            **kwargs):
        """
155
        **Parameters:**
156

157
        ``protocol`` : str or None
158
            The name of the protocol that defines the default experimental
159
160
161
            setup for this database. Also a "complex" protocols can be
            parsed.
            For example:
162
163
164
165
            "nowig-color-5" - nowig protocol, color data only,
            use 5 first frames.
            "nowig-depth-5" - nowig protocol, depth data only,
            use 5 first frames.
166
            "nowig-color" - nowig protocol, depth data only, use all frames.
167
168
169
            "nowig-infrared-50-join_train_dev" - nowig protocol,
            infrared data only, use 50 frames, join train and dev sets forming
            a single large training set.
170
            See the ``parse_protocol`` method of this class.
171

172
        ``original_directory`` : str
173
174
            The directory where the original data of the database are stored.

175
        ``original_extension`` : str
176
177
            The file name extension of the original data.

178
179
180
181
182
183
184
185
186
187
188
189
        ``annotations_temp_dir`` : str
            Annotations computed in ``self.annotations(f)`` method of this
            class will be save to this directory if path is specified /
            non-empty string.
            Default: ``""``.

        ``landmark_detect_method`` : str
            Method to be used to compute annotations - face bounding box and
            landmarks. Possible options: "dlib" or "mtcnn".
            Default: ``"mtcnn"``.

        ``kwargs`` : dict
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
            The arguments of the :py:class:`bob.bio.base.database.BioDatabase`
            base class constructor.
        """

        from bob.db.batl import Database as LowLevelDatabase

        self.db = LowLevelDatabase()

        # Since the high level API expects different group names than what the
        # low level API offers, you need to convert them when necessary
        self.low_level_group_names = (
            'train', 'validation',
            'test')  # group names in the low-level database interface
        self.high_level_group_names = (
            'train', 'dev',
            'eval')  # names are expected to be like that in objects() function

        # Always use super to call parent class methods.
208
        super(BatlPadDatabase, self).__init__(
209
210
211
212
213
214
            name='batl',
            protocol=protocol,
            original_directory=original_directory,
            original_extension=original_extension,
            **kwargs)

215
216
217
218
219
220
        self.protocol = protocol
        self.original_directory = original_directory
        self.original_extension = original_extension
        self.annotations_temp_dir = annotations_temp_dir
        self.landmark_detect_method = landmark_detect_method

221
222
223
224
225
226
227
228
    @property
    def original_directory(self):
        return self.db.original_directory

    @original_directory.setter
    def original_directory(self, value):
        self.db.original_directory = value

229
230
231
232
    def parse_protocol(self, protocol):
        """
        Parse the protocol name, which is give as a string.
        An example of protocols it can parse:
233
234
235
        "nowig-color-5" - nowig protocol, color data only, use 5 first frames.
        "nowig-depth-5" - nowig protocol, depth data only, use 5 first frames.
        "nowig-color" - nowig protocol, depth data only, use all frames.
236
237
238
239

        **Parameters:**

        ``protocol`` : str
240
            Protocol name to be parsed. Example: "nowig-depth-5" .
241
242
243
244
245
246

        **Returns:**

        ``protocol`` : str
            The name of the protocol as defined in the low level db interface.

247
        ``stream_type`` : str
248
249
250
251
            The name of the channel/stream_type to be loaded.

        ``max_frames`` : int
            The number of frames to be loaded.
252
253
254
255
256
257
258
259

        ``extra`` : str
            An extra string which is handled in ``self.objects()`` method.
            Extra strings which are currently handled are defined in
            ``possible_extras`` of this function.
            For example, if ``extra="join_train_dev"``, the train and dev
            sets will be joined in ``self.objects()``,
            forming a single training set.
260
261
        """

262
263
        possible_extras = ['join_train_dev']

264
265
        components = protocol.split("-")

266
267
268
269
270
271
272
273
        extra = [item for item in possible_extras if item in components]

        extra = extra[0] if extra else None

        if extra is not None:
            components.remove(extra)

        components += [None, None]
274
275
276

        components = components[0:3]

277
        protocol, stream_type, max_frames = components
278
279
280
281
282

        if max_frames is not None:

            max_frames = int(max_frames)

283
284
285
286
287
288
289
290
        return protocol, stream_type, max_frames, extra

    def _fix_funny_eyes_in_objects(self, protocol, groups, purposes):
        """
        This function redistributes FunnyEyes PAs accross 'train', 'dev' and
        'eval' sets in the following way.

        Original (low-level DB) distribution is as follows:
291
292
293
        'train' = N1
        'dev' = N2
        'eval' = N3
294
295

        After this function is applied the distribution is:
296
297
298
        'train' = N1 + 1/2*N2
        'dev' = N2 - 1/2*N2
        'eval' = N3
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

        **Parameters:**

        ``protocol`` : str
            The protocol for which the clients should be retrieved.

        ``groups`` : :py:class:`str`
            OR a list of strings.
            The groups of which the clients should be returned.
            Usually, groups are one or more elements of ('train', 'dev', 'eval')

        ``purposes`` : :obj:`str` or [:obj:`str`]
            The purposes for which File objects should be retrieved.
            Usually it is either 'real' or 'attack'.

        **Returns:**

        ``files`` : [VideoFile]
            A list of VideoFile objects defined in BATL Low Level Database
            Interface.
        """

        if groups is None:
            groups = self.low_level_group_names

        files_train = []
        files_dev = []
        files_eval = []

        if groups == 'train' or 'train' in groups:

            files_train = self.db.objects(protocol=protocol, groups='train', purposes=purposes)

            files_to_append = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)

            exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes

336
337
338
            files_to_append = [f for f in files_to_append if f.path[-5:] in exclude]

            files_to_append = files_to_append[0:int(len(files_to_append)/2)] # append HALF of files from "dev" to "train" set
339
340
341
342
343
344
345
346
347

            files_train = files_train + files_to_append

        if groups == 'validation' or 'validation' in groups:

            files_dev = self.db.objects(protocol=protocol, groups='validation', purposes=purposes)

            exclude = ["_1_01", "_1_04", "_1_05", "_1_06", "_1_07"] # files ending with these paths relate to FunnyEyes

348
349
350
            files_to_append_1 = [f for f in files_dev if f.path[-5:] in exclude] # "dev" files containing FunnyEyes

            files_to_append_1 = files_to_append_1[-int(len(files_to_append_1)/2):] # second HALF of "dev" files containing FunnyEyes
351
352
353
354
355
356
357
358
359
360
361
362

            files_to_append_2 = [f for f in files_dev if f.path[-5:] not in exclude] # "dev" set without FunnyEyes

            files_dev = files_to_append_1 + files_to_append_2

        if groups == 'test' or 'test' in groups:

            files_eval = self.db.objects(protocol=protocol, groups='test', purposes=purposes) # this group remain unchanged

        files = files_train + files_dev + files_eval

        return files
363

364
365
    def objects(self,
                protocol=None,
366
                groups=None,
367
                purposes=None,
368
                model_ids=None,
369
370
                **kwargs):
        """
371
        This function returns lists of BatlPadFile objects, which fulfill the
372
373
        given restrictions.

374
375
376
        **Parameters:**

        ``protocol`` : str
377
378
379
380
            The protocol for which the clients should be retrieved.
            The protocol is dependent on your database.
            If you do not have protocols defined, just ignore this field.

381
        ``purposes`` : :obj:`str` or [:obj:`str`]
382
383
384
            The purposes for which File objects should be retrieved.
            Usually it is either 'real' or 'attack'.

385
        ``model_ids``
386
387
            This parameter is not supported in PAD databases yet

388
389
390
        **Returns:**

        ``files`` : [BatlPadFile]
391
392
393
            A list of BATLPadFile objects.
        """

394
395
396
        if protocol is None:
            protocol = self.protocol

397
398
399
400
401
402
        if groups is None:
            groups = self.high_level_group_names

        if purposes is None:
            purposes = ['real', 'attack']

403
        protocol, stream_type, max_frames, extra = self.parse_protocol(protocol)
404

405
406
407
        # Convert group names to low-level group names here.
        groups = self.convert_names_to_lowlevel(
            groups, self.low_level_group_names, self.high_level_group_names)
408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        if not isinstance(groups, list) and groups is not None:  # if a single group is given make it a list
            groups = list(groups)

        if extra is not None and "join_train_dev" in extra:

            if groups == ['train']: # join "train" and "dev" sets
                files = self.db.objects(protocol=protocol,
                                        groups=['train', 'validation'],
                                        purposes=purposes, **kwargs)

            # return ALL data if "train" and "some other" set/sets are requested
            elif len(groups)>=2 and 'train' in groups:
                files = self.db.objects(protocol=protocol,
                                        groups=self.low_level_group_names,
                                        purposes=purposes, **kwargs)

            # addresses the cases when groups=['validation'] or ['test'] or ['validation', 'test']:
            else:
                files = self.db.objects(protocol=protocol,
                                        groups=['test'],
                                        purposes=purposes, **kwargs)

        else:
            files = self._fix_funny_eyes_in_objects(protocol=protocol,
                                                    groups=groups,
                                                    purposes=purposes, **kwargs)
435

436
        files = [BatlPadFile(f, stream_type, max_frames) for f in files]
437
438
439
        return files

    def annotations(self, f):
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
        """
        Computes annotations for a given file object ``f``, which
        is an instance of the ``BatlPadFile`` class.

        NOTE: you can pre-compute annotation in your first experiment
        and then reuse them in other experiments setting
        ``self.annotations_temp_dir`` path of this class, where
        precomputed annotations will be saved.

        **Parameters:**

        ``f`` : :py:class:`object`
            An instance of ``BatlPadFile`` defined above.

        **Returns:**

        ``annotations`` : :py:class:`dict`
            A dictionary containing annotations for
            each frame in the video.
            Dictionary structure:
            ``annotations = {'1': frame1_dict, '2': frame1_dict, ...}``.
            Where
            ``frameN_dict`` contains coordinates of the
            face bounding box and landmarks in frame N.
        """
465

466
        file_path = os.path.join(self.annotations_temp_dir, f.f.path + ".json")
467

468
        if not os.path.isfile(file_path):  # no file with annotations
469

470
471
472
473
474
475
476
477
478
            f.stream_type = "color"
            f.reference_stream_type = "color"
            f.warp_to_reference = False
            f.convert_to_rgb = False
            f.crop = None
            f.video_data_only = True

            video = f.load(directory=self.original_directory,
                           extension=self.original_extension)
479

480
            annotations = {}
481

482
            for idx, image in enumerate(video.as_array()):
483

484
                frame_annotations = detect_face_landmarks_in_image(image, method=self.landmark_detect_method)
485

486
                if frame_annotations:
487

488
489
                    annotations[str(idx)] = frame_annotations

490
            if self.annotations_temp_dir:  # if directory is not an empty string
491

492
                bob.io.base.create_directories_safe(directory=os.path.split(file_path)[0], dryrun=False)
493
494

                with open(file_path, 'w+') as json_file:
495

496
                    json_file.write(json.dumps(annotations))
497

498
        else:  # if file with annotations exists load them from file
499

500
            with open(file_path, 'r') as json_file:
501

502
503
                annotations = json.load(json_file)

504
        if not annotations:  # if dictionary is empty
505
506
507

            return None

508
        return annotations