replay_mobile.py 2.19 KB
Newer Older
1
import logging
2

3
import numpy as np
4
from bob.extension import rc
5
from bob.extension.download import get_file
6
7
from bob.pad.base.database import FileListPadDatabase
from bob.pad.face.database import VideoPadSample
8
9
from bob.pipelines.transformers import Str_To_Types
from bob.pipelines.transformers import str_to_bool
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from sklearn.pipeline import make_pipeline

logger = logging.getLogger(__name__)


def get_rm_video_transform(sample):
    should_flip = sample.should_flip

    def transform(video):
        video = np.asarray(video)
        video = np.rollaxis(video, -1, -2)
        if should_flip:
            video = video[..., ::-1, :]
        return video

    return transform


def ReplayMobilePadDatabase(
    protocol="grandtest",
    selection_style=None,
    max_number_of_frames=None,
    step_size=None,
    annotation_directory=None,
    annotation_type=None,
    fixed_positions=None,
    **kwargs,
):
38
39
40
41
42
43
    name = "pad-face-replay-mobile-586b7e81.tar.gz"
    dataset_protocols_path = get_file(
        name,
        [f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
        cache_subdir="protocols",
        file_hash="586b7e81",
44
    )
45

46
    if annotation_directory is None:
47
        name = "annotations-replaymobile-mtcnn-9cd6e452.tar.xz"
48
        annotation_directory = get_file(
49
50
51
52
            name,
            [f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
            cache_subdir="annotations",
            file_hash="9cd6e452",
53
54
        )
        annotation_type = "eyes-center"
55

56
57
58
59
    transformer = make_pipeline(
        Str_To_Types(fieldtypes=dict(should_flip=str_to_bool)),
        VideoPadSample(
            original_directory=rc.get("bob.db.replaymobile.directory"),
60
            annotation_directory=annotation_directory,
61
62
63
64
65
66
            selection_style=selection_style,
            max_number_of_frames=max_number_of_frames,
            step_size=step_size,
            get_transform=get_rm_video_transform,
        ),
    )
67

68
69
70
71
72
73
74
75
76
    database = FileListPadDatabase(
        dataset_protocols_path,
        protocol,
        transformer=transformer,
        **kwargs,
    )
    database.annotation_type = annotation_type
    database.fixed_positions = fixed_positions
    return database