Skip to content
Snippets Groups Projects
Verified Commit 3f8ca2aa authored by Yannick DAYER's avatar Yannick DAYER
Browse files

fix: harmonize all the database interface classes.

Ensure all Database Interface classes accept `protocol` as parameter.
parent b6182b3d
Branches
No related tags found
1 merge request!146fix: harmonize all the Database Interface classes.
"""Config file for the CASIA FASD dataset. """Config file for the CASIA FASD dataset.
Please run ``bob config set bob.db.casia_fasd.directory /path/to/database/casia_fasd/`` Please run ``bob config set bob.db.casia_fasd.directory /path/to/database/casia_fasd/``
in terminal to point to the original files of the dataset on your computer.""" in a terminal to point to the original files of the dataset on your computer.
"""
from bob.pad.face.database import CasiaFasdPadDatabase from bob.pad.face.database import CasiaFasdPadDatabase
......
"""Database Interface definition for the CASIA-FASD dataset."""
import logging import logging
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
...@@ -10,7 +12,9 @@ logger = logging.getLogger(__name__) ...@@ -10,7 +12,9 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml") rc = UserDefaults("bobrc.toml")
def CasiaFasdPadDatabase( def CasiaFasdPadDatabase( # noqa: N802
protocol: str = "grandtest",
*,
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
...@@ -19,6 +23,7 @@ def CasiaFasdPadDatabase( ...@@ -19,6 +23,7 @@ def CasiaFasdPadDatabase(
fixed_positions=None, fixed_positions=None,
**kwargs, **kwargs,
): ):
"""Return a Database Interface for the Casia FASD dataset."""
name = "pad-face-casia-fasd-0b07ea45.tar.gz" name = "pad-face-casia-fasd-0b07ea45.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
...@@ -38,10 +43,13 @@ def CasiaFasdPadDatabase( ...@@ -38,10 +43,13 @@ def CasiaFasdPadDatabase(
database = FileListPadDatabase( database = FileListPadDatabase(
name="casia-fsd", name="casia-fsd",
dataset_protocols_path=dataset_protocols_path, dataset_protocols_path=dataset_protocols_path,
protocol="grandtest", protocol=protocol,
transformer=transformer, transformer=transformer,
**kwargs, **kwargs,
) )
database.annotation_type = annotation_type database.annotation_type = annotation_type
database.fixed_positions = fixed_positions database.fixed_positions = fixed_positions
database.protocols = lambda: ["grandtest"]
database.groups = lambda: ["dev", "eval", "train"]
return database return database
"""Database Interface definition for the CASIA-SURF dataset."""
import logging import logging
import os
from functools import partial from functools import partial
from pathlib import Path
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
from sklearn.preprocessing import FunctionTransformer from sklearn.preprocessing import FunctionTransformer
...@@ -17,23 +19,27 @@ rc = UserDefaults("bobrc.toml") ...@@ -17,23 +19,27 @@ rc = UserDefaults("bobrc.toml")
def load_multi_stream(path): def load_multi_stream(path):
"""Helper loader to use in :py:class:`bob.pipelines.DelayedSample` objects."""
data = bob.io.base.load(path) data = bob.io.base.load(path)
video = VideoLikeContainer(data[None, ...], [0]) return VideoLikeContainer(data[None, ...], [0])
return video
def casia_surf_multistream_load(samples, original_directory): def casia_surf_multistream_load(samples, original_directory: str | None):
"""Make :py:class:`bob.pipelines.DelayedSample` objects for multi-stream samples."""
mod_to_attr = {} mod_to_attr = {}
mod_to_attr["color"] = "filename" mod_to_attr["color"] = "filename"
mod_to_attr["infrared"] = "ir_filename" mod_to_attr["infrared"] = "ir_filename"
mod_to_attr["depth"] = "depth_filename" mod_to_attr["depth"] = "depth_filename"
mods = list(mod_to_attr.keys()) mods = list(mod_to_attr.keys())
if original_directory is None:
original_directory = ""
def _load(sample): def _load(sample):
paths = dict() paths = dict()
for mod in mods: for mod in mods:
paths[mod] = os.path.join( paths[mod] = str(
original_directory or "", getattr(sample, mod_to_attr[mod]) Path(original_directory) / getattr(sample, mod_to_attr[mod]),
) )
data = partial(load_multi_stream, paths["color"]) data = partial(load_multi_stream, paths["color"])
depth = partial(load_multi_stream, paths["depth"]) depth = partial(load_multi_stream, paths["depth"])
...@@ -57,7 +63,8 @@ def casia_surf_multistream_load(samples, original_directory): ...@@ -57,7 +63,8 @@ def casia_surf_multistream_load(samples, original_directory):
return [_load(s) for s in samples] return [_load(s) for s in samples]
def CasiaSurfMultiStreamSample(original_directory): def CasiaSurfMultiStreamSample(original_directory): # noqa: N802
"""Transformer for loading multi-stream samples."""
return FunctionTransformer( return FunctionTransformer(
casia_surf_multistream_load, casia_surf_multistream_load,
kw_args=dict(original_directory=original_directory), kw_args=dict(original_directory=original_directory),
...@@ -70,20 +77,25 @@ class CasiaSurfPadDatabase(FileListPadDatabase): ...@@ -70,20 +77,25 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
Parameters Parameters
---------- ----------
stream_type : str stream_type : str
A str or a list of str of the following choices: ``all``, ``color``, ``depth``, ``infrared``, by default ``all`` A str or a list of str of the following choices: ``all``, ``color``, ``depth``,
``infrared``, by default ``all``
The returned sample either have their data as a VideoLikeContainer or The returned sample either have their data as a VideoLikeContainer or
a dict of VideoLikeContainers depending on the chosen stream_type. a dict of VideoLikeContainers depending on the chosen stream_type.
TODO 2024: WTH is stream_type??
""" """
def __init__( def __init__(
self, self,
protocol: str = "Testing",
**kwargs, **kwargs,
): ):
original_directory = rc.get("bob.db.casia_surf.directory") original_directory = rc.get("bob.db.casia_surf.directory")
if original_directory is None or not os.path.isdir(original_directory): if original_directory is None or not Path(original_directory).is_dir():
raise FileNotFoundError( raise FileNotFoundError(
"The original_directory is not set. Please set it in the terminal using `bob config set bob.db.casia_surf.directory /path/to/database/CASIA-SURF/`." "The original_directory is not set. Please set it in the terminal "
"using\n\t`bob config set bob.db.casia_surf.directory "
"/path/to/database/CASIA-SURF/`.",
) )
transformer = CasiaSurfMultiStreamSample( transformer = CasiaSurfMultiStreamSample(
original_directory=original_directory, original_directory=original_directory,
...@@ -91,7 +103,7 @@ class CasiaSurfPadDatabase(FileListPadDatabase): ...@@ -91,7 +103,7 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
super().__init__( super().__init__(
name="casia-surf", name="casia-surf",
dataset_protocols_path=original_directory, dataset_protocols_path=original_directory,
protocol="all", protocol=protocol,
reader_cls=partial( reader_cls=partial(
CSVToSamples, CSVToSamples,
dict_reader_kwargs=dict( dict_reader_kwargs=dict(
...@@ -110,16 +122,21 @@ class CasiaSurfPadDatabase(FileListPadDatabase): ...@@ -110,16 +122,21 @@ class CasiaSurfPadDatabase(FileListPadDatabase):
self.annotation_type = None self.annotation_type = None
self.fixed_positions = None self.fixed_positions = None
def protocols(self): @classmethod
return ["all"] def protocols(cls):
"""Return the list of all available protocols for this dataset."""
return ["Training", "Val", "Testing"]
def groups(self): @classmethod
def groups(cls):
"""Return the list of all existing groups for this dataset."""
return ["train", "dev", "eval"] return ["train", "dev", "eval"]
def list_file(self, group): def list_file(self, group):
"""Return the protocol definition file full filename for the given group."""
filename = { filename = {
"train": "train_list.txt", "train": "train_list.txt",
"dev": "val_private_list.txt", "dev": "val_private_list.txt",
"eval": "test_private_list.txt", "eval": "test_private_list.txt",
}[group] }[group]
return os.path.join(self.dataset_protocols_path, filename) return str(Path(self.dataset_protocols_path) / filename)
"""Database Interface definition for the 3DMAD (3D Mask Attack Database) dataset."""
import logging import logging
import os
from functools import partial from functools import partial
from pathlib import Path
import h5py import h5py
import numpy as np import numpy as np
...@@ -25,6 +27,7 @@ def load_frames_from_hdf5( ...@@ -25,6 +27,7 @@ def load_frames_from_hdf5(
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
): ):
"""Helper loader to use in :py:class:`bob.pipelines.DelayedSample` objects."""
with h5py.File(hdf5_file) as f: with h5py.File(hdf5_file) as f:
video = f[key][()] video = f[key][()]
# reduce the shape of depth from (N, C, H, W) to (N, H, W) since H == 1 # reduce the shape of depth from (N, C, H, W) to (N, H, W) since H == 1
...@@ -36,25 +39,23 @@ def load_frames_from_hdf5( ...@@ -36,25 +39,23 @@ def load_frames_from_hdf5(
selection_style=selection_style, selection_style=selection_style,
step_size=step_size, step_size=step_size,
) )
data = VideoLikeContainer(video[indices], indices) return VideoLikeContainer(video[indices], indices)
return data
def load_annotations_from_hdf5( def load_annotations_from_hdf5(
hdf5_file, hdf5_file,
): ):
"""Return a dictionary of annotations loaded from a path."""
with h5py.File(hdf5_file) as f: with h5py.File(hdf5_file) as f:
eye_pos = f["Eye_Pos"][()] eye_pos = f["Eye_Pos"][()]
annotations = { return {
str(i): { str(i): {
"reye": [row[1], row[0]], "reye": [row[1], row[0]],
"leye": [row[3], row[2]], "leye": [row[3], row[2]],
} }
for i, row in enumerate(eye_pos) for i, row in enumerate(eye_pos)
} }
return annotations
def delayed_maskattack_video_load( def delayed_maskattack_video_load(
...@@ -64,10 +65,11 @@ def delayed_maskattack_video_load( ...@@ -64,10 +65,11 @@ def delayed_maskattack_video_load(
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
): ):
"""Make :py:class:`bob.pipelines.DelayedSample` objects for mask-attack samples."""
original_directory = original_directory or "" original_directory = original_directory or ""
results = [] results = []
for sample in samples: for sample in samples:
hdf5_file = os.path.join(original_directory, sample.filename) hdf5_file = str(Path(original_directory) / sample.filename)
data = partial( data = partial(
load_frames_from_hdf5, load_frames_from_hdf5,
key="Color_Data", key="Color_Data",
...@@ -98,17 +100,18 @@ def delayed_maskattack_video_load( ...@@ -98,17 +100,18 @@ def delayed_maskattack_video_load(
data, data,
parent=sample, parent=sample,
delayed_attributes=delayed_attributes, delayed_attributes=delayed_attributes,
) ),
) )
return results return results
def MaskAttackPadSample( def MaskAttackPadSample( # noqa: N802
original_directory, original_directory,
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
): ):
"""Return a Transformer creating samples of videos."""
return FunctionTransformer( return FunctionTransformer(
delayed_maskattack_video_load, delayed_maskattack_video_load,
validate=False, validate=False,
...@@ -121,13 +124,15 @@ def MaskAttackPadSample( ...@@ -121,13 +124,15 @@ def MaskAttackPadSample(
) )
def MaskAttackPadDatabase( def MaskAttackPadDatabase( # noqa: N802
protocol="classification", protocol="classification",
*,
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
**kwargs, **kwargs,
): ):
"""Return a Database Interface object for the mask-attack dataset."""
name = "pad-face-mask-attack-6d8854c2.tar.gz" name = "pad-face-mask-attack-6d8854c2.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
......
"""Database Interface definition for the OULU-NPU dataset."""
import logging import logging
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
...@@ -10,8 +12,9 @@ logger = logging.getLogger(__name__) ...@@ -10,8 +12,9 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml") rc = UserDefaults("bobrc.toml")
def OuluNpuPadDatabase( def OuluNpuPadDatabase( # noqa: N802
protocol="Protocol_1", protocol="Protocol_1",
*,
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
step_size=None, step_size=None,
...@@ -20,6 +23,7 @@ def OuluNpuPadDatabase( ...@@ -20,6 +23,7 @@ def OuluNpuPadDatabase(
fixed_positions=None, fixed_positions=None,
**kwargs, **kwargs,
): ):
"""Return a Database Interface object for the OULU-NPU dataset."""
name = "pad-face-oulunpu-7bfb90c9.tar.gz" name = "pad-face-oulunpu-7bfb90c9.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
...@@ -32,7 +36,7 @@ def OuluNpuPadDatabase( ...@@ -32,7 +36,7 @@ def OuluNpuPadDatabase(
name = "annotations-oulunpu-mtcnn-903addac.tar.gz" name = "annotations-oulunpu-mtcnn-903addac.tar.gz"
annotation_directory = download_file( annotation_directory = download_file(
urls=[ urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}" f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
], ],
destination_sub_directory="annotations", destination_sub_directory="annotations",
destination_filename=name, destination_filename=name,
......
"""Database Interface definition for the Replay Attack dataset."""
import logging import logging
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
...@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__) ...@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml") rc = UserDefaults("bobrc.toml")
def ReplayAttackPadDatabase( def ReplayAttackPadDatabase( # noqa: N802
protocol="grandtest", protocol="grandtest",
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
...@@ -20,6 +22,7 @@ def ReplayAttackPadDatabase( ...@@ -20,6 +22,7 @@ def ReplayAttackPadDatabase(
fixed_positions=None, fixed_positions=None,
**kwargs, **kwargs,
): ):
"""Returns a Database Interface object for the Replay Attack dataset."""
name = "pad-face-replay-attack-aca6b46f.tar.gz" name = "pad-face-replay-attack-aca6b46f.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
...@@ -32,7 +35,7 @@ def ReplayAttackPadDatabase( ...@@ -32,7 +35,7 @@ def ReplayAttackPadDatabase(
name = "annotations-replay-attack-mtcnn-8d1f4c12.tar.gz" name = "annotations-replay-attack-mtcnn-8d1f4c12.tar.gz"
annotation_directory = download_file( annotation_directory = download_file(
urls=[ urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}" f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
], ],
destination_filename=name, destination_filename=name,
destination_sub_directory="annotations", destination_sub_directory="annotations",
......
"""Database Interface definition for the Replay Mobile dataset."""
import logging import logging
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
...@@ -13,6 +15,7 @@ rc = UserDefaults("bobrc.toml") ...@@ -13,6 +15,7 @@ rc = UserDefaults("bobrc.toml")
def get_rm_video_transform(sample): def get_rm_video_transform(sample):
"""Return a lazy loader for the video that is correctly oriented."""
should_flip = sample.should_flip should_flip = sample.should_flip
def transform(video): def transform(video):
...@@ -24,7 +27,7 @@ def get_rm_video_transform(sample): ...@@ -24,7 +27,7 @@ def get_rm_video_transform(sample):
return transform return transform
def ReplayMobilePadDatabase( def ReplayMobilePadDatabase( # noqa: N802
protocol="grandtest", protocol="grandtest",
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
...@@ -34,6 +37,7 @@ def ReplayMobilePadDatabase( ...@@ -34,6 +37,7 @@ def ReplayMobilePadDatabase(
fixed_positions=None, fixed_positions=None,
**kwargs, **kwargs,
): ):
"""Return a Database Interface object for the Replay Mobile dataset."""
name = "pad-face-replay-mobile-620dded2.tar.gz" name = "pad-face-replay-mobile-620dded2.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
...@@ -46,7 +50,7 @@ def ReplayMobilePadDatabase( ...@@ -46,7 +50,7 @@ def ReplayMobilePadDatabase(
name = "annotations-replay-mobile-mtcnn-20055a07.tar.gz" name = "annotations-replay-mobile-mtcnn-20055a07.tar.gz"
annotation_directory = download_file( annotation_directory = download_file(
urls=[ urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}" f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
], ],
destination_filename=name, destination_filename=name,
destination_sub_directory="annotations", destination_sub_directory="annotations",
......
"""Database Interface definition for the SWAN dataset."""
import logging import logging
from clapper.rc import UserDefaults from clapper.rc import UserDefaults
...@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__) ...@@ -10,7 +12,7 @@ logger = logging.getLogger(__name__)
rc = UserDefaults("bobrc.toml") rc = UserDefaults("bobrc.toml")
def SwanPadDatabase( def SwanPadDatabase( # noqa: N802
protocol="pad_p2_face_f1", protocol="pad_p2_face_f1",
selection_style=None, selection_style=None,
max_number_of_frames=None, max_number_of_frames=None,
...@@ -20,6 +22,7 @@ def SwanPadDatabase( ...@@ -20,6 +22,7 @@ def SwanPadDatabase(
fixed_positions=None, fixed_positions=None,
**kwargs, **kwargs,
): ):
"""Return a Database interface object for the SWAN dataset."""
name = "pad-face-swan-ce83ebd8.tar.gz" name = "pad-face-swan-ce83ebd8.tar.gz"
dataset_protocols_path = download_file( dataset_protocols_path = download_file(
urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"], urls=[f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}"],
...@@ -32,7 +35,7 @@ def SwanPadDatabase( ...@@ -32,7 +35,7 @@ def SwanPadDatabase(
name = "annotations-swan-mtcnn-9f9e12d8.tar.gz" name = "annotations-swan-mtcnn-9f9e12d8.tar.gz"
annotation_directory = download_file( annotation_directory = download_file(
urls=[ urls=[
f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}" f"http://www.idiap.ch/software/bob/data/bob/bob.pad.face/{name}",
], ],
destination_filename=name, destination_filename=name,
destination_sub_directory="annotations", destination_sub_directory="annotations",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment