Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.base
Commits
193cf6fc
Commit
193cf6fc
authored
Nov 25, 2020
by
Yannick DAYER
Browse files
Handle 'train' group in database.all_samples()
Check parameters with existing function Add tests for train group
parent
a2499e03
Changes
4
Hide whitespace changes
Inline
Side-by-side
bob/bio/base/database/csv_dataset.py
View file @
193cf6fc
...
...
@@ -4,6 +4,7 @@
import
os
from
bob.pipelines
import
Sample
,
DelayedSample
,
SampleSet
from
bob.db.base.utils
import
check_parameters_for_validity
import
csv
import
bob.io.base
import
functools
...
...
@@ -325,16 +326,34 @@ class CSVDatasetDevEval:
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
Groups to consider ('train', 'dev', and/or 'eval'). If `None` is
given, returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
"""
valid_groups
=
[
"train"
]
if
self
.
dev_enroll_csv
and
self
.
dev_probe_csv
:
valid_groups
.
append
(
"dev"
)
if
self
.
eval_enroll_csv
and
self
.
eval_probe_csv
:
valid_groups
.
append
(
"eval"
)
groups
=
check_parameters_for_validity
(
parameters
=
groups
,
parameter_description
=
"groups"
,
valid_parameters
=
valid_groups
,
default_parameters
=
valid_groups
,
)
samples
=
[]
# Get train samples (background_model_samples returns a list of samples)
samples
=
self
.
background_model_samples
()
if
"train"
in
groups
:
samples
=
samples
+
self
.
background_model_samples
()
groups
.
remove
(
"train"
)
# Get enroll and probe samples
groups
=
[
"dev"
,
"eval"
]
if
not
groups
else
groups
if
"eval"
in
groups
and
(
not
self
.
eval_enroll_csv
or
not
self
.
eval_probe_csv
):
logger
.
warning
(
"'eval' requested, but dataset has no 'eval' group."
)
groups
.
remove
(
"eval"
)
for
group
in
groups
:
for
purpose
in
(
"enroll"
,
"probe"
):
label
=
f
"
{
group
}
_
{
purpose
}
_csv"
...
...
@@ -478,19 +497,33 @@ class CSVDatasetCrossValidation:
Parameters
----------
groups: list or None
Groups to consider, or all groups if `None` is given.
Groups to consider ('train' and/or 'dev'). If `None` is given,
returns the samples from all groups.
Returns
-------
samples: list
List of :class:`bob.pipelines.Sample` objects.
"""
valid_groups
=
[
"train"
,
"dev"
]
groups
=
check_parameters_for_validity
(
parameters
=
groups
,
parameter_description
=
"groups"
,
valid_parameters
=
valid_groups
,
default_parameters
=
valid_groups
,
)
samples
=
[]
# Get train samples (background_model_samples returns a list of samples)
samples
=
self
.
background_model_samples
()
if
"train"
in
groups
:
samples
=
samples
+
self
.
background_model_samples
()
groups
.
remove
(
"train"
)
# Get enroll and probe samples
groups
=
[
"dev"
]
if
not
groups
else
groups
if
"eval"
in
groups
:
logger
.
info
(
"'eval' requested but there is no 'eval' group defined."
)
groups
.
remove
(
"eval"
)
for
group
in
groups
:
samples
=
samples
+
[
s
for
s_set
in
self
.
references
(
group
)
for
s
in
s_set
]
samples
=
samples
+
[
s
for
s_set
in
self
.
probes
(
group
)
for
s
in
s_set
]
samples
=
samples
+
[
s
for
s_set
in
self
.
references
(
group
)
for
s
in
s_set
]
samples
=
samples
+
[
s
for
s_set
in
self
.
probes
(
group
)
for
s
in
s_set
]
return
samples
...
...
bob/bio/base/pipelines/vanilla_biometrics/legacy.py
View file @
193cf6fc
...
...
@@ -12,6 +12,7 @@ from bob.bio.base.algorithm import Algorithm
from
bob.pipelines
import
DelayedSample
from
bob.pipelines
import
DelayedSampleSet
from
bob.pipelines
import
SampleSet
from
bob.db.base.utils
import
check_parameters_for_validity
from
.abstract_classes
import
BioAlgorithm
from
.abstract_classes
import
Database
...
...
@@ -188,15 +189,23 @@ class DatabaseConnector(Database):
Parameters
----------
groups: list or `None`
List of groups to consider (
like 'dev' or 'eval'). If `None`, will
return samples from all the groups.
List of groups to consider (
'world'/'train', 'dev', and/or 'eval').
If `None` is given,
return
s
samples from all the groups.
Returns
-------
samples: list
List of all the samples of a database
, conforming to the pipeline
API. See, e.g., :py:func:`bob.pipelines.first`
.
List of all the samples of a database
in :class:`bob.pipelines.Sample`
objects
.
"""
valid_groups
=
self
.
database
.
groups
()
groups
=
check_parameters_for_validity
(
parameters
=
groups
,
parameter_description
=
"groups"
,
valid_parameters
=
valid_groups
,
default_parameters
=
valid_groups
,
)
logger
.
debug
(
f
"Fetching all samples of groups '
{
groups
}
'."
)
objects
=
self
.
database
.
all_files
(
groups
=
groups
)
return
[
_biofile_to_delayed_sample
(
k
,
self
.
database
)
for
k
in
objects
]
...
...
bob/bio/base/test/test_database_implementations.py
View file @
193cf6fc
...
...
@@ -57,3 +57,6 @@ def test_all_samples():
all_samples
=
dummy_database
.
all_samples
(
groups
=
None
)
assert
len
(
all_samples
)
==
400
assert
all
([
isinstance
(
s
,
DelayedSample
)
for
s
in
all_samples
])
assert
len
(
dummy_database
.
all_samples
(
groups
=
[
"world"
]))
==
200
assert
len
(
dummy_database
.
all_samples
(
groups
=
[
"dev"
]))
==
200
assert
len
(
dummy_database
.
all_samples
(
groups
=
[]))
==
400
bob/bio/base/test/test_filelist.py
View file @
193cf6fc
...
...
@@ -117,7 +117,8 @@ def test_csv_file_list_atnt():
assert
len
(
dataset
.
background_model_samples
())
==
200
assert
len
(
dataset
.
references
())
==
20
assert
len
(
dataset
.
probes
())
==
100
assert
len
(
dataset
.
all_samples
(
groups
=
[
"dev"
]))
==
400
assert
len
(
dataset
.
all_samples
(
groups
=
[
"train"
]))
==
200
assert
len
(
dataset
.
all_samples
(
groups
=
[
"dev"
]))
==
200
assert
len
(
dataset
.
all_samples
(
groups
=
None
))
==
400
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment