Skip to content
Snippets Groups Projects
Commit 6b38ab4c authored by Laurent COLBOIS's avatar Laurent COLBOIS
Browse files

Rework IJBC database using Pandas, to have both performance and portability

parent ffbe30af
No related branches found
No related tags found
1 merge request!124Resolve "IJBC database will fail on non-Idiap filesystems"
Pipeline #51432 failed
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Sat 20 Aug 15:43:10 CEST 2020
from bob.pipelines.utils import hash_string
from bob.extension.download import get_file, find_element_in_tarball
import pickle
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
import pandas as pd
from bob.pipelines.sample import DelayedSample, SampleSet
from bob.extension import rc
import os
import bob.io.image
from functools import partial
def load_ijbc_sample(original_path, extension=[".jpg", ".png"]):
for e in extension:
path = original_path + e
if os.path.exists(path):
return path
else:
return ""
def load(path):
return bob.io.image.load(os.path.join(rc["bob.db.ijbc.directory"], path))
class IJBCDatabase:
def __init__(self, pkl_directory=None):
self.annotation_type = "bounding-box"
self.fixed_positions = None
self.allow_scoring_with_all_biometric_references = False
self.hash_fn = hash_string
self.memory_demanding = True
def _make_sample_from_template_row(row, image_directory):
return DelayedSample(
load=partial(
bob.io.image.load, path=os.path.join(image_directory, row["FILENAME"])
),
template_id=str(row["TEMPLATE_ID"]),
subject_id=str(row["SUBJECT_ID"]),
key=os.path.splitext(row["FILENAME"])[0],
annotations={
"topleft": (float(row["FACE_Y"]), float(row["FACE_X"])),
"bottomright": (
float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]),
float(row["FACE_X"]) + float(row["FACE_WIDTH"]),
),
"size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])),
},
)
if pkl_directory is None:
urls = IJBCDatabase.urls()
pkl_directory = get_file(
"ijbc.tar.gz", urls, file_hash="4b25d7f10595eb9f97f328a2d448d957"
)
self.pkl_directory = pkl_directory
def _make_sample_set_from_template_group(template_group, image_directory):
samples = list(
template_group.apply(
_make_sample_from_template_row, axis=1, image_directory=image_directory
)
)
return SampleSet(
samples, template_id=samples[0].template_id, subject_id=samples[0].subject_id
)
def _assert_group(self, group):
assert (
group == "dev"
), "The IJBC database only has a `dev` group. Received : {}".format(group)
class IJBCDatabase(Database):
def __init__(
self,
protocol="1:1",
original_directory=rc["bob.bio.face.ijbc.directory"],
**kwargs
):
self._check_protocol(protocol)
def references(self, group="dev"):
self._assert_group(group)
return pickle.loads(
find_element_in_tarball(self.pkl_directory, "db_references.pickle", True)
super().__init__(
name="ijbc",
protocol=protocol,
allow_scoring_with_all_biometric_references=False,
annotation_type="eyes-center",
fixed_positions=None,
memory_demanding=True,
)
def probes(self, group="dev"):
self._assert_group(group)
return pickle.loads(
find_element_in_tarball(self.pkl_directory, "db_probes.pickle", True)
self.image_directory = os.path.join(original_directory, "images")
self.protocol_directory = os.path.join(original_directory, "protocols")
self._cached_probes = None
self._cached_references = None
self._load_metadata()
def _load_metadata(self):
# Load CSV files
self.reference_templates = pd.concat(
[
pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_gallery_G1.csv")
),
pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_gallery_G2.csv")
),
]
)
self.probe_templates = pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_probe_mixed.csv")
)
self.matches = pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_11_G1_G2_matches.csv"),
names=["REFERENCE_TEMPLATE_ID", "PROBE_TEMPLATE_ID"],
)
def background_model_samples(self):
import cloudpickle
return None
return cloudpickle.loads(
find_element_in_tarball(
self.pkl_directory, "db_background_model_samples.pickle", True
def probes(self, group="dev"):
self._check_group(group)
if self._cached_probes is None:
self._cached_probes = list(
self.probe_templates.groupby("TEMPLATE_ID").apply(
_make_sample_set_from_template_group,
image_directory=self.image_directory,
)
)
# Link probes to the references they have to be compared with
# We might make that faster if we manage to write it as a Panda instruction
grouped_matches = self.matches.groupby("PROBE_TEMPLATE_ID")
for probe_sampleset in self._cached_probes:
probe_sampleset.references = list(
grouped_matches.get_group(int(probe_sampleset.template_id))[
"REFERENCE_TEMPLATE_ID"
]
)
return self._cached_probes
def references(self, group="dev"):
self._check_group(group)
if self._cached_references is None:
self._cached_references = list(
self.reference_templates.groupby("TEMPLATE_ID").apply(
_make_sample_set_from_template_group,
image_directory=self.image_directory,
)
)
return self._cached_references
def all_samples(self, group="dev"):
self._check_group(group)
return self.references() + self.probes()
def groups(self):
return ["dev"]
def protocols(self):
return ["1:1"]
def _check_protocol(self, protocol):
assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format(
protocol, self.protocols()
)
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz",
]
def _check_group(self, group):
assert group in self.groups(), "Unvalid group `{}` not in {}".format(
group, self.groups()
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment