diff --git a/bob/bio/face/database/ijbc.py b/bob/bio/face/database/ijbc.py index 566695318a1ac96d94241aed96c9eca36596ee47..d9428f7a4baac1ccb0108e7810290c2da159ad29 100644 --- a/bob/bio/face/database/ijbc.py +++ b/bob/bio/face/database/ijbc.py @@ -1,147 +1,68 @@ -from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database -import pandas as pd -from bob.pipelines.sample import DelayedSample, SampleSet -from bob.extension import rc +#!/usr/bin/env python +# vim: set fileencoding=utf-8 : +# Tiago de Freitas Pereira <tiago.pereira@idiap.ch> +# Sat 20 Aug 15:43:10 CEST 2020 + +from bob.pipelines.utils import hash_string +from bob.extension.download import get_file, find_element_in_tarball +import pickle import os -import bob.io.image -from functools import partial -def load(path): - return bob.io.image.load(os.path.join(rc["bob.db.ijbc.directory"], path)) +def load_ijbc_sample(original_path, extension=[".jpg", ".png"]): + for e in extension: + path = original_path + e + if os.path.exists(path): + return path + else: + return "" -def _make_sample_from_template_row(row, image_directory): - return DelayedSample( - load=partial( - bob.io.image.load, path=os.path.join(image_directory, row["FILENAME"]) - ), - template_id=str(row["TEMPLATE_ID"]), - subject_id=str(row["SUBJECT_ID"]), - key=os.path.splitext(row["FILENAME"])[0], - annotations={ - "topleft": (float(row["FACE_Y"]), float(row["FACE_X"])), - "bottomright": ( - float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]), - float(row["FACE_X"]) + float(row["FACE_WIDTH"]), - ), - "size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])), - }, - ) +class IJBCDatabase: + def __init__(self, pkl_directory=None): + self.annotation_type = "bounding-box" + self.fixed_positions = None + self.allow_scoring_with_all_biometric_references = False + self.hash_fn = hash_string + self.memory_demanding = True + if pkl_directory is None: + urls = IJBCDatabase.urls() + pkl_directory = get_file( + "ijbc.tar.gz", urls, file_hash="4b25d7f10595eb9f97f328a2d448d957" + ) -def _make_sample_set_from_template_group(template_group, image_directory): - samples = list( - template_group.apply( - _make_sample_from_template_row, axis=1, image_directory=image_directory - ) - ) - return SampleSet( - samples, template_id=samples[0].template_id, subject_id=samples[0].subject_id - ) - - -class IJBCDatabase(Database): - def __init__( - self, - protocol="1:1", - original_directory=rc["bob.bio.face.ijbc.directory"], - **kwargs - ): - self._check_protocol(protocol) - - super().__init__( - name="ijbc", - protocol=protocol, - allow_scoring_with_all_biometric_references=False, - annotation_type="eyes-center", - fixed_positions=None, - memory_demanding=True, - ) - - self.image_directory = os.path.join(original_directory, "images") - self.protocol_directory = os.path.join(original_directory, "protocols") - self._cached_probes = None - self._cached_references = None - - self._load_metadata() + self.pkl_directory = pkl_directory - def _load_metadata(self): - # Load CSV files - self.reference_templates = pd.concat( - [ - pd.read_csv( - os.path.join(self.protocol_directory, "ijbc_1N_gallery_G1.csv") - ), - pd.read_csv( - os.path.join(self.protocol_directory, "ijbc_1N_gallery_G2.csv") - ), - ] - ) + def _assert_group(self, group): + assert ( + group == "dev" + ), "The IJBC database only has a `dev` group. Received : {}".format(group) - self.probe_templates = pd.read_csv( - os.path.join(self.protocol_directory, "ijbc_1N_probe_mixed.csv") + def references(self, group="dev"): + self._assert_group(group) + return pickle.loads( + find_element_in_tarball(self.pkl_directory, "db_references.pickle", True) ) - self.matches = pd.read_csv( - os.path.join(self.protocol_directory, "ijbc_11_G1_G2_matches.csv"), - names=["REFERENCE_TEMPLATE_ID", "PROBE_TEMPLATE_ID"], + def probes(self, group="dev"): + self._assert_group(group) + return pickle.loads( + find_element_in_tarball(self.pkl_directory, "db_probes.pickle", True) ) def background_model_samples(self): - return None + import cloudpickle - def probes(self, group="dev"): - self._check_group(group) - if self._cached_probes is None: - self._cached_probes = list( - self.probe_templates.groupby("TEMPLATE_ID").apply( - _make_sample_set_from_template_group, - image_directory=self.image_directory, - ) + return cloudpickle.loads( + find_element_in_tarball( + self.pkl_directory, "db_background_model_samples.pickle", True ) - - # Link probes to the references they have to be compared with - # We might make that faster if we manage to write it as a Panda instruction - grouped_matches = self.matches.groupby("PROBE_TEMPLATE_ID") - for probe_sampleset in self._cached_probes: - probe_sampleset.references = list( - grouped_matches.get_group(int(probe_sampleset.template_id))[ - "REFERENCE_TEMPLATE_ID" - ] - ) - - return self._cached_probes - - def references(self, group="dev"): - self._check_group(group) - if self._cached_references is None: - self._cached_references = list( - self.reference_templates.groupby("TEMPLATE_ID").apply( - _make_sample_set_from_template_group, - image_directory=self.image_directory, - ) - ) - - return self._cached_references - - def all_samples(self, group="dev"): - self._check_group(group) - - return self.references() + self.probes() - - def groups(self): - return ["dev"] - - def protocols(self): - return ["1:1"] - - def _check_protocol(self, protocol): - assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format( - protocol, self.protocols() ) - def _check_group(self, group): - assert group in self.groups(), "Unvalid group `{}` not in {}".format( - group, self.groups() - ) + @staticmethod + def urls(): + return [ + "https://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz", + "http://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz", + ]