Commit aae51a2c authored by Laurent COLBOIS's avatar Laurent COLBOIS

Rework IJBC database using Pandas, to have both performance and portability

parent ffbe30af
Pipeline #51430 failed with stages
in 237 minutes and 9 seconds
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# Sat 20 Aug 15:43:10 CEST 2020
from bob.pipelines.utils import hash_string
from bob.extension.download import get_file, find_element_in_tarball
import pickle
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
import pandas as pd
from bob.pipelines.sample import DelayedSample, SampleSet
from bob.extension import rc
import os
import bob.io.image
from functools import partial
def load_ijbc_sample(original_path, extension=[".jpg", ".png"]):
for e in extension:
path = original_path + e
if os.path.exists(path):
return path
else:
return ""
def load(path):
return bob.io.image.load(os.path.join(rc["bob.db.ijbc.directory"], path))
class IJBCDatabase:
def __init__(self, pkl_directory=None):
self.annotation_type = "bounding-box"
self.fixed_positions = None
self.allow_scoring_with_all_biometric_references = False
self.hash_fn = hash_string
self.memory_demanding = True
def _make_sample_from_template_row(row, image_directory):
return DelayedSample(
load=partial(
bob.io.image.load, path=os.path.join(image_directory, row["FILENAME"])
),
template_id=str(row["TEMPLATE_ID"]),
subject_id=str(row["SUBJECT_ID"]),
key=os.path.splitext(row["FILENAME"])[0],
annotations={
"topleft": (float(row["FACE_Y"]), float(row["FACE_X"])),
"bottomright": (
float(row["FACE_Y"]) + float(row["FACE_HEIGHT"]),
float(row["FACE_X"]) + float(row["FACE_WIDTH"]),
),
"size": (float(row["FACE_HEIGHT"]), float(row["FACE_WIDTH"])),
},
)
if pkl_directory is None:
urls = IJBCDatabase.urls()
pkl_directory = get_file(
"ijbc.tar.gz", urls, file_hash="4b25d7f10595eb9f97f328a2d448d957"
)
self.pkl_directory = pkl_directory
def _make_sample_set_from_template_group(template_group, image_directory):
samples = list(
template_group.apply(
_make_sample_from_template_row, axis=1, image_directory=image_directory
)
)
return SampleSet(
samples, template_id=samples[0].template_id, subject_id=samples[0].subject_id
)
def _assert_group(self, group):
assert (
group == "dev"
), "The IJBC database only has a `dev` group. Received : {}".format(group)
class IJBCDatabase(Database):
def __init__(
self,
protocol="1:1",
original_directory=rc["bob.bio.face.ijbc.directory"],
**kwargs
):
self._check_protocol(protocol)
def references(self, group="dev"):
self._assert_group(group)
return pickle.loads(
find_element_in_tarball(self.pkl_directory, "db_references.pickle", True)
super().__init__(
name="ijbc",
protocol=protocol,
allow_scoring_with_all_biometric_references=False,
annotation_type="eyes-center",
fixed_positions=None,
memory_demanding=True,
)
def probes(self, group="dev"):
self._assert_group(group)
return pickle.loads(
find_element_in_tarball(self.pkl_directory, "db_probes.pickle", True)
self.image_directory = os.path.join(original_directory, "images")
self.protocol_directory = os.path.join(original_directory, "protocols")
self._cached_probes = None
self._cached_references = None
self._load_metadata()
def _load_metadata(self):
# Load CSV files
self.reference_templates = pd.concat(
[
pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_gallery_G1.csv")
),
pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_gallery_G2.csv")
),
]
)
self.probe_templates = pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_1N_probe_mixed.csv")
)
self.matches = pd.read_csv(
os.path.join(self.protocol_directory, "ijbc_11_G1_G2_matches.csv"),
names=["REFERENCE_TEMPLATE_ID", "PROBE_TEMPLATE_ID"],
)
def background_model_samples(self):
import cloudpickle
return None
return cloudpickle.loads(
find_element_in_tarball(
self.pkl_directory, "db_background_model_samples.pickle", True
def probes(self, group="dev"):
self._check_group(group)
if self._cached_probes is None:
self._cached_probes = list(
self.probe_templates.groupby("TEMPLATE_ID").apply(
_make_sample_set_from_template_group,
image_directory=self.image_directory,
)
)
# Link probes to the references they have to be compared with
# We might make that faster if we manage to write it as a Panda instruction
grouped_matches = self.matches.groupby("PROBE_TEMPLATE_ID")
for probe_sampleset in self._cached_probes:
probe_sampleset.references = list(
grouped_matches.get_group(int(probe_sampleset.template_id))[
"REFERENCE_TEMPLATE_ID"
]
)
return self._cached_probes
def references(self, group="dev"):
self._check_group(group)
if self._cached_references is None:
self._cached_references = list(
self.reference_templates.groupby("TEMPLATE_ID").apply(
_make_sample_set_from_template_group,
image_directory=self.image_directory,
)
)
return self._cached_references
def all_samples(self, group="dev"):
self._check_group(group)
return self.references() + self.probes()
def groups(self):
return ["dev"]
def protocols(self):
return ["1:1"]
def _check_protocol(self, protocol):
assert protocol in self.protocols(), "Unvalid protocol `{}` not in {}".format(
protocol, self.protocols()
)
@staticmethod
def urls():
return [
"https://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz",
"http://www.idiap.ch/software/bob/databases/latest/ijbc.tar.gz",
]
def _check_group(self, group):
assert group in self.groups(), "Unvalid group `{}` not in {}".format(
group, self.groups()
)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment