Commit 34bd50fb authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Improved test cases

parent bfe48f7a
Pipeline #43915 failed with stage
in 9 minutes and 39 seconds
from .csv_dataset import CSVDatasetDevEval
from .csv_dataset import CSVDatasetDevEval, CSVToSampleLoader
from .file import BioFile
from .file import BioFileSet
from .database import BioDatabase
......@@ -27,5 +27,6 @@ __appropriate__(
BioDatabase,
ZTBioDatabase,
CSVDatasetDevEval,
CSVToSampleLoader
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
......@@ -10,7 +10,7 @@ import functools
from abc import ABCMeta, abstractmethod
class CSVSampleLoaderAbstract(metaclass=ABCMeta):
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
Convert CSV files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
......@@ -38,9 +38,10 @@ class CSVSampleLoaderAbstract(metaclass=ABCMeta):
"""
def __init__(self, data_loader, extension=""):
def __init__(self, data_loader, dataset_original_directory="", extension=""):
self.data_loader = data_loader
self.extension = extension
self.dataset_original_directory = dataset_original_directory
self.excluding_attributes = ["_data", "load", "key"]
@abstractmethod
......@@ -52,11 +53,13 @@ class CSVSampleLoaderAbstract(metaclass=ABCMeta):
pass
@abstractmethod
def convert_samples_to_samplesets(self, samples, group_by_subject=True):
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
pass
class CSVToSampleLoader(CSVSampleLoaderAbstract):
class CSVToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism to convert CSV files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
......@@ -86,15 +89,17 @@ class CSVToSampleLoader(CSVSampleLoaderAbstract):
def convert_row_to_sample(self, row, header):
path = row[0]
subject = row[1]
kwargs = dict([[h, r] for h,r in zip(header[2:], row[2:])])
kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
return DelayedSample(
functools.partial(self.data_loader, os.path.join(path, self.extension)),
functools.partial(self.data_loader, os.path.join(self.dataset_original_directory, path+self.extension)),
key=path,
subject=subject,
**kwargs,
)
def convert_samples_to_samplesets(self, samples, group_by_subject=True):
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
def get_attribute_from_sample(sample):
return dict(
[
......@@ -117,7 +122,7 @@ class CSVToSampleLoader(CSVSampleLoaderAbstract):
return list(sample_sets.values())
else:
return [SampleSet([s], **get_attribute_from_sample(s)) for s in samples]
return [SampleSet([s], **get_attribute_from_sample(s), references=references) for s in samples]
class CSVDatasetDevEval:
......@@ -186,7 +191,7 @@ class CSVDatasetDevEval:
dataset_path: str
Absolute path of the dataset protocol description
protocol: str
protocol_na,e: str
The name of the protocol
csv_to_sample_loader:
......@@ -196,21 +201,21 @@ class CSVDatasetDevEval:
def __init__(
self,
dataset_path,
protocol,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load, extension=""
data_loader=bob.io.base.load, dataset_original_directory="", extension=""
),
):
def get_paths():
if not os.path.exists(dataset_path):
raise ValueError(f"The path `{dataset_path}` was not found")
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
# TODO: Unzip file if dataset path is a zip
protocol_path = os.path.join(dataset_path, protocol)
protocol_path = os.path.join(dataset_protocol_path, protocol_name)
if not os.path.exists(protocol_path):
raise ValueError(f"The protocol `{protocol}` was not found")
raise ValueError(f"The protocol `{protocol_name}` was not found")
train_csv = os.path.join(protocol_path, "train.csv")
dev_enroll_csv = os.path.join(protocol_path, "dev_enroll.csv")
......@@ -276,20 +281,31 @@ class CSVDatasetDevEval:
return self.cache["train"]
def _get_subjects_from_samplesets(self, sample_sets):
return list(set([s.subject for s in sample_sets]))
def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False):
fetching_probes = False
if purpose == "enroll":
cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
else:
fetching_probes = True
cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
if self.cache[cache_label] is not None:
return self.cache[cache_label]
probes_data = self.csv_to_sample_loader(self.__dict__[cache_label])
references = None
if fetching_probes:
references = self._get_subjects_from_samplesets(
self.references(group=group)
)
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
probes_data, group_by_subject=group_by_subject
samples, group_by_subject=group_by_subject, references=references
)
self.cache[cache_label] = sample_sets
......
......@@ -5,16 +5,28 @@
"""
import os
import bob.io.base
import bob.io.base.test_utils
from bob.bio.base.database import CSVDatasetDevEval
from bob.bio.base.database import CSVDatasetDevEval, CSVToSampleLoader
import nose.tools
from bob.pipelines import DelayedSample, SampleSet
import numpy as np
from .utils import atnt_database_directory
from bob.bio.base.pipelines.vanilla_biometrics import (
Distance,
VanillaBiometricsPipeline,
)
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline
from bob.pipelines import wrap
example_dir = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/example_csv_filelist")
)
atnt_dir = os.path.realpath(bob.io.base.test_utils.datafile(".", __name__, "data/atnt"))
atnt_protocol_path = os.path.realpath(
bob.io.base.test_utils.datafile(".", __name__, "data/atnt")
)
def check_all_true(list_of_something, something):
......@@ -41,21 +53,24 @@ def test_csv_file_list_dev_only_metadata():
dataset = CSVDatasetDevEval(example_dir, "protocol_only_dev_metadata")
assert len(dataset.background_model_samples()) == 8
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.background_model_samples()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.background_model_samples()])
assert check_all_true(dataset.background_model_samples(), DelayedSample)
assert np.alltrue(
["METADATA_1" in s.__dict__ for s in dataset.background_model_samples()]
)
assert np.alltrue(
["METADATA_2" in s.__dict__ for s in dataset.background_model_samples()]
)
assert len(dataset.references()) == 2
assert check_all_true(dataset.references(), SampleSet)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.references()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.references()])
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.references()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.references()])
assert len(dataset.probes()) == 10
assert check_all_true(dataset.probes(), SampleSet)
assert np.alltrue(['METADATA_1' in s.__dict__ for s in dataset.probes()])
assert np.alltrue(['METADATA_2' in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["METADATA_1" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["METADATA_2" in s.__dict__ for s in dataset.probes()])
assert np.alltrue(["references" in s.__dict__ for s in dataset.probes()])
def test_csv_file_list_dev_eval():
......@@ -79,7 +94,42 @@ def test_csv_file_list_dev_eval():
def test_csv_file_list_atnt():
dataset = CSVDatasetDevEval(atnt_dir, "idiap_protocol")
dataset = CSVDatasetDevEval(atnt_protocol_path, "idiap_protocol")
assert len(dataset.background_model_samples()) == 200
assert len(dataset.references()) == 20
assert len(dataset.probes()) == 100
def test_atnt_experiment():
def load(path):
import bob.io.image
return bob.io.base.load(path)
def linearize(X):
X = np.asarray(X)
return np.reshape(X, (X.shape[0], -1))
dataset = CSVDatasetDevEval(
dataset_protocol_path=atnt_protocol_path,
protocol_name="idiap_protocol",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=load,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
),
)
#### Testing it in a real recognition systems
transformer = wrap(["sample"], make_pipeline(FunctionTransformer(linearize)))
vanilla_biometrics_pipeline = VanillaBiometricsPipeline(transformer, Distance())
scores = vanilla_biometrics_pipeline(
dataset.background_model_samples(),
dataset.references(),
dataset.probes(),
)
assert len(scores)==100
assert np.alltrue([len(s)==20] for s in scores)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment