Skip to content
Snippets Groups Projects
Commit 34bd50fb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improved test cases

parent bfe48f7a
No related branches found
No related tags found
2 merge requests!200Database interface,!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #43915 failed
from .csv_dataset import CSVDatasetDevEval
from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader
from .file import BioFile
from .file import BioFileSet
from .database import BioDatabase
......@@ -27,5 +27,6 @@ __appropriate__(
BioDatabase,
ZTBioDatabase,
CSVDatasetDevEval,
CSVToSampleLoader
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -10,7 +10,7 @@ import functools
from abc import ABCMeta, abstractmethod
class CSVSampleLoaderAbstract(metaclass=ABCMeta):
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
Convert CSV files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
......@@ -38,9 +38,10 @@ class CSVSampleLoaderAbstract(metaclass=ABCMeta):
"""
def __init__(self, data_loader, extension=""):
def __init__(self, data_loader, dataset_original_directory="", extension=""):
self.data_loader = data_loader
self.extension = extension
self.dataset_original_directory = dataset_original_directory
self.excluding_attributes = ["_data", "load", "key"]
@abstractmethod
......@@ -52,11 +53,13 @@ class CSVSampleLoaderAbstract(metaclass=ABCMeta):
pass
@abstractmethod
def convert_samples_to_samplesets(self, samples, group_by_subject=True):
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
pass
class CSVToSampleLoader(CSVSampleLoaderAbstract):
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism to convert CSV files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
......@@ -86,15 +89,17 @@ class CSVToSampleLoader(CSVSampleLoaderAbstract):
def convert_row_to_sample(self, row, header):
path = row[0]
subject = row[1]
kwargs = dict([[h, r] for h,r in zip(header[2:], row[2:])])
kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
return DelayedSample(
functools.partial(self.data_loader, os.path.join(path, self.extension)),
functools.partial(self.data_loader, os.path.join(self.dataset_original_directory, path+self.extension)),
key=path,
subject=subject,
**kwargs,
)
def convert_samples_to_samplesets(self, samples, group_by_subject=True):
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
def get_attribute_from_sample(sample):
return dict(
[
......@@ -117,7 +122,7 @@ class CSVToSampleLoader(CSVSampleLoaderAbstract):
return list(sample_sets.values())
else:
return [SampleSet([s], **get_attribute_from_sample(s)) for s in samples]
return [SampleSet([s], **get_attribute_from_sample(s), references=references) for s in samples]
class CSVDatasetDevEval:
......@@ -186,7 +191,7 @@ class CSVDatasetDevEval:
dataset_path: str
Absolute path of the dataset protocol description
protocol: str
protocol_na,e: str
The name of the protocol
csv_to_sample_loader:
......@@ -196,21 +201,21 @@ class CSVDatasetDevEval:
def __init__(
self,
dataset_path,
protocol,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load, extension=""
data_loader=bob.io.base.load, dataset_original_directory="", extension=""
),
):
def get_paths():
if not os.path.exists(dataset_path):
raise ValueError(f"The path `{dataset_path}` was not found")
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# TODO: Unzip file if dataset path is a zip
protocol_path = os.path.join(dataset_path, protocol)
protocol_path = os.path.join(dataset_protocol_path, protocol_name)
if not os.path.exists(protocol_path):
raise ValueError(f"The protocol `{protocol}` was not found")
raise ValueError(f"The protocol `{protocol_name}` was not found")
train_csv = os.path.join(protocol_path, "train.csv")
dev_enroll_csv = os.path.join(protocol_path, "dev_enroll.csv")
......@@ -276,20 +281,31 @@ class CSVDatasetDevEval:
return self.cache["train"]
def _get_subjects_from_samplesets(self, sample_sets):
return list(set([s.subject for s in sample_sets]))
def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False):
fetching_probes = False
if purpose == "enroll":
cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
else:
fetching_probes = True
cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
if self.cache[cache_label] is not None:
return self.cache[cache_label]
probes_data = self.csv_to_sample_loader(self.__dict__[cache_label])
references = None
if fetching_probes:
references = self._get_subjects_from_samplesets(
self.references(group=group)
)
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
probes_data, group_by_subject=group_by_subject
samples, group_by_subject=group_by_subject, references=references
)
self.cache[cache_label] = sample_sets
......
......@@ -5,16 +5,28 @@
"""
import os
import bob.io.base
import bob.io.base.test_utils
from bob.bio.base.database import CSVDatasetDevEval
from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader
import nose.tools
from bob.pipelines import DelayedSample, SampleSet
import numpy as np
from .utils import atnt_database_directory
from bob.bio.base.pipelines.vanilla_biometrics import (
Distance,
VanillaBiometricsPipeline,
)
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline
from bob.pipelines import wrap
example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_csv_filelist")
)
atnt_dir = os.path.realpath(bob.io.base.test_utils.datafile(".", __name__, "data/atnt"))
atnt_protocol_path = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/atnt")
)
def check_all_true(list_of_something, something):
......@@ -41,21 +53,24 @@ def test_csv_file_list_dev_only_metadata():
dataset = CSVDatasetDevEval(example_dir, "protocol_only_dev_metadata")
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.background_model_samples()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.background_model_samples()])
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert np.alltrue(
["METADATA_1" in s.__dict__ for s in dataset.background_model_samples()]
)
assert np.alltrue(
["METADATA_2" in s.__dict__ for s in dataset.background_model_samples()]
)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.references()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.references()])
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.references()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.references()])
assert len(dataset.probes()) == 10
assert check_all_true(dataset.probes(), SampleSet)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.probes()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["references" in s.__dict__ for s in dataset.probes()])
def test_csv_file_list_dev_eval():
......@@ -79,7 +94,42 @@ def test_csv_file_list_dev_eval():
def test_csv_file_list_atnt():
dataset = CSVDatasetDevEval(atnt_dir, "idiap_protocol")
dataset = CSVDatasetDevEval(atnt_protocol_path, "idiap_protocol")
assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100
def test_atnt_experiment():
def load(path):
import bob.io.image
return bob.io.base.load(path)
def linearize(X):
X = np.asarray(X)
return np.reshape(X, (X.shape[0], -1))
dataset = CSVDatasetDevEval(
dataset_protocol_path=atnt_protocol_path,
protocol_name="idiap_protocol",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=load,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
),
)
#### Testing it in a real recognition systems
transformer = wrap(["sample"], make_pipeline(FunctionTransformer(linearize)))
vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, Distance())
scores = vanilla_biometrics_pipeline(
dataset.background_model_samples(),
dataset.references(),
dataset.probes(),
)
assert len(scores)==100
assert np.alltrue([len(s)==20] for s in scores)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment