Commit 6678bcbe authored by Yannick DAYER's avatar Yannick DAYER
Browse files

Keep low level group name away from the high level

parent d67c8d56
Pipeline #46113 passed with stage
in 12 minutes and 34 seconds
......@@ -12,7 +12,11 @@ from bob.bio.base.algorithm import Algorithm
from bob.pipelines import DelayedSample
from bob.pipelines import DelayedSampleSet
from bob.pipelines import SampleSet
from bob.db.base.utils import check_parameters_for_validity
from bob.db.base.utils import (
check_parameters_for_validity,
convert_names_to_highlevel,
convert_names_to_lowlevel,
)
from .abstract_classes import BioAlgorithm
from .abstract_classes import Database
......@@ -189,7 +193,7 @@ class DatabaseConnector(Database):
Parameters
----------
groups: list or `None`
List of groups to consider ('world'/'train', 'dev', and/or 'eval').
List of groups to consider ('train', 'dev', and/or 'eval').
If `None` is given, returns samples from all the groups.
Returns
......@@ -198,7 +202,11 @@ class DatabaseConnector(Database):
List of all the samples of a database in :class:`bob.pipelines.Sample`
objects.
"""
valid_groups = self.database.groups()
valid_groups = convert_names_to_highlevel(
self.database.groups(),
low_level_names=["world", "dev", "eval"],
high_level_names=["train", "dev", "eval"],
)
groups = check_parameters_for_validity(
parameters=groups,
parameter_description="groups",
......@@ -206,7 +214,12 @@ class DatabaseConnector(Database):
default_parameters=valid_groups,
)
logger.debug(f"Fetching all samples of groups '{groups}'.")
objects = self.database.all_files(groups=groups)
low_level_groups = convert_names_to_lowlevel(
names=groups,
low_level_names=["world", "dev", "eval"],
high_level_names=["train", "dev", "eval"],
)
objects = self.database.all_files(groups=low_level_groups)
return [_biofile_to_delayed_sample(k, self.database) for k in objects]
......
......@@ -57,6 +57,6 @@ def test_all_samples():
all_samples = dummy_database.all_samples(groups=None)
assert len(all_samples) == 400
assert all([isinstance(s, DelayedSample) for s in all_samples])
assert len(dummy_database.all_samples(groups=["world"])) == 200
assert len(dummy_database.all_samples(groups=["train"])) == 200
assert len(dummy_database.all_samples(groups=["dev"])) == 200
assert len(dummy_database.all_samples(groups=[])) == 400
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment