Skip to content
Snippets Groups Projects
Commit 3b51060c authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Merge branch 'fix/database' into 'develop'

fix: harmonize all the Database Interface classes.

See merge request !146
parents b6182b3d 5d0418bd
Branches
Tags
1 merge request!146fix: harmonize all the Database Interface classes.
"""Config file for the CASIA FASD dataset.
Please run ``bob config set bob.db.casia_fasd.directory /path/to/database/casia_fasd/``
in terminal to point to the original files of the dataset on your computer."""
in a terminal to point to the original files of the dataset on your computer.
"""
from bob.pad.face.database import CasiaFasdPadDatabase
......
"""Database Interface definition for the CASIA-FASD dataset."""
import logging
from clapper.rc import UserDefaults
......@@ -10,7 +12,9 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml")
def CasiaFasdPadDatabase(
def CasiaFasdPadDatabase( # noqa: N802
protocol: str = "grandtest",
*,
selection_style=None,
max_number_of_frames=None,
step_size=None,
......@@ -19,6 +23,7 @@ def CasiaFasdPadDatabase(
fixed_positions=None,
**kwargs,
):
"""Return a Database Interface for the Casia FASD dataset."""
name = "pad-face-casia-fasd-0b07ea45.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......@@ -38,10 +43,13 @@ def CasiaFasdPadDatabase(
database = FileListPadDatabase(
name="casia-fsd",
dataset_protocols_path=dataset_protocols_path,
protocol="grandtest",
protocol=protocol,
transformer=transformer,
**kwargs,
)
database.annotation_type = annotation_type
database.fixed_positions = fixed_positions
database.protocols = lambda: ["grandtest"]
database.groups = lambda: ["dev", "eval", "train"]
return database
"""Database Interface definition for the CASIA-SURF dataset."""
import logging
import os
from functools import partial
from pathlib import Path
from clapper.rc import UserDefaults
from sklearn.preprocessing import FunctionTransformer
......@@ -17,23 +19,27 @@ rc = UserDefaults("bobrc.toml")
def load_multi_stream(path):
"""Helper loader to use in :py:class:`bob.pipelines.DelayedSample` objects."""
data = bob.io.base.load(path)
video = VideoLikeContainer(data[None, ...], [0])
return video
return VideoLikeContainer(data[None, ...], [0])
def casia_surf_multistream_load(samples, original_directory):
def casia_surf_multistream_load(samples, original_directory: str | None):
"""Make :py:class:`bob.pipelines.DelayedSample` objects for multi-stream samples."""
mod_to_attr = {}
mod_to_attr["color"] = "filename"
mod_to_attr["infrared"] = "ir_filename"
mod_to_attr["depth"] = "depth_filename"
mods = list(mod_to_attr.keys())
if original_directory is None:
original_directory = ""
def _load(sample):
paths = dict()
for mod in mods:
paths[mod] = os.path.join(
original_directory or "", getattr(sample, mod_to_attr[mod])
paths[mod] = str(
Path(original_directory) / getattr(sample, mod_to_attr[mod]),
)
data = partial(load_multi_stream, paths["color"])
depth = partial(load_multi_stream, paths["depth"])
......@@ -57,7 +63,8 @@ def casia_surf_multistream_load(samples, original_directory):
return [_load(s) for s in samples]
def CasiaSurfMultiStreamSample(original_directory):
def CasiaSurfMultiStreamSample(original_directory): # noqa: N802
"""Transformer for loading multi-stream samples."""
return FunctionTransformer(
casia_surf_multistream_load,
kw_args=dict(original_directory=original_directory),
......@@ -70,20 +77,25 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
Parameters
----------
stream_type : str
A str or a list of str of the following choices: ``all``, ``color``, ``depth``, ``infrared``, by default ``all``
A str or a list of str of the following choices: ``all``, ``color``, ``depth``,
``infrared``, by default ``all``
The returned sample either have their data as a VideoLikeContainer or
a dict of VideoLikeContainers depending on the chosen stream_type.
TODO 2024: WTH is stream_type??
"""
def __init__(
self,
protocol: str = "Testing",
**kwargs,
):
original_directory = rc.get("bob.db.casia_surf.directory")
if original_directory is None or not os.path.isdir(original_directory):
if original_directory is None or not Path(original_directory).is_dir():
raise FileNotFoundError(
"The original_directory is not set. Please set it in the terminal using `bob config set bob.db.casia_surf.directory /path/to/database/CASIA-SURF/`."
"The original_directory is not set. Please set it in the terminal "
"using\n\t`bob config set bob.db.casia_surf.directory "
"/path/to/database/CASIA-SURF/`.",
)
transformer = CasiaSurfMultiStreamSample(
original_directory=original_directory,
......@@ -91,7 +103,7 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
super().__init__(
name="casia-surf",
dataset_protocols_path=original_directory,
protocol="all",
protocol=protocol,
reader_cls=partial(
CSVToSamples,
dict_reader_kwargs=dict(
......@@ -110,16 +122,21 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
self.annotation_type = None
self.fixed_positions = None
def protocols(self):
return ["all"]
@classmethod
def protocols(cls):
"""Return the list of all available protocols for this dataset."""
return ["Training", "Val", "Testing"]
def groups(self):
@classmethod
def groups(cls):
"""Return the list of all existing groups for this dataset."""
return ["train", "dev", "eval"]
def list_file(self, group):
"""Return the protocol definition file full filename for the given group."""
filename = {
"train": "train_list.txt",
"dev": "val_private_list.txt",
"eval": "test_private_list.txt",
}[group]
return os.path.join(self.dataset_protocols_path, filename)
return str(Path(self.dataset_protocols_path) / filename)
"""Database Interface definition for the 3DMAD (3D Mask Attack Database) dataset."""
import logging
import os
from functools import partial
from pathlib import Path
import h5py
import numpy as np
......@@ -25,6 +27,7 @@ def load_frames_from_hdf5(
max_number_of_frames=None,
step_size=None,
):
"""Helper loader to use in :py:class:`bob.pipelines.DelayedSample` objects."""
with h5py.File(hdf5_file) as f:
video = f[key][()]
# reduce the shape of depth from (N, C, H, W) to (N, H, W) since H == 1
......@@ -36,25 +39,23 @@ def load_frames_from_hdf5(
selection_style=selection_style,
step_size=step_size,
)
data = VideoLikeContainer(video[indices], indices)
return data
return VideoLikeContainer(video[indices], indices)
def load_annotations_from_hdf5(
hdf5_file,
):
"""Return a dictionary of annotations loaded from a path."""
with h5py.File(hdf5_file) as f:
eye_pos = f["Eye_Pos"][()]
annotations = {
return {
str(i): {
"reye": [row[1], row[0]],
"leye": [row[3], row[2]],
}
for i, row in enumerate(eye_pos)
}
return annotations
def delayed_maskattack_video_load(
......@@ -64,10 +65,11 @@ def delayed_maskattack_video_load(
max_number_of_frames=None,
step_size=None,
):
"""Make :py:class:`bob.pipelines.DelayedSample` objects for mask-attack samples."""
original_directory = original_directory or ""
results = []
for sample in samples:
hdf5_file = os.path.join(original_directory, sample.filename)
hdf5_file = str(Path(original_directory) / sample.filename)
data = partial(
load_frames_from_hdf5,
key="Color_Data",
......@@ -98,17 +100,18 @@ def delayed_maskattack_video_load(
data,
parent=sample,
delayed_attributes=delayed_attributes,
)
),
)
return results
def MaskAttackPadSample(
def MaskAttackPadSample( # noqa: N802
original_directory,
selection_style=None,
max_number_of_frames=None,
step_size=None,
):
"""Return a Transformer creating samples of videos."""
return FunctionTransformer(
delayed_maskattack_video_load,
validate=False,
......@@ -121,13 +124,15 @@ def MaskAttackPadSample(
)
def MaskAttackPadDatabase(
def MaskAttackPadDatabase( # noqa: N802
protocol="classification",
*,
selection_style=None,
max_number_of_frames=None,
step_size=None,
**kwargs,
):
"""Return a Database Interface object for the mask-attack dataset."""
name = "pad-face-mask-attack-6d8854c2.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......
"""Database Interface definition for the OULU-NPU dataset."""
import logging
from clapper.rc import UserDefaults
......@@ -10,8 +12,9 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml")
def OuluNpuPadDatabase(
def OuluNpuPadDatabase( # noqa: N802
protocol="Protocol_1",
*,
selection_style=None,
max_number_of_frames=None,
step_size=None,
......@@ -20,6 +23,7 @@ def OuluNpuPadDatabase(
fixed_positions=None,
**kwargs,
):
"""Return a Database Interface object for the OULU-NPU dataset."""
name = "pad-face-oulunpu-7bfb90c9.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......@@ -32,7 +36,7 @@ def OuluNpuPadDatabase(
name = "annotations-oulunpu-mtcnn-903addac.tar.gz"
annotation_directory = download_file(
urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
],
destination_sub_directory="annotations",
destination_filename=name,
......
"""Database Interface definition for the Replay Attack dataset."""
import logging
from clapper.rc import UserDefaults
......@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml")
def ReplayAttackPadDatabase(
def ReplayAttackPadDatabase( # noqa: N802
protocol="grandtest",
selection_style=None,
max_number_of_frames=None,
......@@ -20,6 +22,7 @@ def ReplayAttackPadDatabase(
fixed_positions=None,
**kwargs,
):
"""Returns a Database Interface object for the Replay Attack dataset."""
name = "pad-face-replay-attack-aca6b46f.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......@@ -32,7 +35,7 @@ def ReplayAttackPadDatabase(
name = "annotations-replay-attack-mtcnn-8d1f4c12.tar.gz"
annotation_directory = download_file(
urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
],
destination_filename=name,
destination_sub_directory="annotations",
......
"""Database Interface definition for the Replay Mobile dataset."""
import logging
from clapper.rc import UserDefaults
......@@ -13,6 +15,7 @@ rc = UserDefaults("bobrc.toml")
def get_rm_video_transform(sample):
"""Return a lazy loader for the video that is correctly oriented."""
should_flip = sample.should_flip
def transform(video):
......@@ -24,7 +27,7 @@ def get_rm_video_transform(sample):
return transform
def ReplayMobilePadDatabase(
def ReplayMobilePadDatabase( # noqa: N802
protocol="grandtest",
selection_style=None,
max_number_of_frames=None,
......@@ -34,6 +37,7 @@ def ReplayMobilePadDatabase(
fixed_positions=None,
**kwargs,
):
"""Return a Database Interface object for the Replay Mobile dataset."""
name = "pad-face-replay-mobile-620dded2.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......@@ -46,7 +50,7 @@ def ReplayMobilePadDatabase(
name = "annotations-replay-mobile-mtcnn-20055a07.tar.gz"
annotation_directory = download_file(
urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
],
destination_filename=name,
destination_sub_directory="annotations",
......
"""Database Interface definition for the SWAN dataset."""
import logging
from clapper.rc import UserDefaults
......@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml")
def SwanPadDatabase(
def SwanPadDatabase( # noqa: N802
protocol="pad_p2_face_f1",
selection_style=None,
max_number_of_frames=None,
......@@ -20,6 +22,7 @@ def SwanPadDatabase(
fixed_positions=None,
**kwargs,
):
"""Return a Database interface object for the SWAN dataset."""
name = "pad-face-swan-ce83ebd8.tar.gz"
dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......@@ -32,7 +35,7 @@ def SwanPadDatabase(
name = "annotations-swan-mtcnn-9f9e12d8.tar.gz"
annotation_directory = download_file(
urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
],
destination_filename=name,
destination_sub_directory="annotations",
......
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# Thu May 24 10:41:42 CEST 2012
"""Test all the Database Interfaces for general function and sample loading if possible.
When the data samples are not available (e.g. on the CI runners), the ``loading`` tests
are skipped.
"""
from unittest import SkipTest
......@@ -9,7 +13,8 @@ import numpy as np
import bob.bio.base
def test_replay_attack():
def test_replay_attack_protocol():
"""Test the replay-attack database interface without loading samples."""
database = bob.bio.base.load_resource(
"replay-attack",
"database",
......@@ -42,6 +47,15 @@ def test_replay_attack():
== 1000
)
def test_replay_attack_loading():
"""Test the loading of samples of the replay-attack database."""
database = bob.bio.base.load_resource(
"replay-attack",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = database.sort(database.samples())[0]
try:
annot = dict(sample.annotations["0"])
......@@ -61,7 +75,8 @@ def test_replay_attack():
raise SkipTest(e)
def test_replay_mobile():
def test_replay_mobile_protocol():
"""Test the replay-mobile database interface without loading samples."""
database = bob.bio.base.load_resource(
"replay-mobile",
"database",
......@@ -122,6 +137,18 @@ def test_replay_mobile():
"bottomright": [1111, 495],
}, annot
def test_replay_mobile_loading():
"""Test the loading of samples of the replay-mobile database."""
database = bob.bio.base.load_resource(
"replay-mobile",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
all_samples = database.sort(database.samples())
sample = all_samples[0]
sample2 = [s for s in all_samples if not s.should_flip][0]
try:
assert sample.data.shape == (20, 3, 1280, 720), sample.data.shape
np.testing.assert_equal(sample.data[0][:, 0, 0], [13, 13, 13])
......@@ -132,7 +159,8 @@ def test_replay_mobile():
# Test the mask_attack database
def test_mask_attack():
def test_mask_attack_protocol():
"""Test the mask-attack database interface without loading samples."""
mask_attack = bob.bio.base.load_resource(
"mask-attack",
"database",
......@@ -171,6 +199,16 @@ def test_mask_attack():
assert len(mask_attack.samples(groups=["dev"], purposes="attack")) == 25
assert len(mask_attack.samples(groups=["eval"], purposes="attack")) == 25
def test_mask_attack_loading():
"""Test the loading of samples of the mask-attack database."""
mask_attack = bob.bio.base.load_resource(
"mask-attack",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = mask_attack.samples()[0]
try:
assert sample.data.shape == (20, 3, 480, 640)
......@@ -186,7 +224,8 @@ def test_mask_attack():
raise SkipTest(e)
def test_casia_fasd():
def test_casia_fasd_protocols():
"""Test the casia-fasd database interface without loading samples."""
casia_fasd = bob.bio.base.load_resource(
"casia-fasd",
"database",
......@@ -201,6 +240,16 @@ def test_casia_fasd():
assert len(casia_fasd.samples(groups="train")) == 180
assert len(casia_fasd.samples(groups="dev")) == 60
assert len(casia_fasd.samples(groups="eval")) == 360
def test_casia_fasd_loading():
"""Test the loading of samples of the casia-fasd database."""
casia_fasd = bob.bio.base.load_resource(
"casia-fasd",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = casia_fasd.samples()[0]
try:
assert sample.data.shape == (20, 3, 480, 640)
......@@ -209,7 +258,12 @@ def test_casia_fasd():
raise SkipTest(e)
def test_casia_surf():
def test_casia_surf_protocol():
"""Test the casia-surf database interface without loading samples.
As the protocol definition files are shipped with the data, this test will also be
skipped when the original directory is not set.
"""
try:
casia_surf = bob.bio.base.load_resource(
"casia-surf",
......@@ -225,6 +279,19 @@ def test_casia_surf():
assert len(casia_surf.samples(groups="train")) == 29266
assert len(casia_surf.samples(groups="dev")) == 9608
assert len(casia_surf.samples(groups="eval")) == 57710
except FileNotFoundError as e:
raise SkipTest(e)
def test_casia_surf_loading():
"""Test the loading of samples of the casia-surf database."""
try:
casia_surf = bob.bio.base.load_resource(
"casia-surf",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = casia_surf.samples()[0]
assert sample.data.shape == (1, 3, 279, 279)
np.testing.assert_equal(sample.data[0][:, 0, 0], [0, 0, 0])
......@@ -234,7 +301,8 @@ def test_casia_surf():
raise SkipTest(e)
def test_swan():
def test_swan_protocol():
"""Test the swan database interface without loading samples."""
database = bob.bio.base.load_resource(
"swan",
"database",
......@@ -264,6 +332,15 @@ def test_swan():
== 2502
)
def test_swan_loading():
"""Test the loading of samples of the swan database."""
database = bob.bio.base.load_resource(
"swan",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = database.sort(database.samples())[0]
try:
annot = dict(sample.annotations["0"])
......@@ -283,7 +360,8 @@ def test_swan():
raise SkipTest(e)
def test_oulu_npu():
def test_oulu_npu_protocol():
"""Test the oulu-npu database interface without loading samples."""
database = bob.bio.base.load_resource(
"oulu-npu",
"database",
......@@ -327,6 +405,16 @@ def test_oulu_npu():
== 960 + 720 + 480
)
def test_oulu_npu_loading():
"""Test the loading of samples of the oulu-npu database."""
database = bob.bio.base.load_resource(
"oulu-npu",
"database",
preferred_package="bob.pad.face",
package_prefix="bob.pad.",
)
sample = database.sort(database.samples())[0]
try:
annot = dict(sample.annotations["0"])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment