Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
bob.bio.base
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bob
bob.bio.base
Commits
9162d84b
Commit
9162d84b
authored
4 years ago
by
Tiago de Freitas Pereira
Browse files
Options
Downloads
Plain Diff
Merge branch 'move-code' into 'master'
Move code See merge request
!232
parents
6dcf9fec
1b7cf031
Branches
Branches containing commit
Tags
v4.1.2b0
Tags containing commit
1 merge request
!232
Move code
Pipeline
#46475
skipped
Changes
1
Pipelines
2
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
bob/bio/base/database/csv_dataset.py
+52
-144
52 additions, 144 deletions
bob/bio/base/database/csv_dataset.py
with
52 additions
and
144 deletions
bob/bio/base/database/csv_dataset.py
+
52
−
144
View file @
9162d84b
...
...
@@ -8,13 +8,14 @@ from bob.db.base.utils import check_parameters_for_validity
import
csv
import
bob.io.base
import
functools
from
abc
import
ABCMeta
,
abstractmethod
import
numpy
as
np
import
itertools
import
logging
import
bob.db.base
from
bob.extension.download
import
find_element_in_tarball
from
bob.bio.base.pipelines.vanilla_biometrics.abstract_classes
import
Database
from
bob.extension.download
import
search_file
from
bob.pipelines.datasets.sample_loaders
import
CSVBaseSampleLoader
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -58,92 +59,6 @@ class AnnotationsLoader:
return
annotation
#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######
class
CSVBaseSampleLoader
(
metaclass
=
ABCMeta
):
"""
Base class that converts the lines of a CSV file, like the one below to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
.. code-block:: text
PATH,REFERENCE_ID
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
.. note::
This class should be extended
Parameters
----------
data_loader:
A python function that can be called parameterlessly, to load the
sample in question from whatever medium
metadata_loader:
AnnotationsLoader
dataset_original_directory: str
Path of where data is stored
extension: str
Default file extension
"""
def
__init__
(
self
,
data_loader
,
metadata_loader
=
None
,
dataset_original_directory
=
""
,
extension
=
""
,
):
self
.
data_loader
=
data_loader
self
.
extension
=
extension
self
.
dataset_original_directory
=
dataset_original_directory
self
.
metadata_loader
=
metadata_loader
@abstractmethod
def
__call__
(
self
,
filename
):
pass
@abstractmethod
def
convert_row_to_sample
(
self
,
row
,
header
):
pass
def
convert_samples_to_samplesets
(
self
,
samples
,
group_by_reference_id
=
True
,
references
=
None
):
if
group_by_reference_id
:
# Grouping sample sets
sample_sets
=
dict
()
for
s
in
samples
:
if
s
.
reference_id
not
in
sample_sets
:
sample_sets
[
s
.
reference_id
]
=
(
SampleSet
([
s
],
parent
=
s
)
if
references
is
None
else
SampleSet
([
s
],
parent
=
s
,
references
=
references
)
)
else
:
sample_sets
[
s
.
reference_id
].
append
(
s
)
return
list
(
sample_sets
.
values
())
else
:
return
(
[
SampleSet
([
s
],
parent
=
s
)
for
s
in
samples
]
if
references
is
None
else
[
SampleSet
([
s
],
parent
=
s
,
references
=
references
)
for
s
in
samples
]
)
class
CSVToSampleLoader
(
CSVBaseSampleLoader
):
"""
Simple mechanism that converts the lines of a CSV file to
...
...
@@ -239,27 +154,6 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
)
#####
# DATABASE INTERFACES
#####
def
path_discovery
(
dataset_protocol_path
,
option1
,
option2
):
# If the input is a directory
if
os
.
path
.
isdir
(
dataset_protocol_path
):
option1
=
os
.
path
.
join
(
dataset_protocol_path
,
option1
)
option2
=
os
.
path
.
join
(
dataset_protocol_path
,
option2
)
if
os
.
path
.
exists
(
option1
):
return
open
(
option1
)
else
:
return
open
(
option2
)
if
os
.
path
.
exists
(
option2
)
else
None
# If it's not a directory is a tarball
op1
=
find_element_in_tarball
(
dataset_protocol_path
,
option1
)
return
op1
if
op1
else
find_element_in_tarball
(
dataset_protocol_path
,
option2
)
class
CSVDataset
(
Database
):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
...
...
@@ -357,35 +251,45 @@ class CSVDataset(Database):
raise
ValueError
(
f
"
The path `
{
dataset_protocol_path
}
` was not found
"
)
# Here we are handling the legacy
train_csv
=
path_discovery
(
train_csv
=
search_file
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"
norm
"
,
"
train_world.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
norm
"
,
"
train_world.csv
"
),
[
os
.
path
.
join
(
protocol_name
,
"
norm
"
,
"
train_world.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
norm
"
,
"
train_world.csv
"
),
],
)
dev_enroll_csv
=
path_discovery
(
dev_enroll_csv
=
search_file
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_models.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_models.csv
"
),
[
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_models.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_models.csv
"
),
],
)
legacy_probe
=
"
for_scores.lst
"
if
self
.
is_sparse
else
"
for_probes.lst
"
dev_probe_csv
=
path_discovery
(
dev_probe_csv
=
search_file
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_probes.csv
"
),
[
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"
dev
"
,
"
for_probes.csv
"
),
],
)
eval_enroll_csv
=
path_discovery
(
eval_enroll_csv
=
search_file
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_models.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_models.csv
"
),
[
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_models.lst
"
),
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_models.csv
"
),
],
)
eval_probe_csv
=
path_discovery
(
eval_probe_csv
=
search_file
(
dataset_protocol_path
,
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_probes.csv
"
),
[
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
legacy_probe
),
os
.
path
.
join
(
protocol_name
,
"
eval
"
,
"
for_probes.csv
"
),
],
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
...
...
@@ -441,17 +345,17 @@ class CSVDataset(Database):
def
_get_samplesets
(
self
,
group
=
"
dev
"
,
cache_
label
=
None
,
cache_
key
=
None
,
group_by_reference_id
=
False
,
fetching_probes
=
False
,
is_sparse
=
False
,
):
if
self
.
cache
[
cache_
label
]
is
not
None
:
return
self
.
cache
[
cache_
label
]
if
self
.
cache
[
cache_
key
]
is
not
None
:
return
self
.
cache
[
cache_
key
]
# Getting samples from CSV
samples
=
self
.
csv_to_sample_loader
(
self
.
__getattribute__
(
cache_
label
))
samples
=
self
.
csv_to_sample_loader
(
self
.
__getattribute__
(
cache_
key
))
references
=
None
if
fetching_probes
and
is_sparse
:
...
...
@@ -481,23 +385,23 @@ class CSVDataset(Database):
samples
,
group_by_reference_id
=
group_by_reference_id
,
references
=
references
,
)
self
.
cache
[
cache_
label
]
=
sample_sets
self
.
cache
[
cache_
key
]
=
sample_sets
return
self
.
cache
[
cache_
label
]
return
self
.
cache
[
cache_
key
]
def
references
(
self
,
group
=
"
dev
"
):
cache_
label
=
"
dev_enroll_csv
"
if
group
==
"
dev
"
else
"
eval_enroll_csv
"
cache_
key
=
"
dev_enroll_csv
"
if
group
==
"
dev
"
else
"
eval_enroll_csv
"
return
self
.
_get_samplesets
(
group
=
group
,
cache_
label
=
cache_
label
,
group_by_reference_id
=
True
group
=
group
,
cache_
key
=
cache_
key
,
group_by_reference_id
=
True
)
def
probes
(
self
,
group
=
"
dev
"
):
cache_
label
=
"
dev_probe_csv
"
if
group
==
"
dev
"
else
"
eval_probe_csv
"
cache_
key
=
"
dev_probe_csv
"
if
group
==
"
dev
"
else
"
eval_probe_csv
"
return
self
.
_get_samplesets
(
group
=
group
,
cache_
label
=
cache_
label
,
cache_
key
=
cache_
key
,
group_by_reference_id
=
False
,
fetching_probes
=
True
,
is_sparse
=
self
.
is_sparse
,
...
...
@@ -610,16 +514,20 @@ class CSVDatasetZTNorm(Database):
self
.
cache
[
"
znorm_csv
"
]
=
None
self
.
cache
[
"
tnorm_csv
"
]
=
None
znorm_csv
=
path_discovery
(
znorm_csv
=
search_file
(
self
.
dataset_protocol_path
,
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_znorm.lst
"
),
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_znorm.csv
"
),
[
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_znorm.lst
"
),
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_znorm.csv
"
),
],
)
tnorm_csv
=
path_discovery
(
tnorm_csv
=
search_file
(
self
.
dataset_protocol_path
,
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_tnorm.lst
"
),
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_tnorm.csv
"
),
[
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_tnorm.lst
"
),
os
.
path
.
join
(
self
.
protocol_name
,
"
norm
"
,
"
for_tnorm.csv
"
),
],
)
if
znorm_csv
is
None
:
...
...
@@ -657,10 +565,10 @@ class CSVDatasetZTNorm(Database):
f
"
Invalid proportion value (
{
proportion
}
). Values allowed from [0-1]
"
)
cache_
label
=
"
znorm_csv
"
cache_
key
=
"
znorm_csv
"
samplesets
=
self
.
_get_samplesets
(
group
=
group
,
cache_
label
=
cache_
label
,
cache_
key
=
cache_
key
,
group_by_reference_id
=
False
,
fetching_probes
=
True
,
is_sparse
=
False
,
...
...
@@ -677,9 +585,9 @@ class CSVDatasetZTNorm(Database):
f
"
Invalid proportion value (
{
proportion
}
). Values allowed from [0-1]
"
)
cache_
label
=
"
tnorm_csv
"
cache_
key
=
"
tnorm_csv
"
samplesets
=
self
.
_get_samplesets
(
group
=
"
dev
"
,
cache_
label
=
cache_
label
,
group_by_reference_id
=
True
,
group
=
"
dev
"
,
cache_
key
=
cache_
key
,
group_by_reference_id
=
True
,
)
treferences
=
samplesets
[:
int
(
len
(
samplesets
)
*
proportion
)]
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment