diff --git a/bob/bio/face/database/ijbc.py b/bob/bio/face/database/ijbc.py index d9428f7a4baac1ccb0108e7810290c2da159ad29..566695318a1ac96d94241aed96c9eca36596ee47 100644 --- a/bob/bio/face/database/ijbc.py +++ b/bob/bio/face/database/ijbc.py @@ -1,68 +1,147 @@ -#!/usr/bin/env python -# vim: set fileencoding=utf-8 : -# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> -# Sat 20 Aug 15:43:10 CEST 2020 - -from bob.pipelines.utils import hash_string -from bob.extension.download import get_file, find_element_in_tarball -import pickle +from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database +import pandas as pd +from bob.pipelines.sample import DelayedSample, SampleSet +from bob.extension import rc import os +import bob.io.image +from functools import partial -def load_ijbc_sample(original_path, extension=[".jpg", ".png"]): - for e in extension: - path = original_path + e - if os.path.exists(path): - return path - else: - return "" +def load(path): + return bob.io.image.load(os.path.join(rc["bob.db.ijbc.directory"], path)) -class IJBCDatabase: - def __init__(self, pkl_directory=None): - self.annotation_type = "bounding-box" - self.fixed_positions = None - self.allow_scoring_with_all_biometric_references = False - self.hash_fn = hash_string - self.memory_demanding = True +def _make_sample_from_template_row(row, image_directory): + return DelayedSample( + load=partial( + bob.io.image.load, path=os.path.join(image_directory, row["FILENAME"]) + ), + template_id=str(row["TEMPLATE_ID"]), + subject_id=str(row["SUBJECT_ID"]), + key=os.path.splitext(row["FILENAME"])[0], + annotations={ + "topleft": (float(row["FACE_Y"]), float(row["FACE_X"])), + "bottomright": ( + float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]), + float(row["FACE_X"]) + float(row["FACE_WIDTH"]), + ), + "size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])), + }, + ) - if pkl_directory is None: - urls = IJBCDatabase.urls() - pkl_directory = get_file( - "ijbc.tar.gz", urls, file_hash="4b25d7f10595eb9f97f328a2d448d957" - ) - self.pkl_directory = pkl_directory +def _make_sample_set_from_template_group(template_group, image_directory): + samples = list( + template_group.apply( + _make_sample_from_template_row, axis=1, image_directory=image_directory + ) + ) + return SampleSet( + samples, template_id=samples[0].template_id, subject_id=samples[0].subject_id + ) + - def _assert_group(self, group): - assert ( - group == "dev" - ), "The IJBC database only has a `dev` group. Received : {}".format(group) +class IJBCDatabase(Database): + def __init__( + self, + protocol="1:1", + original_directory=rc["bob.bio.face.ijbc.directory"], + **kwargs + ): + self._check_protocol(protocol) - def references(self, group="dev"): - self._assert_group(group) - return pickle.loads( - find_element_in_tarball(self.pkl_directory, "db_references.pickle", True) + super().__init__( + name="ijbc", + protocol=protocol, + allow_scoring_with_all_biometric_references=False, + annotation_type="eyes-center", + fixed_positions=None, + memory_demanding=True, ) - def probes(self, group="dev"): - self._assert_group(group) - return pickle.loads( - find_element_in_tarball(self.pkl_directory, "db_probes.pickle", True) + self.image_directory = os.path.join(original_directory, "images") + self.protocol_directory = os.path.join(original_directory, "protocols") + self._cached_probes = None + self._cached_references = None + + self._load_metadata() + + def _load_metadata(self): + # Load CSV files + self.reference_templates = pd.concat( + [ + pd.read_csv( + os.path.join(self.protocol_directory, "ijbc_1N_gallery_G1.csv") + ), + pd.read_csv( + os.path.join(self.protocol_directory, "ijbc_1N_gallery_G2.csv") + ), + ] + ) + + self.probe_templates = pd.read_csv( + os.path.join(self.protocol_directory, "ijbc_1N_probe_mixed.csv") + ) + + self.matches = pd.read_csv( + os.path.join(self.protocol_directory, "ijbc_11_G1_G2_matches.csv"), + names=["REFERENCE_TEMPLATE_ID", "PROBE_TEMPLATE_ID"], ) def background_model_samples(self): - import cloudpickle + return None - return cloudpickle.loads( - find_element_in_tarball( - self.pkl_directory, "db_background_model_samples.pickle", True + def probes(self, group="dev"): + self._check_group(group) + if self._cached_probes is None: + self._cached_probes = list( + self.probe_templates.groupby("TEMPLATE_ID").apply( + _make_sample_set_from_template_group, + image_directory=self.image_directory, + ) ) + + # Link probes to the references they have to be compared with + # We might make that faster if we manage to write it as a Panda instruction + grouped_matches = self.matches.groupby("PROBE_TEMPLATE_ID") + for probe_sampleset in self._cached_probes: + probe_sampleset.references = list( + grouped_matches.get_group(int(probe_sampleset.template_id))[ + "REFERENCE_TEMPLATE_ID" + ] + ) + + return self._cached_probes + + def references(self, group="dev"): + self._check_group(group) + if self._cached_references is None: + self._cached_references = list( + self.reference_templates.groupby("TEMPLATE_ID").apply( + _make_sample_set_from_template_group, + image_directory=self.image_directory, + ) + ) + + return self._cached_references + + def all_samples(self, group="dev"): + self._check_group(group) + + return self.references() + self.probes() + + def groups(self): + return ["dev"] + + def protocols(self): + return ["1:1"] + + def _check_protocol(self, protocol): + assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format( + protocol, self.protocols() ) - @staticmethod - def urls(): - return [ - "https://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz", - "http://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz", - ] + def _check_group(self, group): + assert group in self.groups(), "Unvalid group `{}` not in {}".format( + group, self.groups() + )