Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.base
Commits
4930dda2
Commit
4930dda2
authored
Nov 30, 2020
by
Tiago de Freitas Pereira
Browse files
Adapting CSVDevEval to work with our current FileList Structure
parent
52934fcb
Pipeline
#46207
failed with stage
in 1 minute and 25 seconds
Changes
58
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
.gitignore
View file @
4930dda2
...
...
@@ -13,4 +13,4 @@ sphinx
dist
build
record.txt
.DS_Store
*
.DS_Store
bob/bio/base/algorithm/__init__.py
View file @
4930dda2
...
...
@@ -3,11 +3,10 @@ from .Distance import Distance
from
.PCA
import
PCA
from
.LDA
import
LDA
from
.PLDA
import
PLDA
from
.BIC
import
BIC
# gets sphinx autodoc done right - don't remove it
def
__appropriate__
(
*
args
):
"""Says object was actually declared here, and not in the import module.
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
...
...
@@ -17,15 +16,12 @@ def __appropriate__(*args):
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for
obj
in
args
:
obj
.
__module__
=
__name__
for
obj
in
args
:
obj
.
__module__
=
__name__
__appropriate__
(
Algorithm
,
Distance
,
PCA
,
LDA
,
PLDA
,
BIC
,
)
Algorithm
,
Distance
,
PCA
,
LDA
,
PLDA
,
)
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
"_"
)]
bob/bio/base/database/__init__.py
View file @
4930dda2
...
...
@@ -3,6 +3,8 @@ from .csv_dataset import (
CSVToSampleLoader
,
CSVDatasetCrossValidation
,
CSVBaseSampleLoader
,
IdiapAnnotationsLoader
,
LSTToSampleLoader
,
)
from
.file
import
BioFile
from
.file
import
BioFileSet
...
...
bob/bio/base/database/csv_dataset.py
View file @
4930dda2
...
...
@@ -12,9 +12,57 @@ from abc import ABCMeta, abstractmethod
import
numpy
as
np
import
itertools
import
logging
import
bob.db.base
from
bob.bio.base.pipelines.vanilla_biometrics.abstract_classes
import
Database
logger
=
logging
.
getLogger
(
__name__
)
#####
# ANNOTATIONS LOADERS
####
class
IdiapAnnotationsLoader
:
"""
Load annotations in the Idiap format
"""
def
__init__
(
self
,
annotation_directory
=
None
,
annotation_extension
=
".pos"
,
annotation_type
=
"eyecenter"
,
):
self
.
annotation_directory
=
annotation_directory
self
.
annotation_extension
=
annotation_extension
self
.
annotation_type
=
annotation_type
def
__call__
(
self
,
row
,
header
=
None
):
if
self
.
annotation_directory
is
None
:
return
None
path
=
row
[
0
]
# since the file id is equal to the file name, we can simply use it
annotation_file
=
os
.
path
.
join
(
self
.
annotation_directory
,
path
+
self
.
annotation_extension
)
# return the annotations as read from file
annotation
=
{
"annotations"
:
bob
.
db
.
base
.
read_annotation_file
(
annotation_file
,
self
.
annotation_type
)
}
return
annotation
#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######
class
CSVBaseSampleLoader
(
metaclass
=
ABCMeta
):
"""
Convert CSV files in the format below to either a list of
...
...
@@ -22,10 +70,10 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
.. code-block:: text
PATH,
SUBJECT
path_1,
subject
_1
path_2,
subject
_2
path_i,
subject
_j
PATH,
REFERENCE_ID
path_1,
reference_id
_1
path_2,
reference_id
_2
path_i,
reference_id
_j
...
.. note::
...
...
@@ -43,10 +91,17 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
def
__init__
(
self
,
data_loader
,
dataset_original_directory
=
""
,
extension
=
""
):
def
__init__
(
self
,
data_loader
,
metadata_loader
=
None
,
dataset_original_directory
=
""
,
extension
=
""
,
):
self
.
data_loader
=
data_loader
self
.
extension
=
extension
self
.
dataset_original_directory
=
dataset_original_directory
self
.
metadata_loader
=
metadata_loader
@
abstractmethod
def
__call__
(
self
,
filename
):
...
...
@@ -56,11 +111,24 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
def
convert_row_to_sample
(
self
,
row
,
header
):
pass
@
abstractmethod
def
convert_samples_to_samplesets
(
self
,
samples
,
group_by_
subject
=
True
,
references
=
None
self
,
samples
,
group_by_
reference_id
=
True
,
references
=
None
):
pass
if
group_by_reference_id
:
# Grouping sample sets
sample_sets
=
dict
()
for
s
in
samples
:
if
s
.
reference_id
not
in
sample_sets
:
sample_sets
[
s
.
reference_id
]
=
SampleSet
(
[
s
],
parent
=
s
,
references
=
references
)
else
:
sample_sets
[
s
.
reference_id
].
append
(
s
)
return
list
(
sample_sets
.
values
())
else
:
return
[
SampleSet
([
s
],
parent
=
s
,
references
=
references
)
for
s
in
samples
]
class
CSVToSampleLoader
(
CSVBaseSampleLoader
):
...
...
@@ -71,11 +139,13 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
def
check_header
(
self
,
header
):
"""
A header should have at least "
SUBJECT
" AND "PATH"
A header should have at least "
reference_id
" AND "PATH"
"""
header
=
[
h
.
lower
()
for
h
in
header
]
if
not
"subject"
in
header
:
raise
ValueError
(
"The field `subject` is not available in your dataset."
)
if
not
"reference_id"
in
header
:
raise
ValueError
(
"The field `reference_id` is not available in your dataset."
)
if
not
"path"
in
header
:
raise
ValueError
(
"The field `path` is not available in your dataset."
)
...
...
@@ -91,42 +161,67 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
def
convert_row_to_sample
(
self
,
row
,
header
):
path
=
row
[
0
]
subject
=
row
[
1
]
reference_id
=
row
[
1
]
kwargs
=
dict
([[
h
,
r
]
for
h
,
r
in
zip
(
header
[
2
:],
row
[
2
:])])
if
self
.
metadata_loader
is
not
None
:
metadata
=
self
.
metadata_loader
(
row
)
kwargs
.
update
(
metadata
)
return
DelayedSample
(
functools
.
partial
(
self
.
data_loader
,
os
.
path
.
join
(
self
.
dataset_original_directory
,
path
+
self
.
extension
),
),
key
=
path
,
subject
=
subject
,
reference_id
=
reference_id
,
**
kwargs
,
)
def
convert_samples_to_samplesets
(
self
,
samples
,
group_by_subject
=
True
,
references
=
None
):
if
group_by_subject
:
# Grouping sample sets
sample_sets
=
dict
()
for
s
in
samples
:
if
s
.
subject
not
in
sample_sets
:
sample_sets
[
s
.
subject
]
=
SampleSet
(
[
s
],
parent
=
s
,
references
=
references
)
else
:
sample_sets
[
s
.
subject
].
append
(
s
)
return
list
(
sample_sets
.
values
())
class
LSTToSampleLoader
(
CSVBaseSampleLoader
):
"""
Simple mechanism to convert LST files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def
__call__
(
self
,
filename
):
with
open
(
filename
)
as
cf
:
reader
=
csv
.
reader
(
cf
,
delimiter
=
" "
)
return
[
self
.
convert_row_to_sample
(
row
)
for
row
in
reader
]
def
convert_row_to_sample
(
self
,
row
,
header
=
None
):
path
=
row
[
0
]
reference_id
=
str
(
row
[
1
])
kwargs
=
dict
()
if
len
(
row
)
==
3
:
subject
=
row
[
2
]
kwargs
=
{
"subject"
:
str
(
subject
)}
if
self
.
metadata_loader
is
not
None
:
metadata
=
self
.
metadata_loader
(
row
)
kwargs
.
update
(
metadata
)
return
DelayedSample
(
functools
.
partial
(
self
.
data_loader
,
os
.
path
.
join
(
self
.
dataset_original_directory
,
path
+
self
.
extension
),
),
key
=
path
,
reference_id
=
reference_id
,
**
kwargs
,
)
else
:
return
[
SampleSet
([
s
],
parent
=
s
,
references
=
references
)
for
s
in
samples
]
#####
# DATABASE INTERFACES
#####
class
CSVDatasetDevEval
:
class
CSVDatasetDevEval
(
Database
):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
...
...
@@ -154,17 +249,17 @@ class CSVDatasetDevEval:
- dev_probe.csv
Those csv files should contain in each row i-) the path to raw data and ii-) the
subject
label
Those csv files should contain in each row i-) the path to raw data and ii-) the
reference_id
label
for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
The structure of each CSV file should be as below:
.. code-block:: text
PATH,
SUBJECT
path_1,
subject
_1
path_2,
subject
_2
path_i,
subject
_j
PATH,
reference_id
path_1,
reference_id
_1
path_2,
reference_id
_2
path_i,
reference_id
_j
...
...
...
@@ -173,10 +268,10 @@ class CSVDatasetDevEval:
.. code-block:: text
PATH,
SUBJECT
,METADATA_1,METADATA_2,METADATA_k
path_1,
subject
_1,A,B,C
path_2,
subject
_2,A,B,1
path_i,
subject
_j,2,3,4
PATH,
reference_id
,METADATA_1,METADATA_2,METADATA_k
path_1,
reference_id
_1,A,B,C
path_2,
reference_id
_2,A,B,1
path_i,
reference_id
_j,2,3,4
...
...
...
@@ -206,9 +301,14 @@ class CSVDatasetDevEval:
dataset_protocol_path
,
protocol_name
,
csv_to_sample_loader
=
CSVToSampleLoader
(
data_loader
=
bob
.
io
.
base
.
load
,
dataset_original_directory
=
""
,
extension
=
""
data_loader
=
bob
.
io
.
base
.
load
,
metadata_loader
=
None
,
dataset_original_directory
=
""
,
extension
=
""
,
),
):
self
.
dataset_protocol_path
=
dataset_protocol_path
def
get_paths
():
if
not
os
.
path
.
exists
(
dataset_protocol_path
):
...
...
@@ -219,11 +319,34 @@ class CSVDatasetDevEval:
if
not
os
.
path
.
exists
(
protocol_path
):
raise
ValueError
(
f
"The protocol `
{
protocol_name
}
` was not found"
)
train_csv
=
os
.
path
.
join
(
protocol_path
,
"train.csv"
)
dev_enroll_csv
=
os
.
path
.
join
(
protocol_path
,
"dev_enroll.csv"
)
dev_probe_csv
=
os
.
path
.
join
(
protocol_path
,
"dev_probe.csv"
)
eval_enroll_csv
=
os
.
path
.
join
(
protocol_path
,
"eval_enroll.csv"
)
eval_probe_csv
=
os
.
path
.
join
(
protocol_path
,
"eval_probe.csv"
)
def
path_discovery
(
option1
,
option2
):
return
option1
if
os
.
path
.
exists
(
option1
)
else
option2
# Here we are handling the legacy
train_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_path
,
"norm"
,
"train_world.lst"
),
os
.
path
.
join
(
protocol_path
,
"norm"
,
"train_world.csv"
),
)
dev_enroll_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_path
,
"dev"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_path
,
"dev"
,
"for_models.csv"
),
)
dev_probe_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_path
,
"dev"
,
"for_probes.lst"
),
os
.
path
.
join
(
protocol_path
,
"dev"
,
"for_probes.csv"
),
)
eval_enroll_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_path
,
"eval"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_path
,
"eval"
,
"for_models.csv"
),
)
eval_probe_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_path
,
"eval"
,
"for_probes.lst"
),
os
.
path
.
join
(
protocol_path
,
"eval"
,
"for_probes.csv"
),
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
train_csv
=
train_csv
if
os
.
path
.
exists
(
train_csv
)
else
None
...
...
@@ -244,6 +367,8 @@ class CSVDatasetDevEval:
raise
ValueError
(
f
"The file `
{
dev_probe_csv
}
` is required and it was not found"
)
dev_enroll_csv
=
dev_enroll_csv
dev_probe_csv
=
dev_probe_csv
return
(
train_csv
,
...
...
@@ -274,7 +399,6 @@ class CSVDatasetDevEval:
self
.
csv_to_sample_loader
=
csv_to_sample_loader
def
background_model_samples
(
self
):
self
.
cache
[
"train"
]
=
(
self
.
csv_to_sample_loader
(
self
.
train_csv
)
if
self
.
cache
[
"train"
]
is
None
...
...
@@ -283,7 +407,9 @@ class CSVDatasetDevEval:
return
self
.
cache
[
"train"
]
def
_get_samplesets
(
self
,
group
=
"dev"
,
purpose
=
"enroll"
,
group_by_subject
=
False
):
def
_get_samplesets
(
self
,
group
=
"dev"
,
purpose
=
"enroll"
,
group_by_reference_id
=
False
):
fetching_probes
=
False
if
purpose
==
"enroll"
:
...
...
@@ -297,12 +423,14 @@ class CSVDatasetDevEval:
references
=
None
if
fetching_probes
:
references
=
list
(
set
([
s
.
subject
for
s
in
self
.
references
(
group
=
group
)]))
references
=
list
(
set
([
s
.
reference_id
for
s
in
self
.
references
(
group
=
group
)])
)
samples
=
self
.
csv_to_sample_loader
(
self
.
__dict__
[
cache_label
])
sample_sets
=
self
.
csv_to_sample_loader
.
convert_samples_to_samplesets
(
samples
,
group_by_
subject
=
group_by_subject
,
references
=
references
samples
,
group_by_
reference_id
=
group_by_reference_id
,
references
=
references
)
self
.
cache
[
cache_label
]
=
sample_sets
...
...
@@ -311,12 +439,12 @@ class CSVDatasetDevEval:
def
references
(
self
,
group
=
"dev"
):
return
self
.
_get_samplesets
(
group
=
group
,
purpose
=
"enroll"
,
group_by_
subject
=
True
group
=
group
,
purpose
=
"enroll"
,
group_by_
reference_id
=
True
)
def
probes
(
self
,
group
=
"dev"
):
return
self
.
_get_samplesets
(
group
=
group
,
purpose
=
"probe"
,
group_by_
subject
=
False
group
=
group
,
purpose
=
"probe"
,
group_by_
reference_id
=
False
)
def
all_samples
(
self
,
groups
=
None
):
...
...
@@ -360,6 +488,27 @@ class CSVDatasetDevEval:
samples
=
samples
+
self
.
csv_to_sample_loader
(
self
.
__dict__
[
label
])
return
samples
def
groups
(
self
):
"""This function returns the list of groups for this database.
Returns
-------
[str]
A list of groups
"""
# We always have dev-set
groups
=
[
"dev"
]
if
self
.
train_csv
is
not
None
:
groups
.
append
(
"train"
)
if
self
.
eval_enroll_csv
is
not
None
:
groups
.
append
(
"eval"
)
return
groups
class
CSVDatasetCrossValidation
:
"""
...
...
@@ -377,10 +526,10 @@ class CSVDatasetCrossValidation:
.. code-block:: text
PATH,
SUBJECT
path_1,
subject
_1
path_2,
subject
_2
path_i,
subject
_j
PATH,
reference_id
path_1,
reference_id
_1
path_2,
reference_id
_2
path_i,
reference_id
_j
...
Parameters
...
...
@@ -393,7 +542,7 @@ class CSVDatasetCrossValidation:
Pseudo-random number generator seed
test_size: float
Percentage of the
subject
s used for testing
Percentage of the
reference_id
s used for testing
samples_for_enrollment: float
Number of samples used for enrollment
...
...
@@ -435,30 +584,35 @@ class CSVDatasetCrossValidation:
def
_do_cross_validation
(
self
):
# Shuffling samples by
subject
samples_by_
subject
=
group_samples_by_
subject
(
# Shuffling samples by
reference_id
samples_by_
reference_id
=
group_samples_by_
reference_id
(
self
.
csv_to_sample_loader
(
self
.
csv_file_name
)
)
subject
s
=
list
(
samples_by_
subject
.
keys
())
reference_id
s
=
list
(
samples_by_
reference_id
.
keys
())
np
.
random
.
seed
(
self
.
random_state
)
np
.
random
.
shuffle
(
subject
s
)
np
.
random
.
shuffle
(
reference_id
s
)
# Getting the training data
n_samples_for_training
=
len
(
subjects
)
-
int
(
self
.
test_size
*
len
(
subjects
))
n_samples_for_training
=
len
(
reference_ids
)
-
int
(
self
.
test_size
*
len
(
reference_ids
)
)
self
.
cache
[
"train"
]
=
list
(
itertools
.
chain
(
*
[
samples_by_subject
[
s
]
for
s
in
subjects
[
0
:
n_samples_for_training
]]
*
[
samples_by_reference_id
[
s
]
for
s
in
reference_ids
[
0
:
n_samples_for_training
]
]
)
)
# Splitting enroll and probe
self
.
cache
[
"dev_enroll_csv"
]
=
[]
self
.
cache
[
"dev_probe_csv"
]
=
[]
for
s
in
subject
s
[
n_samples_for_training
:]:
samples
=
samples_by_
subject
[
s
]
for
s
in
reference_id
s
[
n_samples_for_training
:]:
samples
=
samples_by_
reference_id
[
s
]
if
len
(
samples
)
<
self
.
samples_for_enrollment
:
raise
ValueError
(
f
"Not enough samples (
{
len
(
samples
)
}
) for enrollment for the
subject
{
s
}
"
f
"Not enough samples (
{
len
(
samples
)
}
) for enrollment for the
reference_id
{
s
}
"
)
# Enrollment samples
...
...
@@ -472,8 +626,8 @@ class CSVDatasetCrossValidation:
"dev_probe_csv"
]
+=
self
.
csv_to_sample_loader
.
convert_samples_to_samplesets
(
samples
[
self
.
samples_for_enrollment
:],
group_by_
subject
=
False
,
references
=
subject
s
[
n_samples_for_training
:],
group_by_
reference_id
=
False
,
references
=
reference_id
s
[
n_samples_for_training
:],
)
def
_load_from_cache
(
self
,
cache_key
):
...
...
@@ -527,12 +681,12 @@ class CSVDatasetCrossValidation:
return
samples
def
group_samples_by_
subject
(
samples
):
def
group_samples_by_
reference_id
(
samples
):
# Grouping sample sets
samples_by_
subject
=
dict
()
samples_by_
reference_id
=
dict
()
for
s
in
samples
:
if
s
.
subject
not
in
samples_by_
subject
:
samples_by_
subject
[
s
.
subject
]
=
[]
samples_by_
subject
[
s
.
subject
].
append
(
s
)
return
samples_by_
subject
if
s
.
reference_id
not
in
samples_by_
reference_id
:
samples_by_
reference_id
[
s
.
reference_id
]
=
[]
samples_by_
reference_id
[
s
.
reference_id
].
append
(
s
)
return
samples_by_
reference_id
bob/bio/base/pipelines/vanilla_biometrics/abstract_classes.py
View file @
4930dda2
...
...
@@ -178,18 +178,18 @@ class BioAlgorithm(metaclass=ABCMeta):
"""
for
r
in
biometric_references
:
if
(
str
(
r
.
subject
)
in
probe_refererences
and
str
(
r
.
subject
)
not
in
self
.
stacked_biometric_references
str
(
r
.
reference_id
)
in
probe_refererences
and
str
(
r
.
reference_id
)
not
in
self
.
stacked_biometric_references
):
self
.
stacked_biometric_references
[
str
(
r
.
subject
)]
=
r
.
data
self
.
stacked_biometric_references
[
str
(
r
.
reference_id
)]
=
r
.
data
for
probe_sample
in
sampleset
:
cache_references
(
sampleset
.
references
)
references
=
[
self
.
stacked_biometric_references
[
str
(
r
.
subject
)]
self
.
stacked_biometric_references
[
str
(
r
.
reference_id
)]
for
r
in
biometric_references
if
str
(
r
.
subject
)
in
sampleset
.
references
if
str
(
r
.
reference_id
)
in
sampleset
.
references
]
scores
=
self
.
score_multiple_biometric_references
(
...
...
@@ -204,7 +204,7 @@ class BioAlgorithm(metaclass=ABCMeta):
[
r
for
r
in
biometric_references
if
str
(
r
.
subject
)
in
sampleset
.
references
if
str
(
r
.
reference_id
)
in
sampleset
.
references
],
total_scores
,
):
...
...
@@ -328,6 +328,12 @@ class Database(metaclass=ABCMeta):
"""
pass
def
groups
(
self
):
pass
def
reference_ids
(
self
,
group
):
return
[
s
.
reference_id
for
s
in
self
.
references
(
group
=
group
)]
class
ScoreWriter
(
metaclass
=
ABCMeta
):
"""
...
...
bob/bio/base/pipelines/vanilla_biometrics/legacy.py
View file @
4930dda2
...
...
@@ -29,7 +29,7 @@ def _biofile_to_delayed_sample(biofile, database):
load
=
functools
.
partial
(
biofile
.
load
,
database
.
original_directory
,
database
.
original_extension
,
),
subject
=
str
(
biofile
.
client_id
),
reference_id
=
str
(
biofile
.
client_id
),
key
=
biofile
.
path
,
path
=
biofile
.
path
,
delayed_attributes
=
dict
(
...
...
@@ -138,7 +138,7 @@ class DatabaseConnector(Database):
[
_biofile_to_delayed_sample
(
k
,
self
.
database
)
for
k
in
objects
],
key
=
str
(
m
),
path
=
str
(
m
),