diff --git a/src/bob/pad/base/pipelines/abstract_classes.py b/src/bob/pad/base/pipelines/abstract_classes.py index 8383a610b0ae05da2006b3d5ac8c105acec9ff9f..f0ed71a21fae01d984165ac84003e74340f6693e 100644 --- a/src/bob/pad/base/pipelines/abstract_classes.py +++ b/src/bob/pad/base/pipelines/abstract_classes.py @@ -1,11 +1,15 @@ +from __future__ import annotations + from abc import ABCMeta, abstractmethod +from bob.pipelines import Sample + class Database(metaclass=ABCMeta): """Base database class for PAD experiments.""" @abstractmethod - def fit_samples(self): + def fit_samples(self) -> list[Sample]: """Returns :any:`bob.pipelines.Sample`'s to train a PAD model. Returns @@ -16,7 +20,7 @@ class Database(metaclass=ABCMeta): pass @abstractmethod - def predict_samples(self, group="dev"): + def predict_samples(self, group: str = "dev") -> list[Sample]: """Returns :any:`bob.pipelines.Sample`'s to be scored. Parameters @@ -30,3 +34,20 @@ class Database(metaclass=ABCMeta): List of samples to be scored. """ pass + + def all_samples( + self, groups: str | list[str] | None = None + ) -> list[Sample]: + """Returns all the samples of the database in one list. + + Giving ``groups`` will restrict the ``predict_samples`` to those groups. + """ + samples = self.fit_samples() + if groups is not None: + if type(groups) is str: + groups = [groups] + for group in groups: + samples.extend(self.predict_samples(group=group)) + else: + samples.extend(self.predict_samples(group=group)) + return samples