Skip to content
Snippets Groups Projects
Commit 9162d84b authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'move-code' into 'master'

Move code

See merge request !232
parents 6dcf9fec 1b7cf031
Branches
Tags v4.1.2b0
1 merge request!232Move code
Pipeline #46475 skipped
......@@ -8,13 +8,14 @@ from bob.db.base.utils import check_parameters_for_validity
import csv
import bob.io.base
import functools
from abc import ABCMeta, abstractmethod
import numpy as np
import itertools
import logging
import bob.db.base
from bob.extension.download import find_element_in_tarball
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
from bob.extension.download import search_file
from bob.pipelines.datasets.sample_loaders import CSVBaseSampleLoader
logger = logging.getLogger(__name__)
......@@ -58,92 +59,6 @@ class AnnotationsLoader:
return annotation
#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
Base class that converts the lines of a CSV file, like the one below to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
.. code-block:: text
PATH,REFERENCE_ID
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
.. note::
This class should be extended
Parameters
----------
data_loader:
A python function that can be called parameterlessly, to load the
sample in question from whatever medium
metadata_loader:
AnnotationsLoader
dataset_original_directory: str
Path of where data is stored
extension: str
Default file extension
"""
def __init__(
self,
data_loader,
metadata_loader=None,
dataset_original_directory="",
extension="",
):
self.data_loader = data_loader
self.extension = extension
self.dataset_original_directory = dataset_original_directory
self.metadata_loader = metadata_loader
@abstractmethod
def __call__(self, filename):
pass
@abstractmethod
def convert_row_to_sample(self, row, header):
pass
def convert_samples_to_samplesets(
self, samples, group_by_reference_id=True, references=None
):
if group_by_reference_id:
# Grouping sample sets
sample_sets = dict()
for s in samples:
if s.reference_id not in sample_sets:
sample_sets[s.reference_id] = (
SampleSet([s], parent=s)
if references is None
else SampleSet([s], parent=s, references=references)
)
else:
sample_sets[s.reference_id].append(s)
return list(sample_sets.values())
else:
return (
[SampleSet([s], parent=s) for s in samples]
if references is None
else [SampleSet([s], parent=s, references=references) for s in samples]
)
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism that converts the lines of a CSV file to
......@@ -239,27 +154,6 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
)
#####
# DATABASE INTERFACES
#####
def path_discovery(dataset_protocol_path, option1, option2):
# If the input is a directory
if os.path.isdir(dataset_protocol_path):
option1 = os.path.join(dataset_protocol_path, option1)
option2 = os.path.join(dataset_protocol_path, option2)
if os.path.exists(option1):
return open(option1)
else:
return open(option2) if os.path.exists(option2) else None
# If it's not a directory is a tarball
op1 = find_element_in_tarball(dataset_protocol_path, option1)
return op1 if op1 else find_element_in_tarball(dataset_protocol_path, option2)
class CSVDataset(Database):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
......@@ -357,35 +251,45 @@ class CSVDataset(Database):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# Here we are handling the legacy
train_csv = path_discovery(
train_csv = search_file(
dataset_protocol_path,
os.path.join(protocol_name, "norm", "train_world.lst"),
os.path.join(protocol_name, "norm", "train_world.csv"),
[
os.path.join(protocol_name, "norm", "train_world.lst"),
os.path.join(protocol_name, "norm", "train_world.csv"),
],
)
dev_enroll_csv = path_discovery(
dev_enroll_csv = search_file(
dataset_protocol_path,
os.path.join(protocol_name, "dev", "for_models.lst"),
os.path.join(protocol_name, "dev", "for_models.csv"),
[
os.path.join(protocol_name, "dev", "for_models.lst"),
os.path.join(protocol_name, "dev", "for_models.csv"),
],
)
legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
dev_probe_csv = path_discovery(
dev_probe_csv = search_file(
dataset_protocol_path,
os.path.join(protocol_name, "dev", legacy_probe),
os.path.join(protocol_name, "dev", "for_probes.csv"),
[
os.path.join(protocol_name, "dev", legacy_probe),
os.path.join(protocol_name, "dev", "for_probes.csv"),
],
)
eval_enroll_csv = path_discovery(
eval_enroll_csv = search_file(
dataset_protocol_path,
os.path.join(protocol_name, "eval", "for_models.lst"),
os.path.join(protocol_name, "eval", "for_models.csv"),
[
os.path.join(protocol_name, "eval", "for_models.lst"),
os.path.join(protocol_name, "eval", "for_models.csv"),
],
)
eval_probe_csv = path_discovery(
eval_probe_csv = search_file(
dataset_protocol_path,
os.path.join(protocol_name, "eval", legacy_probe),
os.path.join(protocol_name, "eval", "for_probes.csv"),
[
os.path.join(protocol_name, "eval", legacy_probe),
os.path.join(protocol_name, "eval", "for_probes.csv"),
],
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
......@@ -441,17 +345,17 @@ class CSVDataset(Database):
def _get_samplesets(
self,
group="dev",
cache_label=None,
cache_key=None,
group_by_reference_id=False,
fetching_probes=False,
is_sparse=False,
):
if self.cache[cache_label] is not None:
return self.cache[cache_label]
if self.cache[cache_key] is not None:
return self.cache[cache_key]
# Getting samples from CSV
samples = self.csv_to_sample_loader(self.__getattribute__(cache_label))
samples = self.csv_to_sample_loader(self.__getattribute__(cache_key))
references = None
if fetching_probes and is_sparse:
......@@ -481,23 +385,23 @@ class CSVDataset(Database):
samples, group_by_reference_id=group_by_reference_id, references=references,
)
self.cache[cache_label] = sample_sets
self.cache[cache_key] = sample_sets
return self.cache[cache_label]
return self.cache[cache_key]
def references(self, group="dev"):
cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
cache_key = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
return self._get_samplesets(
group=group, cache_label=cache_label, group_by_reference_id=True
group=group, cache_key=cache_key, group_by_reference_id=True
)
def probes(self, group="dev"):
cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
cache_key = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
return self._get_samplesets(
group=group,
cache_label=cache_label,
cache_key=cache_key,
group_by_reference_id=False,
fetching_probes=True,
is_sparse=self.is_sparse,
......@@ -610,16 +514,20 @@ class CSVDatasetZTNorm(Database):
self.cache["znorm_csv"] = None
self.cache["tnorm_csv"] = None
znorm_csv = path_discovery(
znorm_csv = search_file(
self.dataset_protocol_path,
os.path.join(self.protocol_name, "norm", "for_znorm.lst"),
os.path.join(self.protocol_name, "norm", "for_znorm.csv"),
[
os.path.join(self.protocol_name, "norm", "for_znorm.lst"),
os.path.join(self.protocol_name, "norm", "for_znorm.csv"),
],
)
tnorm_csv = path_discovery(
tnorm_csv = search_file(
self.dataset_protocol_path,
os.path.join(self.protocol_name, "norm", "for_tnorm.lst"),
os.path.join(self.protocol_name, "norm", "for_tnorm.csv"),
[
os.path.join(self.protocol_name, "norm", "for_tnorm.lst"),
os.path.join(self.protocol_name, "norm", "for_tnorm.csv"),
],
)
if znorm_csv is None:
......@@ -657,10 +565,10 @@ class CSVDatasetZTNorm(Database):
f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
)
cache_label = "znorm_csv"
cache_key = "znorm_csv"
samplesets = self._get_samplesets(
group=group,
cache_label=cache_label,
cache_key=cache_key,
group_by_reference_id=False,
fetching_probes=True,
is_sparse=False,
......@@ -677,9 +585,9 @@ class CSVDatasetZTNorm(Database):
f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
)
cache_label = "tnorm_csv"
cache_key = "tnorm_csv"
samplesets = self._get_samplesets(
group="dev", cache_label=cache_label, group_by_reference_id=True,
group="dev", cache_key=cache_key, group_by_reference_id=True,
)
treferences = samplesets[: int(len(samplesets) * proportion)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment