diff --git a/bob/bio/base/database/csv_dataset.py b/bob/bio/base/database/csv_dataset.py index c4befc78e1c98c43f8147334799b93a5835e789f..cc1660a25dd78a5944a82b589a2d3829c2d93ede 100644 --- a/bob/bio/base/database/csv_dataset.py +++ b/bob/bio/base/database/csv_dataset.py @@ -8,13 +8,14 @@ from bob.db.base.utils import check_parameters_for_validity import csv import bob.io.base import functools -from abc import ABCMeta, abstractmethod import numpy as np import itertools import logging import bob.db.base -from bob.extension.download import find_element_in_tarball from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database +from bob.extension.download import search_file +from bob.pipelines.datasets.sample_loaders import CSVBaseSampleLoader + logger = logging.getLogger(__name__) @@ -58,92 +59,6 @@ class AnnotationsLoader: return annotation -####### -# SAMPLE LOADERS -# CONVERT CSV LINES TO SAMPLES -####### - - -class CSVBaseSampleLoader(metaclass=ABCMeta): - """ - Base class that converts the lines of a CSV file, like the one below to - :any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet` - - .. code-block:: text - - PATH,REFERENCE_ID - path_1,reference_id_1 - path_2,reference_id_2 - path_i,reference_id_j - ... - - .. note:: - This class should be extended - - Parameters - ---------- - - data_loader: - A python function that can be called parameterlessly, to load the - sample in question from whatever medium - - metadata_loader: - AnnotationsLoader - - dataset_original_directory: str - Path of where data is stored - - extension: str - Default file extension - - """ - - def __init__( - self, - data_loader, - metadata_loader=None, - dataset_original_directory="", - extension="", - ): - self.data_loader = data_loader - self.extension = extension - self.dataset_original_directory = dataset_original_directory - self.metadata_loader = metadata_loader - - @abstractmethod - def __call__(self, filename): - pass - - @abstractmethod - def convert_row_to_sample(self, row, header): - pass - - def convert_samples_to_samplesets( - self, samples, group_by_reference_id=True, references=None - ): - if group_by_reference_id: - - # Grouping sample sets - sample_sets = dict() - for s in samples: - if s.reference_id not in sample_sets: - sample_sets[s.reference_id] = ( - SampleSet([s], parent=s) - if references is None - else SampleSet([s], parent=s, references=references) - ) - else: - sample_sets[s.reference_id].append(s) - return list(sample_sets.values()) - - else: - return ( - [SampleSet([s], parent=s) for s in samples] - if references is None - else [SampleSet([s], parent=s, references=references) for s in samples] - ) - - class CSVToSampleLoader(CSVBaseSampleLoader): """ Simple mechanism that converts the lines of a CSV file to @@ -239,27 +154,6 @@ class LSTToSampleLoader(CSVBaseSampleLoader): ) -##### -# DATABASE INTERFACES -##### - - -def path_discovery(dataset_protocol_path, option1, option2): - - # If the input is a directory - if os.path.isdir(dataset_protocol_path): - option1 = os.path.join(dataset_protocol_path, option1) - option2 = os.path.join(dataset_protocol_path, option2) - if os.path.exists(option1): - return open(option1) - else: - return open(option2) if os.path.exists(option2) else None - - # If it's not a directory is a tarball - op1 = find_element_in_tarball(dataset_protocol_path, option1) - return op1 if op1 else find_element_in_tarball(dataset_protocol_path, option2) - - class CSVDataset(Database): """ Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline. @@ -357,35 +251,45 @@ class CSVDataset(Database): raise ValueError(f"The path `{dataset_protocol_path}` was not found") # Here we are handling the legacy - train_csv = path_discovery( + train_csv = search_file( dataset_protocol_path, - os.path.join(protocol_name, "norm", "train_world.lst"), - os.path.join(protocol_name, "norm", "train_world.csv"), + [ + os.path.join(protocol_name, "norm", "train_world.lst"), + os.path.join(protocol_name, "norm", "train_world.csv"), + ], ) - dev_enroll_csv = path_discovery( + dev_enroll_csv = search_file( dataset_protocol_path, - os.path.join(protocol_name, "dev", "for_models.lst"), - os.path.join(protocol_name, "dev", "for_models.csv"), + [ + os.path.join(protocol_name, "dev", "for_models.lst"), + os.path.join(protocol_name, "dev", "for_models.csv"), + ], ) legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst" - dev_probe_csv = path_discovery( + dev_probe_csv = search_file( dataset_protocol_path, - os.path.join(protocol_name, "dev", legacy_probe), - os.path.join(protocol_name, "dev", "for_probes.csv"), + [ + os.path.join(protocol_name, "dev", legacy_probe), + os.path.join(protocol_name, "dev", "for_probes.csv"), + ], ) - eval_enroll_csv = path_discovery( + eval_enroll_csv = search_file( dataset_protocol_path, - os.path.join(protocol_name, "eval", "for_models.lst"), - os.path.join(protocol_name, "eval", "for_models.csv"), + [ + os.path.join(protocol_name, "eval", "for_models.lst"), + os.path.join(protocol_name, "eval", "for_models.csv"), + ], ) - eval_probe_csv = path_discovery( + eval_probe_csv = search_file( dataset_protocol_path, - os.path.join(protocol_name, "eval", legacy_probe), - os.path.join(protocol_name, "eval", "for_probes.csv"), + [ + os.path.join(protocol_name, "eval", legacy_probe), + os.path.join(protocol_name, "eval", "for_probes.csv"), + ], ) # The minimum required is to have `dev_enroll_csv` and `dev_probe_csv` @@ -441,17 +345,17 @@ class CSVDataset(Database): def _get_samplesets( self, group="dev", - cache_label=None, + cache_key=None, group_by_reference_id=False, fetching_probes=False, is_sparse=False, ): - if self.cache[cache_label] is not None: - return self.cache[cache_label] + if self.cache[cache_key] is not None: + return self.cache[cache_key] # Getting samples from CSV - samples = self.csv_to_sample_loader(self.__getattribute__(cache_label)) + samples = self.csv_to_sample_loader(self.__getattribute__(cache_key)) references = None if fetching_probes and is_sparse: @@ -481,23 +385,23 @@ class CSVDataset(Database): samples, group_by_reference_id=group_by_reference_id, references=references, ) - self.cache[cache_label] = sample_sets + self.cache[cache_key] = sample_sets - return self.cache[cache_label] + return self.cache[cache_key] def references(self, group="dev"): - cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv" + cache_key = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv" return self._get_samplesets( - group=group, cache_label=cache_label, group_by_reference_id=True + group=group, cache_key=cache_key, group_by_reference_id=True ) def probes(self, group="dev"): - cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv" + cache_key = "dev_probe_csv" if group == "dev" else "eval_probe_csv" return self._get_samplesets( group=group, - cache_label=cache_label, + cache_key=cache_key, group_by_reference_id=False, fetching_probes=True, is_sparse=self.is_sparse, @@ -610,16 +514,20 @@ class CSVDatasetZTNorm(Database): self.cache["znorm_csv"] = None self.cache["tnorm_csv"] = None - znorm_csv = path_discovery( + znorm_csv = search_file( self.dataset_protocol_path, - os.path.join(self.protocol_name, "norm", "for_znorm.lst"), - os.path.join(self.protocol_name, "norm", "for_znorm.csv"), + [ + os.path.join(self.protocol_name, "norm", "for_znorm.lst"), + os.path.join(self.protocol_name, "norm", "for_znorm.csv"), + ], ) - tnorm_csv = path_discovery( + tnorm_csv = search_file( self.dataset_protocol_path, - os.path.join(self.protocol_name, "norm", "for_tnorm.lst"), - os.path.join(self.protocol_name, "norm", "for_tnorm.csv"), + [ + os.path.join(self.protocol_name, "norm", "for_tnorm.lst"), + os.path.join(self.protocol_name, "norm", "for_tnorm.csv"), + ], ) if znorm_csv is None: @@ -657,10 +565,10 @@ class CSVDatasetZTNorm(Database): f"Invalid proportion value ({proportion}). Values allowed from [0-1]" ) - cache_label = "znorm_csv" + cache_key = "znorm_csv" samplesets = self._get_samplesets( group=group, - cache_label=cache_label, + cache_key=cache_key, group_by_reference_id=False, fetching_probes=True, is_sparse=False, @@ -677,9 +585,9 @@ class CSVDatasetZTNorm(Database): f"Invalid proportion value ({proportion}). Values allowed from [0-1]" ) - cache_label = "tnorm_csv" + cache_key = "tnorm_csv" samplesets = self._get_samplesets( - group="dev", cache_label=cache_label, group_by_reference_id=True, + group="dev", cache_key=cache_key, group_by_reference_id=True, ) treferences = samplesets[: int(len(samplesets) * proportion)]