ijbc.py 4.56 KB
Newer Older
1
2
3
4
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
import pandas as pd
from bob.pipelines.sample import DelayedSample, SampleSet
from bob.extension import rc
5
import os
6
7
import bob.io.image
from functools import partial
8
9


10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def _make_sample_from_template_row(row, image_directory):
    return DelayedSample(
        load=partial(
            bob.io.image.load, path=os.path.join(image_directory, row["FILENAME"])
        ),
        reference_id=str(row["TEMPLATE_ID"]),
        subject_id=str(row["SUBJECT_ID"]),
        key=os.path.splitext(row["FILENAME"])[0],
        annotations={
            "topleft": (float(row["FACE_Y"]), float(row["FACE_X"])),
            "bottomright": (
                float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]),
                float(row["FACE_X"]) + float(row["FACE_WIDTH"]),
            ),
            "size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])),
        },
    )
27
28


29
30
31
32
33
34
35
36
37
def _make_sample_set_from_template_group(template_group, image_directory):
    samples = list(
        template_group.apply(
            _make_sample_from_template_row, axis=1, image_directory=image_directory
        )
    )
    return SampleSet(
        samples, reference_id=samples[0].template_id, subject_id=samples[0].subject_id
    )
38
39


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class IJBCDatabase(Database):
    def __init__(
        self,
        protocol="1:1",
        original_directory=rc["bob.bio.face.ijbc.directory"],
        **kwargs
    ):
        self._check_protocol(protocol)

        super().__init__(
            name="ijbc",
            protocol=protocol,
            allow_scoring_with_all_biometric_references=False,
            annotation_type="eyes-center",
            fixed_positions=None,
            memory_demanding=True,
        )
57

58
59
60
61
        self.image_directory = os.path.join(original_directory, "images")
        self.protocol_directory = os.path.join(original_directory, "protocols")
        self._cached_probes = None
        self._cached_references = None
62

63
64
65
66
67
68
69
70
71
72
73
74
75
        self._load_metadata()

    def _load_metadata(self):
        # Load CSV files
        self.reference_templates = pd.concat(
            [
                pd.read_csv(
                    os.path.join(self.protocol_directory, "ijbc_1N_gallery_G1.csv")
                ),
                pd.read_csv(
                    os.path.join(self.protocol_directory, "ijbc_1N_gallery_G2.csv")
                ),
            ]
76
77
        )

78
79
80
81
82
83
84
        self.probe_templates = pd.read_csv(
            os.path.join(self.protocol_directory, "ijbc_1N_probe_mixed.csv")
        )

        self.matches = pd.read_csv(
            os.path.join(self.protocol_directory, "ijbc_11_G1_G2_matches.csv"),
            names=["REFERENCE_TEMPLATE_ID", "PROBE_TEMPLATE_ID"],
85
86
87
        )

    def background_model_samples(self):
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        return None

    def probes(self, group="dev"):
        self._check_group(group)
        if self._cached_probes is None:
            self._cached_probes = list(
                self.probe_templates.groupby("TEMPLATE_ID").apply(
                    _make_sample_set_from_template_group,
                    image_directory=self.image_directory,
                )
            )

        # Link probes to the references they have to be compared with
        # We might make that faster if we manage to write it as a Panda instruction
        grouped_matches = self.matches.groupby("PROBE_TEMPLATE_ID")
        for probe_sampleset in self._cached_probes:
            probe_sampleset.references = list(
                grouped_matches.get_group(int(probe_sampleset.template_id))[
                    "REFERENCE_TEMPLATE_ID"
                ]
            )
109

110
111
112
113
114
115
116
117
118
119
        return self._cached_probes

    def references(self, group="dev"):
        self._check_group(group)
        if self._cached_references is None:
            self._cached_references = list(
                self.reference_templates.groupby("TEMPLATE_ID").apply(
                    _make_sample_set_from_template_group,
                    image_directory=self.image_directory,
                )
120
            )
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

        return self._cached_references

    def all_samples(self, group="dev"):
        self._check_group(group)

        return self.references() + self.probes()

    def groups(self):
        return ["dev"]

    def protocols(self):
        return ["1:1"]

    def _check_protocol(self, protocol):
        assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format(
            protocol, self.protocols()
138
139
        )

140
141
142
143
    def _check_group(self, group):
        assert group in self.groups(), "Unvalid group `{}` not in {}".format(
            group, self.groups()
        )