Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.base
Commits
4de7cc3b
Commit
4de7cc3b
authored
Oct 07, 2020
by
Tiago de Freitas Pereira
Browse files
Implemented CrossValidation Filelist dataset
parent
34bd50fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
bob/bio/base/database/__init__.py
View file @
4de7cc3b
from
.csv_dataset
import
CSVDatasetDevEval
,
CSVToSampleLoader
from
.csv_dataset
import
CSVDatasetDevEval
,
CSVToSampleLoader
,
CSVDatasetCrossValidation
from
.file
import
BioFile
from
.file
import
BioFileSet
from
.database
import
BioDatabase
...
...
bob/bio/base/database/csv_dataset.py
View file @
4de7cc3b
...
...
@@ -8,6 +8,8 @@ import csv
import
bob.io.base
import
functools
from
abc
import
ABCMeta
,
abstractmethod
import
numpy
as
np
import
itertools
class
CSVBaseSampleLoader
(
metaclass
=
ABCMeta
):
...
...
@@ -91,7 +93,10 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
subject
=
row
[
1
]
kwargs
=
dict
([[
h
,
r
]
for
h
,
r
in
zip
(
header
[
2
:],
row
[
2
:])])
return
DelayedSample
(
functools
.
partial
(
self
.
data_loader
,
os
.
path
.
join
(
self
.
dataset_original_directory
,
path
+
self
.
extension
)),
functools
.
partial
(
self
.
data_loader
,
os
.
path
.
join
(
self
.
dataset_original_directory
,
path
+
self
.
extension
),
),
key
=
path
,
subject
=
subject
,
**
kwargs
,
...
...
@@ -118,11 +123,15 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
sample_sets
[
s
.
subject
]
=
SampleSet
(
[
s
],
**
get_attribute_from_sample
(
s
)
)
sample_sets
[
s
.
subject
].
append
(
s
)
else
:
sample_sets
[
s
.
subject
].
append
(
s
)
return
list
(
sample_sets
.
values
())
else
:
return
[
SampleSet
([
s
],
**
get_attribute_from_sample
(
s
),
references
=
references
)
for
s
in
samples
]
return
[
SampleSet
([
s
],
**
get_attribute_from_sample
(
s
),
references
=
references
)
for
s
in
samples
]
class
CSVDatasetDevEval
:
...
...
@@ -194,8 +203,9 @@ class CSVDatasetDevEval:
protocol_na,e: str
The name of the protocol
csv_to_sample_loader:
csv_to_sample_loader: :any:`CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Samples`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
...
...
@@ -281,9 +291,6 @@ class CSVDatasetDevEval:
return
self
.
cache
[
"train"
]
def
_get_subjects_from_samplesets
(
self
,
sample_sets
):
return
list
(
set
([
s
.
subject
for
s
in
sample_sets
]))
def
_get_samplesets
(
self
,
group
=
"dev"
,
purpose
=
"enroll"
,
group_by_subject
=
False
):
fetching_probes
=
False
...
...
@@ -298,9 +305,7 @@ class CSVDatasetDevEval:
references
=
None
if
fetching_probes
:
references
=
self
.
_get_subjects_from_samplesets
(
self
.
references
(
group
=
group
)
)
references
=
list
(
set
([
s
.
subject
for
s
in
self
.
references
(
group
=
group
)]))
samples
=
self
.
csv_to_sample_loader
(
self
.
__dict__
[
cache_label
])
...
...
@@ -321,3 +326,144 @@ class CSVDatasetDevEval:
return
self
.
_get_samplesets
(
group
=
group
,
purpose
=
"probe"
,
group_by_subject
=
False
)
class
CSVDatasetCrossValidation
:
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.VanillaBiometrics` pipeline that
handles **CROSS VALIDATION**.
Check :ref:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
interface.
This interface will take one `csv_file` as input and split into i-) data for training and
ii-) data for testing.
The data for testing will be further split in data for enrollment and data for probing.
The input CSV file should be casted in the following format:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
...
Parameters
----------
csv_file_name: str
CSV file containing all the samples from your database
random_state: int
Pseudo-random number generator seed
test_size: float
Percentage of the subjects used for testing
samples_for_enrollment: float
Number of samples used for enrollment
csv_to_sample_loader: :any:`CSVBaseSampleLoader`
Base class that whose objective is to generate :any:`bob.pipelines.Samples`
and/or :any:`bob.pipelines.SampleSet` from csv rows
"""
def
__init__
(
self
,
csv_file_name
=
"metadata.csv"
,
random_state
=
0
,
test_size
=
0.8
,
samples_for_enrollment
=
1
,
csv_to_sample_loader
=
CSVToSampleLoader
(
data_loader
=
bob
.
io
.
base
.
load
,
dataset_original_directory
=
""
,
extension
=
""
),
):
def
get_dict_cache
():
cache
=
dict
()
cache
[
"train"
]
=
None
cache
[
"dev_enroll_csv"
]
=
None
cache
[
"dev_probe_csv"
]
=
None
return
cache
self
.
random_state
=
random_state
self
.
cache
=
get_dict_cache
()
self
.
csv_to_sample_loader
=
csv_to_sample_loader
self
.
csv_file_name
=
csv_file_name
self
.
samples_for_enrollment
=
samples_for_enrollment
self
.
test_size
=
test_size
if
self
.
test_size
<
0
and
self
.
test_size
>
1
:
raise
ValueError
(
f
"`test_size` should be between 0 and 1.
{
test_size
}
is provided"
)
def
_do_cross_validation
(
self
):
# Shuffling samples by subject
samples_by_subject
=
group_samples_by_subject
(
self
.
csv_to_sample_loader
(
self
.
csv_file_name
)
)
subjects
=
list
(
samples_by_subject
.
keys
())
np
.
random
.
seed
(
self
.
random_state
)
np
.
random
.
shuffle
(
subjects
)
# Getting the training data
n_samples_for_training
=
len
(
subjects
)
-
int
(
self
.
test_size
*
len
(
subjects
))
self
.
cache
[
"train"
]
=
list
(
itertools
.
chain
(
*
[
samples_by_subject
[
s
]
for
s
in
subjects
[
0
:
n_samples_for_training
]]
)
)
# Splitting enroll and probe
self
.
cache
[
"dev_enroll_csv"
]
=
[]
self
.
cache
[
"dev_probe_csv"
]
=
[]
for
s
in
subjects
[
n_samples_for_training
:]:
samples
=
samples_by_subject
[
s
]
if
len
(
samples
)
<
self
.
samples_for_enrollment
:
raise
ValueError
(
f
"Not enough samples (
{
len
(
samples
)
}
) for enrollment for the subject
{
s
}
"
)
# Enrollment samples
self
.
cache
[
"dev_enroll_csv"
].
append
(
self
.
csv_to_sample_loader
.
convert_samples_to_samplesets
(
samples
[
0
:
self
.
samples_for_enrollment
]
)[
0
]
)
self
.
cache
[
"dev_probe_csv"
]
+=
self
.
csv_to_sample_loader
.
convert_samples_to_samplesets
(
samples
[
self
.
samples_for_enrollment
:],
group_by_subject
=
False
,
references
=
subjects
[
n_samples_for_training
:],
)
def
_load_from_cache
(
self
,
cache_key
):
if
self
.
cache
[
cache_key
]
is
None
:
self
.
_do_cross_validation
()
return
self
.
cache
[
cache_key
]
def
background_model_samples
(
self
):
return
self
.
_load_from_cache
(
"train"
)
def
references
(
self
,
group
=
"dev"
):
return
self
.
_load_from_cache
(
"dev_enroll_csv"
)
def
probes
(
self
,
group
=
"dev"
):
return
self
.
_load_from_cache
(
"dev_probe_csv"
)
def
group_samples_by_subject
(
samples
):
# Grouping sample sets
samples_by_subject
=
dict
()
for
s
in
samples
:
if
s
.
subject
not
in
samples_by_subject
:
samples_by_subject
[
s
.
subject
]
=
[]
samples_by_subject
[
s
.
subject
].
append
(
s
)
return
samples_by_subject
bob/bio/base/test/data/atnt/cross_validation/metadata.csv
0 → 100644
View file @
4de7cc3b
PATH,SUBJECT
s1/9,1
s1/2,1
s1/4,1
s1/5,1
s1/7,1
s1/8,1
s1/1,1
s1/10,1
s1/3,1
s1/6,1
s2/9,2
s2/2,2
s2/4,2
s2/5,2
s2/7,2
s2/8,2
s2/1,2
s2/10,2
s2/3,2
s2/6,2
s5/9,5
s5/2,5
s5/4,5
s5/5,5
s5/7,5
s5/8,5
s5/1,5
s5/10,5
s5/3,5
s5/6,5
s6/9,6
s6/2,6
s6/4,6
s6/5,6
s6/7,6
s6/8,6
s6/1,6
s6/10,6
s6/3,6
s6/6,6
s10/9,10
s10/2,10
s10/4,10
s10/5,10
s10/7,10
s10/8,10
s10/1,10
s10/10,10
s10/3,10
s10/6,10
s11/9,11
s11/2,11
s11/4,11
s11/5,11
s11/7,11
s11/8,11
s11/1,11
s11/10,11
s11/3,11
s11/6,11
s12/9,12
s12/2,12
s12/4,12
s12/5,12
s12/7,12
s12/8,12
s12/1,12
s12/10,12
s12/3,12
s12/6,12
s14/9,14
s14/2,14
s14/4,14
s14/5,14
s14/7,14
s14/8,14
s14/1,14
s14/10,14
s14/3,14
s14/6,14
s16/9,16
s16/2,16
s16/4,16
s16/5,16
s16/7,16
s16/8,16
s16/1,16
s16/10,16
s16/3,16
s16/6,16
s17/9,17
s17/2,17
s17/4,17
s17/5,17
s17/7,17
s17/8,17
s17/1,17
s17/10,17
s17/3,17
s17/6,17
s20/9,20
s20/2,20
s20/4,20
s20/5,20
s20/7,20
s20/8,20
s20/1,20
s20/10,20
s20/3,20
s20/6,20
s21/9,21
s21/2,21
s21/4,21
s21/5,21
s21/7,21
s21/8,21
s21/1,21
s21/10,21
s21/3,21
s21/6,21
s24/9,24
s24/2,24
s24/4,24
s24/5,24
s24/7,24
s24/8,24
s24/1,24
s24/10,24
s24/3,24
s24/6,24
s26/9,26
s26/2,26
s26/4,26
s26/5,26
s26/7,26
s26/8,26
s26/1,26
s26/10,26
s26/3,26
s26/6,26
s27/9,27
s27/2,27
s27/4,27
s27/5,27
s27/7,27
s27/8,27
s27/1,27
s27/10,27
s27/3,27
s27/6,27
s29/9,29
s29/2,29
s29/4,29
s29/5,29
s29/7,29
s29/8,29
s29/1,29
s29/10,29
s29/3,29
s29/6,29
s33/9,33
s33/2,33
s33/4,33
s33/5,33
s33/7,33
s33/8,33
s33/1,33
s33/10,33
s33/3,33
s33/6,33
s34/9,34
s34/2,34
s34/4,34
s34/5,34
s34/7,34
s34/8,34
s34/1,34
s34/10,34
s34/3,34
s34/6,34
s36/9,36
s36/2,36
s36/4,36
s36/5,36
s36/7,36
s36/8,36
s36/1,36
s36/10,36
s36/3,36
s36/6,36
s39/9,39
s39/2,39
s39/4,39
s39/5,39
s39/7,39
s39/8,39
s39/1,39
s39/10,39
s39/3,39
s39/6,39
s3/9,3
s3/2,3
s3/4,3
s3/5,3
s3/7,3
s4/9,4
s4/2,4
s4/4,4
s4/5,4
s4/7,4
s7/9,7
s7/2,7
s7/4,7
s7/5,7
s7/7,7
s8/9,8
s8/2,8
s8/4,8
s8/5,8
s8/7,8
s9/9,9
s9/2,9
s9/4,9
s9/5,9
s9/7,9
s13/9,13
s13/2,13
s13/4,13
s13/5,13
s13/7,13
s15/9,15
s15/2,15
s15/4,15
s15/5,15
s15/7,15
s18/9,18
s18/2,18
s18/4,18
s18/5,18
s18/7,18
s19/9,19
s19/2,19
s19/4,19
s19/5,19
s19/7,19
s22/9,22
s22/2,22
s22/4,22
s22/5,22
s22/7,22
s23/9,23
s23/2,23
s23/4,23
s23/5,23
s23/7,23
s25/9,25
s25/2,25
s25/4,25
s25/5,25
s25/7,25
s28/9,28
s28/2,28
s28/4,28
s28/5,28
s28/7,28
s30/9,30
s30/2,30
s30/4,30
s30/5,30
s30/7,30
s31/9,31
s31/2,31
s31/4,31
s31/5,31
s31/7,31
s32/9,32
s32/2,32
s32/4,32
s32/5,32
s32/7,32
s35/9,35
s35/2,35
s35/4,35
s35/5,35
s35/7,35
s37/9,37
s37/2,37
s37/4,37
s37/5,37
s37/7,37
s38/9,38
s38/2,38
s38/4,38
s38/5,38
s38/7,38
s40/9,40
s40/2,40
s40/4,40
s40/5,40
s40/7,40
s3/8,3
s3/1,3
s3/10,3
s3/3,3
s3/6,3
s4/8,4
s4/1,4
s4/10,4
s4/3,4
s4/6,4
s7/8,7
s7/1,7
s7/10,7
s7/3,7
s7/6,7
s8/8,8
s8/1,8
s8/10,8
s8/3,8
s8/6,8
s9/8,9
s9/1,9
s9/10,9
s9/3,9
s9/6,9
s13/8,13
s13/1,13
s13/10,13
s13/3,13
s13/6,13
s15/8,15
s15/1,15
s15/10,15
s15/3,15
s15/6,15
s18/8,18
s18/1,18
s18/10,18
s18/3,18
s18/6,18
s19/8,19
s19/1,19
s19/10,19
s19/3,19
s19/6,19
s22/8,22
s22/1,22
s22/10,22
s22/3,22
s22/6,22
s23/8,23
s23/1,23
s23/10,23
s23/3,23
s23/6,23
s25/8,25
s25/1,25
s25/10,25
s25/3,25
s25/6,25
s28/8,28
s28/1,28
s28/10,28
s28/3,28
s28/6,28
s30/8,30
s30/1,30
s30/10,30
s30/3,30
s30/6,30
s31/8,31
s31/1,31
s31/10,31
s31/3,31
s31/6,31
s32/8,32
s32/1,32
s32/10,32
s32/3,32
s32/6,32
s35/8,35
s35/1,35
s35/10,35
s35/3,35
s35/6,35
s37/8,37
s37/1,37
s37/10,37
s37/3,37
s37/6,37
s38/8,38
s38/1,38
s38/10,38
s38/3,38
s38/6,38
s40/8,40
s40/1,40
s40/10,40
s40/3,40
s40/6,40
bob/bio/base/test/test_filelist.py
View file @
4de7cc3b
...
...
@@ -7,7 +7,7 @@
import
os
import
bob.io.base
import
bob.io.base.test_utils
from
bob.bio.base.database
import
CSVDatasetDevEval
,
CSVToSampleLoader