From 4db20087987651a4313308334fb5538404b0bad9 Mon Sep 17 00:00:00 2001 From: Yannick DAYER <yannick.dayer@idiap.ch> Date: Wed, 19 Jul 2023 11:18:40 +0200 Subject: [PATCH] fix: Database now contains an expected all_samples Corresponds to the bob.bio.base equivalent, allowing use of pad datasets with bio commands (like bob bio annotate). --- .../pad/base/pipelines/abstract_classes.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/src/bob/pad/base/pipelines/abstract_classes.py b/src/bob/pad/base/pipelines/abstract_classes.py index 8383a61..caf2498 100644 --- a/src/bob/pad/base/pipelines/abstract_classes.py +++ b/src/bob/pad/base/pipelines/abstract_classes.py @@ -1,11 +1,15 @@ +from __future__ import annotations + from abc import ABCMeta, abstractmethod +from bob.pipelines import Sample + class Database(metaclass=ABCMeta): """Base database class for PAD experiments.""" @abstractmethod - def fit_samples(self): + def fit_samples(self) -> list[Sample]: """Returns :any:`bob.pipelines.Sample`'s to train a PAD model. Returns @@ -16,7 +20,7 @@ class Database(metaclass=ABCMeta): pass @abstractmethod - def predict_samples(self, group="dev"): + def predict_samples(self, group: str = "dev") -> list[Sample]: """Returns :any:`bob.pipelines.Sample`'s to be scored. Parameters @@ -30,3 +34,20 @@ class Database(metaclass=ABCMeta): List of samples to be scored. """ pass + + def all_samples( + self, groups: str | list[str] | None = None + ) -> list[Sample]: + """Returns all the samples of the database in one list. + + Giving ``groups`` will restrict the ``predict_samples`` to those groups. + """ + samples = self.fit_samples() + if groups is not None: + if type(groups) is str: + groups = [groups] + for group in groups: + samples.extend(self.predict_samples(group=group)) + else: + samples.extend(group=None) + return samples -- GitLab