Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.bio.base
Commits
14a58698
Commit
14a58698
authored
Dec 02, 2020
by
Tiago de Freitas Pereira
Browse files
Make the databases work transparently with with either tarballs or csv files
parent
31e99d38
Pipeline
#46251
passed with stage
in 5 minutes and 41 seconds
Changes
3
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/bio/base/database/csv_dataset.py
View file @
14a58698
...
...
@@ -13,7 +13,7 @@ import numpy as np
import
itertools
import
logging
import
bob.db.base
from
bob.extension.download
import
find_element_in_tarball
from
bob.bio.base.pipelines.vanilla_biometrics.abstract_classes
import
Database
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -156,14 +156,13 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
if
not
"path"
in
header
:
raise
ValueError
(
"The field `path` is not available in your dataset."
)
def
__call__
(
self
,
filename
):
with
open
(
filename
)
as
cf
:
reader
=
csv
.
reader
(
cf
)
header
=
next
(
reader
)
def
__call__
(
self
,
f
):
f
.
seek
(
0
)
reader
=
csv
.
reader
(
f
)
header
=
next
(
reader
)
self
.
check_header
(
header
)
return
[
self
.
convert_row_to_sample
(
row
,
header
)
for
row
in
reader
]
self
.
check_header
(
header
)
return
[
self
.
convert_row_to_sample
(
row
,
header
)
for
row
in
reader
]
def
convert_row_to_sample
(
self
,
row
,
header
):
path
=
row
[
0
]
...
...
@@ -192,17 +191,16 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def
__call__
(
self
,
filename
):
with
open
(
filename
)
as
cf
:
reader
=
csv
.
reader
(
cf
,
delimiter
=
" "
)
samples
=
[]
for
row
in
reader
:
if
row
[
0
][
0
]
==
"#"
:
continue
samples
.
append
(
self
.
convert_row_to_sample
(
row
))
def
__call__
(
self
,
f
):
f
.
seek
(
0
)
reader
=
csv
.
reader
(
f
,
delimiter
=
" "
)
samples
=
[]
for
row
in
reader
:
if
row
[
0
][
0
]
==
"#"
:
continue
samples
.
append
(
self
.
convert_row_to_sample
(
row
))
return
samples
return
samples
def
convert_row_to_sample
(
self
,
row
,
header
=
None
):
...
...
@@ -333,57 +331,61 @@ class CSVDatasetDevEval(Database):
if
not
os
.
path
.
exists
(
dataset_protocol_path
):
raise
ValueError
(
f
"The path `
{
dataset_protocol_path
}
` was not found"
)
# TODO: Unzip file if dataset path is a zip
protocol_path
=
os
.
path
.
join
(
dataset_protocol_path
,
protocol_name
)
if
not
os
.
path
.
exists
(
protocol_path
):
raise
ValueError
(
f
"The protocol `
{
protocol_name
}
` was not found"
)
def
path_discovery
(
option1
,
option2
):
return
option1
if
os
.
path
.
exists
(
option1
)
else
option2
# If the input is a directory
if
os
.
path
.
isdir
(
dataset_protocol_path
):
option1
=
os
.
path
.
join
(
dataset_protocol_path
,
option1
)
option2
=
os
.
path
.
join
(
dataset_protocol_path
,
option2
)
if
os
.
path
.
exists
(
option1
):
return
open
(
option1
)
else
:
return
open
(
option2
)
if
os
.
path
.
exists
(
option2
)
else
None
# If it's not a directory is a tarball
op1
=
find_element_in_tarball
(
dataset_protocol_path
,
option1
)
return
(
op1
if
op1
else
find_element_in_tarball
(
dataset_protocol_path
,
option2
)
)
# Here we are handling the legacy
train_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_
path
,
"norm"
,
"train_world.lst"
),
os
.
path
.
join
(
protocol_
path
,
"norm"
,
"train_world.csv"
),
os
.
path
.
join
(
protocol_
name
,
"norm"
,
"train_world.lst"
),
os
.
path
.
join
(
protocol_
name
,
"norm"
,
"train_world.csv"
),
)
dev_enroll_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_
path
,
"dev"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_
path
,
"dev"
,
"for_models.csv"
),
os
.
path
.
join
(
protocol_
name
,
"dev"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_
name
,
"dev"
,
"for_models.csv"
),
)
legacy_probe
=
"for_scores.lst"
if
self
.
is_sparse
else
"for_probes.lst"
dev_probe_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_
path
,
"dev"
,
legacy_probe
),
os
.
path
.
join
(
protocol_
path
,
"dev"
,
"for_probes.csv"
),
os
.
path
.
join
(
protocol_
name
,
"dev"
,
legacy_probe
),
os
.
path
.
join
(
protocol_
name
,
"dev"
,
"for_probes.csv"
),
)
eval_enroll_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_
path
,
"eval"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_
path
,
"eval"
,
"for_models.csv"
),
os
.
path
.
join
(
protocol_
name
,
"eval"
,
"for_models.lst"
),
os
.
path
.
join
(
protocol_
name
,
"eval"
,
"for_models.csv"
),
)
eval_probe_csv
=
path_discovery
(
os
.
path
.
join
(
protocol_
path
,
"eval"
,
legacy_probe
),
os
.
path
.
join
(
protocol_
path
,
"eval"
,
"for_probes.csv"
),
os
.
path
.
join
(
protocol_
name
,
"eval"
,
legacy_probe
),
os
.
path
.
join
(
protocol_
name
,
"eval"
,
"for_probes.csv"
),
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
train_csv
=
train_csv
if
os
.
path
.
exists
(
train_csv
)
else
None
# Eval
eval_enroll_csv
=
(
eval_enroll_csv
if
os
.
path
.
exists
(
eval_enroll_csv
)
else
None
)
eval_probe_csv
=
eval_probe_csv
if
os
.
path
.
exists
(
eval_probe_csv
)
else
None
# Dev
if
not
os
.
path
.
exists
(
dev_enroll_csv
)
:
if
dev_enroll_csv
is
None
:
raise
ValueError
(
f
"The file `
{
dev_enroll_csv
}
` is required and it was not found"
)
if
not
os
.
path
.
exists
(
dev_probe_csv
)
:
if
dev_probe_csv
is
None
:
raise
ValueError
(
f
"The file `
{
dev_probe_csv
}
` is required and it was not found"
)
...
...
@@ -612,7 +614,7 @@ class CSVDatasetCrossValidation:
self
.
random_state
=
random_state
self
.
cache
=
get_dict_cache
()
self
.
csv_to_sample_loader
=
csv_to_sample_loader
self
.
csv_file_name
=
csv_file_name
self
.
csv_file_name
=
open
(
csv_file_name
)
self
.
samples_for_enrollment
=
samples_for_enrollment
self
.
test_size
=
test_size
...
...
bob/bio/base/test/data/example_csv_filelist.tar.gz
0 → 100644
View file @
14a58698
File added
bob/bio/base/test/test_filelist.py
View file @
14a58698
...
...
@@ -105,46 +105,50 @@ def test_csv_file_list_dev_eval():
)
)
dataset
=
CSVDatasetDevEval
(
example_dir
,
"protocol_dev_eval"
,
csv_to_sample_loader
=
CSVToSampleLoader
(
data_loader
=
bob
.
io
.
base
.
load
,
metadata_loader
=
AnnotationsLoader
(
annotation_directory
=
annotation_directory
,
annotation_extension
=
".pos"
,
annotation_type
=
"eyecenter"
,
def
run
(
filename
):
dataset
=
CSVDatasetDevEval
(
filename
,
"protocol_dev_eval"
,
csv_to_sample_loader
=
CSVToSampleLoader
(
data_loader
=
bob
.
io
.
base
.
load
,
metadata_loader
=
AnnotationsLoader
(
annotation_directory
=
annotation_directory
,
annotation_extension
=
".pos"
,
annotation_type
=
"eyecenter"
,
),
dataset_original_directory
=
""
,
extension
=
""
,
),
dataset_original_directory
=
""
,
extension
=
""
,
),
)
assert
len
(
dataset
.
background_model_samples
())
==
8
assert
check_all_true
(
dataset
.
background_model_samples
(),
DelayedSample
)
)
assert
len
(
dataset
.
background_model_samples
())
==
8
assert
check_all_true
(
dataset
.
background_model_samples
(),
DelayedSample
)
assert
len
(
dataset
.
references
())
==
2
assert
check_all_true
(
dataset
.
references
(),
SampleSet
)
assert
len
(
dataset
.
references
())
==
2
assert
check_all_true
(
dataset
.
references
(),
SampleSet
)
assert
len
(
dataset
.
probes
())
==
8
assert
check_all_true
(
dataset
.
references
(),
SampleSet
)
assert
len
(
dataset
.
probes
())
==
8
assert
check_all_true
(
dataset
.
references
(),
SampleSet
)
assert
len
(
dataset
.
references
(
group
=
"eval"
))
==
6
assert
check_all_true
(
dataset
.
references
(
group
=
"eval"
),
SampleSet
)
assert
len
(
dataset
.
references
(
group
=
"eval"
))
==
6
assert
check_all_true
(
dataset
.
references
(
group
=
"eval"
),
SampleSet
)
assert
len
(
dataset
.
probes
(
group
=
"eval"
))
==
13
assert
check_all_true
(
dataset
.
probes
(
group
=
"eval"
),
SampleSet
)
assert
len
(
dataset
.
probes
(
group
=
"eval"
))
==
13
assert
check_all_true
(
dataset
.
probes
(
group
=
"eval"
),
SampleSet
)
assert
len
(
dataset
.
all_samples
(
groups
=
None
))
==
47
assert
check_all_true
(
dataset
.
all_samples
(
groups
=
None
),
DelayedSample
)
assert
len
(
dataset
.
all_samples
(
groups
=
None
))
==
47
assert
check_all_true
(
dataset
.
all_samples
(
groups
=
None
),
DelayedSample
)
# Check the annotations
for
s
in
dataset
.
all_samples
(
groups
=
None
):
assert
isinstance
(
s
.
annotations
,
dict
)
# Check the annotations
for
s
in
dataset
.
all_samples
(
groups
=
None
):
assert
isinstance
(
s
.
annotations
,
dict
)
assert
len
(
dataset
.
reference_ids
(
group
=
"dev"
))
==
2
assert
len
(
dataset
.
reference_ids
(
group
=
"eval"
))
==
6
assert
len
(
dataset
.
reference_ids
(
group
=
"dev"
))
==
2
assert
len
(
dataset
.
reference_ids
(
group
=
"eval"
))
==
6
assert
len
(
dataset
.
groups
())
==
3
assert
len
(
dataset
.
groups
())
==
3
run
(
example_dir
)
run
(
example_dir
+
".tar.gz"
)
def
test_csv_file_list_dev_eval_sparse
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment