Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.base
Commits
243b4e70
Commit
243b4e70
authored
Dec 03, 2020
by
Tiago de Freitas Pereira
Browse files
Implemented ZTNorm interface
parent
9681ff8b
Pipeline
#46294
passed with stage
in 13 minutes and 1 second
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/bio/base/database/__init__.py
View file @
243b4e70
...
...
@@ -5,6 +5,7 @@ from .csv_dataset import (
CSVBaseSampleLoader
,
AnnotationsLoader
,
LSTToSampleLoader
,
CSVDatasetDevEvalZTNorm
,
)
from
.file
import
BioFile
from
.file
import
BioFileSet
...
...
bob/bio/base/database/csv_dataset.py
View file @
243b4e70
...
...
@@ -244,6 +244,22 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
#####
def
path_discovery
(
dataset_protocol_path
,
option1
,
option2
):
# If the input is a directory
if
os
.
path
.
isdir
(
dataset_protocol_path
):
option1
=
os
.
path
.
join
(
dataset_protocol_path
,
option1
)
option2
=
os
.
path
.
join
(
dataset_protocol_path
,
option2
)
if
os
.
path
.
exists
(
option1
):
return
open
(
option1
)
else
:
return
open
(
option2
)
if
os
.
path
.
exists
(
option2
)
else
None
# If it's not a directory is a tarball
op1
=
find_element_in_tarball
(
dataset_protocol_path
,
option1
)
return
op1
if
op1
else
find_element_in_tarball
(
dataset_protocol_path
,
option2
)
class
CSVDatasetDevEval
(
Database
):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
...
...
@@ -333,54 +349,41 @@ class CSVDatasetDevEval(Database):
):
self
.
dataset_protocol_path
=
dataset_protocol_path
self
.
is_sparse
=
is_sparse
self
.
protocol_name
=
protocol_name
def
get_paths
():
if
not
os
.
path
.
exists
(
dataset_protocol_path
):
raise
ValueError
(
f
"The path `
{
dataset_protocol_path
}
` was not found"
)
def
path_discovery
(
option1
,
option2
):
# If the input is a directory
if
os
.
path
.
isdir
(
dataset_protocol_path
):
option1
=
os
.
path
.
join
(
dataset_protocol_path
,
option1
)
option2
=
os
.
path
.
join
(
dataset_protocol_path
,
option2
)
if
os
.
path
.
exists
(
option1
):
return
open
(
option1
)
else
:
return
open
(
option2
)
if
os
.
path
.
exists
(
option2
)
else
None
# If it's not a directory is a tarball
op1
=
find_element_in_tarball
(
dataset_protocol_path
,
option1
)
return
(
op1
if
op1
else
find_element_in_tarball
(
dataset_protocol_path
,
option2
)
)
# Here we are handling the legacy
train_csv
=
path_discovery
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"norm"
,
"train_world.lst"
),
os
.
path
.
join
(
protocol_name
,
"norm"
,
"train_world.csv"
),
)
dev_enroll_csv
=
path_discovery
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"dev"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_name
,
"dev"
,
"for_models.csv"
),
)
legacy_probe
=
"for_scores.lst"
if
self
.
is_sparse
else
"for_probes.lst"
dev_probe_csv
=
path_discovery
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"dev"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"dev"
,
"for_probes.csv"
),
)
eval_enroll_csv
=
path_discovery
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"eval"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_name
,
"eval"
,
"for_models.csv"
),
)
eval_probe_csv
=
path_discovery
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"eval"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"eval"
,
"for_probes.csv"
),
)
...
...
@@ -438,24 +441,22 @@ class CSVDatasetDevEval(Database):
return
self
.
cache
[
"train"
]
def
_get_samplesets
(
self
,
group
=
"dev"
,
purpose
=
"enroll"
,
group_by_reference_id
=
False
self
,
group
=
"dev"
,
cache_label
=
None
,
group_by_reference_id
=
False
,
fetching_probes
=
False
,
is_sparse
=
False
,
):
fetching_probes
=
False
if
purpose
==
"enroll"
:
cache_label
=
"dev_enroll_csv"
if
group
==
"dev"
else
"eval_enroll_csv"
else
:
fetching_probes
=
True
cache_label
=
"dev_probe_csv"
if
group
==
"dev"
else
"eval_probe_csv"
if
self
.
cache
[
cache_label
]
is
not
None
:
return
self
.
cache
[
cache_label
]
# Getting samples from CSV
samples
=
self
.
csv_to_sample_loader
(
self
.
__
dict
__
[
cache_label
]
)
samples
=
self
.
csv_to_sample_loader
(
self
.
__
getattribute
__
(
cache_label
)
)
references
=
None
if
fetching_probes
and
self
.
is_sparse
:
if
fetching_probes
and
is_sparse
:
# Checking if `is_sparse` was set properly
if
len
(
samples
)
>
0
and
not
hasattr
(
samples
[
0
],
"compare_reference_id"
):
...
...
@@ -487,13 +488,21 @@ class CSVDatasetDevEval(Database):
return
self
.
cache
[
cache_label
]
def
references
(
self
,
group
=
"dev"
):
cache_label
=
"dev_enroll_csv"
if
group
==
"dev"
else
"eval_enroll_csv"
return
self
.
_get_samplesets
(
group
=
group
,
purpose
=
"enroll"
,
group_by_reference_id
=
True
group
=
group
,
cache_label
=
cache_label
,
group_by_reference_id
=
True
)
def
probes
(
self
,
group
=
"dev"
):
cache_label
=
"dev_probe_csv"
if
group
==
"dev"
else
"eval_probe_csv"
return
self
.
_get_samplesets
(
group
=
group
,
purpose
=
"probe"
,
group_by_reference_id
=
False
group
=
group
,
cache_label
=
cache_label
,
group_by_reference_id
=
False
,
fetching_probes
=
True
,
is_sparse
=
self
.
is_sparse
,
)
def
all_samples
(
self
,
groups
=
None
):
...
...
@@ -534,7 +543,9 @@ class CSVDatasetDevEval(Database):
for
group
in
groups
:
for
purpose
in
(
"enroll"
,
"probe"
):
label
=
f
"
{
group
}
_
{
purpose
}
_csv"
samples
=
samples
+
self
.
csv_to_sample_loader
(
self
.
__dict__
[
label
])
samples
=
samples
+
self
.
csv_to_sample_loader
(
self
.
__getattribute__
(
label
)
)
return
samples
def
groups
(
self
):
...
...
@@ -559,6 +570,125 @@ class CSVDatasetDevEval(Database):
return
groups
class
CSVDatasetDevEvalZTNorm
(
Database
):
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.ZTNormPipeline` pipelines.
Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
interface.
This dataset interface takes as in put a :any:`CSVDatasetDevEval` as input and have two extra methods:
:any:`CSVDatasetDevEvalZTNorm.zprobes` and :any:`CSVDatasetDevEvalZTNorm.treferences`.
To create a new dataset, you need to provide a directory structure similar to the one below:
.. code-block:: text
my_dataset/
my_dataset/my_protocol/norm/train_world.csv
my_dataset/my_protocol/norm/for_znorm.csv
my_dataset/my_protocol/norm/for_tnorm.csv
my_dataset/my_protocol/dev/for_models.csv
my_dataset/my_protocol/dev/for_probes.csv
my_dataset/my_protocol/eval/for_models.csv
my_dataset/my_protocol/eval/for_probes.csv
Parameters
----------
database: :any:`CSVDatasetDevEval`
:any:`CSVDatasetDevEval` to be aggregated
"""
def
__init__
(
self
,
database
):
self
.
database
=
database
self
.
cache
=
self
.
database
.
cache
self
.
csv_to_sample_loader
=
self
.
database
.
csv_to_sample_loader
self
.
protocol_name
=
self
.
database
.
protocol_name
self
.
dataset_protocol_path
=
self
.
database
.
dataset_protocol_path
self
.
_get_samplesets
=
self
.
database
.
_get_samplesets
## create_cache
self
.
cache
[
"znorm_csv"
]
=
None
self
.
cache
[
"tnorm_csv"
]
=
None
znorm_csv
=
path_discovery
(
self
.
dataset_protocol_path
,
os
.
path
.
join
(
self
.
protocol_name
,
"norm"
,
"for_znorm.lst"
),
os
.
path
.
join
(
self
.
protocol_name
,
"norm"
,
"for_znorm.csv"
),
)
tnorm_csv
=
path_discovery
(
self
.
dataset_protocol_path
,
os
.
path
.
join
(
self
.
protocol_name
,
"norm"
,
"for_tnorm.lst"
),
os
.
path
.
join
(
self
.
protocol_name
,
"norm"
,
"for_tnorm.csv"
),
)
if
znorm_csv
is
None
:
raise
ValueError
(
f
"The file `for_znorm.lst` is required and it was not found in `
{
self
.
protocol_name
}
/norm` "
)
if
tnorm_csv
is
None
:
raise
ValueError
(
f
"The file `for_tnorm.csv` is required and it was not found `
{
self
.
protocol_name
}
/norm`"
)
self
.
database
.
znorm_csv
=
znorm_csv
self
.
database
.
tnorm_csv
=
tnorm_csv
def
background_model_samples
(
self
):
return
self
.
database
.
background_model_samples
()
def
references
(
self
,
group
=
"dev"
):
return
self
.
database
.
references
(
group
=
group
)
def
probes
(
self
,
group
=
"dev"
):
return
self
.
database
.
probes
(
group
=
group
)
def
all_samples
(
self
,
groups
=
None
):
return
self
.
database
.
all_samples
(
groups
=
groups
)
def
groups
(
self
):
return
self
.
database
.
groups
()
def
zprobes
(
self
,
group
=
"dev"
,
proportion
=
1.0
):
if
proportion
<=
0
or
proportion
>
1
:
raise
ValueError
(
f
"Invalid proportion value (
{
proportion
}
). Values allowed from [0-1]"
)
cache_label
=
"znorm_csv"
samplesets
=
self
.
_get_samplesets
(
group
=
group
,
cache_label
=
cache_label
,
group_by_reference_id
=
False
,
fetching_probes
=
True
,
is_sparse
=
False
,
)
zprobes
=
samplesets
[:
int
(
len
(
samplesets
)
*
proportion
)]
return
zprobes
def
treferences
(
self
,
covariate
=
"sex"
,
proportion
=
1.0
):
if
proportion
<=
0
or
proportion
>
1
:
raise
ValueError
(
f
"Invalid proportion value (
{
proportion
}
). Values allowed from [0-1]"
)
cache_label
=
"tnorm_csv"
samplesets
=
self
.
_get_samplesets
(
group
=
"dev"
,
cache_label
=
cache_label
,
group_by_reference_id
=
True
,
)
treferences
=
samplesets
[:
int
(
len
(
samplesets
)
*
proportion
)]
return
treferences
class
CSVDatasetCrossValidation
:
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
...
...
bob/bio/base/test/test_filelist.py
View file @
243b4e70
...
...
@@ -13,6 +13,7 @@ from bob.bio.base.database import (
CSVDatasetCrossValidation
,
AnnotationsLoader
,
LSTToSampleLoader
,
CSVDatasetDevEvalZTNorm
,
)
import
nose.tools
from
bob.pipelines
import
DelayedSample
,
SampleSet
...
...
@@ -151,6 +152,68 @@ def test_csv_file_list_dev_eval():
run
(
example_dir
+
".tar.gz"
)
def
test_csv_file_list_dev_eval_score_norm
():
annotation_directory
=
os
.
path
.
realpath
(
bob
.
io
.
base
.
test_utils
.
datafile
(
"."
,
__name__
,
"data/example_csv_filelist/annotations"
)
)
def
run
(
filename
):
dataset
=
CSVDatasetDevEval
(
filename
,
"protocol_dev_eval"
,
csv_to_sample_loader
=
CSVToSampleLoader
(
data_loader
=
bob
.
io
.
base
.
load
,
metadata_loader
=
AnnotationsLoader
(
annotation_directory
=
annotation_directory
,
annotation_extension
=
".pos"
,
annotation_type
=
"eyecenter"
,
),
dataset_original_directory
=
""
,
extension
=
""
,
),
)
znorm_dataset
=
CSVDatasetDevEvalZTNorm
(
dataset
)
assert
len
(
znorm_dataset
.
background_model_samples
())
==
8
assert
check_all_true
(
znorm_dataset
.
background_model_samples
(),
DelayedSample
)
assert
len
(
znorm_dataset
.
references
())
==
2
assert
check_all_true
(
znorm_dataset
.
references
(),
SampleSet
)
assert
len
(
znorm_dataset
.
probes
())
==
8
assert
check_all_true
(
znorm_dataset
.
references
(),
SampleSet
)
assert
len
(
znorm_dataset
.
references
(
group
=
"eval"
))
==
6
assert
check_all_true
(
znorm_dataset
.
references
(
group
=
"eval"
),
SampleSet
)
assert
len
(
znorm_dataset
.
probes
(
group
=
"eval"
))
==
13
assert
check_all_true
(
znorm_dataset
.
probes
(
group
=
"eval"
),
SampleSet
)
assert
len
(
znorm_dataset
.
all_samples
(
groups
=
None
))
==
47
assert
check_all_true
(
znorm_dataset
.
all_samples
(
groups
=
None
),
DelayedSample
)
# Check the annotations
for
s
in
znorm_dataset
.
all_samples
(
groups
=
None
):
assert
isinstance
(
s
.
annotations
,
dict
)
assert
len
(
znorm_dataset
.
reference_ids
(
group
=
"dev"
))
==
2
assert
len
(
znorm_dataset
.
reference_ids
(
group
=
"eval"
))
==
6
assert
len
(
znorm_dataset
.
groups
())
==
3
## Checking ZT-Norm stuff
assert
len
(
znorm_dataset
.
treferences
())
==
2
assert
len
(
znorm_dataset
.
zprobes
())
==
8
assert
len
(
znorm_dataset
.
treferences
(
proportion
=
0.5
))
==
1
assert
len
(
znorm_dataset
.
zprobes
(
proportion
=
0.5
))
==
4
run
(
example_dir
)
run
(
example_dir
+
".tar.gz"
)
def
test_csv_file_list_dev_eval_sparse
():
annotation_directory
=
os
.
path
.
realpath
(
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment