Commit 243b4e70 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Implemented ZTNorm interface

parent 9681ff8b
Pipeline #46294 passed with stage
in 13 minutes and 1 second
......@@ -5,6 +5,7 @@ from .csv_dataset import (
CSVBaseSampleLoader,
AnnotationsLoader,
LSTToSampleLoader,
CSVDatasetDevEvalZTNorm,
)
from .file import BioFile
from .file import BioFileSet
......
......@@ -244,6 +244,22 @@ class LSTToSampleLoader(CSVBaseSampleLoader):
#####
def path_discovery(dataset_protocol_path, option1, option2):
# If the input is a directory
if os.path.isdir(dataset_protocol_path):
option1 = os.path.join(dataset_protocol_path, option1)
option2 = os.path.join(dataset_protocol_path, option2)
if os.path.exists(option1):
return open(option1)
else:
return open(option2) if os.path.exists(option2) else None
# If it's not a directory is a tarball
op1 = find_element_in_tarball(dataset_protocol_path, option1)
return op1 if op1 else find_element_in_tarball(dataset_protocol_path, option2)
class CSVDatasetDevEval(Database):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
......@@ -333,54 +349,41 @@ class CSVDatasetDevEval(Database):
):
self.dataset_protocol_path = dataset_protocol_path
self.is_sparse = is_sparse
self.protocol_name = protocol_name
def get_paths():
if not os.path.exists(dataset_protocol_path):
raise ValueError(f"The path `{dataset_protocol_path}` was not found")
def path_discovery(option1, option2):
# If the input is a directory
if os.path.isdir(dataset_protocol_path):
option1 = os.path.join(dataset_protocol_path, option1)
option2 = os.path.join(dataset_protocol_path, option2)
if os.path.exists(option1):
return open(option1)
else:
return open(option2) if os.path.exists(option2) else None
# If it's not a directory is a tarball
op1 = find_element_in_tarball(dataset_protocol_path, option1)
return (
op1
if op1
else find_element_in_tarball(dataset_protocol_path, option2)
)
# Here we are handling the legacy
train_csv = path_discovery(
dataset_protocol_path,
os.path.join(protocol_name, "norm", "train_world.lst"),
os.path.join(protocol_name, "norm", "train_world.csv"),
)
dev_enroll_csv = path_discovery(
dataset_protocol_path,
os.path.join(protocol_name, "dev", "for_models.lst"),
os.path.join(protocol_name, "dev", "for_models.csv"),
)
legacy_probe = "for_scores.lst" if self.is_sparse else "for_probes.lst"
dev_probe_csv = path_discovery(
dataset_protocol_path,
os.path.join(protocol_name, "dev", legacy_probe),
os.path.join(protocol_name, "dev", "for_probes.csv"),
)
eval_enroll_csv = path_discovery(
dataset_protocol_path,
os.path.join(protocol_name, "eval", "for_models.lst"),
os.path.join(protocol_name, "eval", "for_models.csv"),
)
eval_probe_csv = path_discovery(
dataset_protocol_path,
os.path.join(protocol_name, "eval", legacy_probe),
os.path.join(protocol_name, "eval", "for_probes.csv"),
)
......@@ -438,24 +441,22 @@ class CSVDatasetDevEval(Database):
return self.cache["train"]
def _get_samplesets(
self, group="dev", purpose="enroll", group_by_reference_id=False
self,
group="dev",
cache_label=None,
group_by_reference_id=False,
fetching_probes=False,
is_sparse=False,
):
fetching_probes = False
if purpose == "enroll":
cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
else:
fetching_probes = True
cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
if self.cache[cache_label] is not None:
return self.cache[cache_label]
# Getting samples from CSV
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
samples = self.csv_to_sample_loader(self.__getattribute__(cache_label))
references = None
if fetching_probes and self.is_sparse:
if fetching_probes and is_sparse:
# Checking if `is_sparse` was set properly
if len(samples) > 0 and not hasattr(samples[0], "compare_reference_id"):
......@@ -487,13 +488,21 @@ class CSVDatasetDevEval(Database):
return self.cache[cache_label]
def references(self, group="dev"):
cache_label = "dev_enroll_csv" if group == "dev" else "eval_enroll_csv"
return self._get_samplesets(
group=group, purpose="enroll", group_by_reference_id=True
group=group, cache_label=cache_label, group_by_reference_id=True
)
def probes(self, group="dev"):
cache_label = "dev_probe_csv" if group == "dev" else "eval_probe_csv"
return self._get_samplesets(
group=group, purpose="probe", group_by_reference_id=False
group=group,
cache_label=cache_label,
group_by_reference_id=False,
fetching_probes=True,
is_sparse=self.is_sparse,
)
def all_samples(self, groups=None):
......@@ -534,7 +543,9 @@ class CSVDatasetDevEval(Database):
for group in groups:
for purpose in ("enroll", "probe"):
label = f"{group}_{purpose}_csv"
samples = samples + self.csv_to_sample_loader(self.__dict__[label])
samples = samples + self.csv_to_sample_loader(
self.__getattribute__(label)
)
return samples
def groups(self):
......@@ -559,6 +570,125 @@ class CSVDatasetDevEval(Database):
return groups
class CSVDatasetDevEvalZTNorm(Database):
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.ZTNormPipeline` pipelines.
Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
interface.
This dataset interface takes as in put a :any:`CSVDatasetDevEval` as input and have two extra methods:
:any:`CSVDatasetDevEvalZTNorm.zprobes` and :any:`CSVDatasetDevEvalZTNorm.treferences`.
To create a new dataset, you need to provide a directory structure similar to the one below:
.. code-block:: text
my_dataset/
my_dataset/my_protocol/norm/train_world.csv
my_dataset/my_protocol/norm/for_znorm.csv
my_dataset/my_protocol/norm/for_tnorm.csv
my_dataset/my_protocol/dev/for_models.csv
my_dataset/my_protocol/dev/for_probes.csv
my_dataset/my_protocol/eval/for_models.csv
my_dataset/my_protocol/eval/for_probes.csv
Parameters
----------
database: :any:`CSVDatasetDevEval`
:any:`CSVDatasetDevEval` to be aggregated
"""
def __init__(self, database):
self.database = database
self.cache = self.database.cache
self.csv_to_sample_loader = self.database.csv_to_sample_loader
self.protocol_name = self.database.protocol_name
self.dataset_protocol_path = self.database.dataset_protocol_path
self._get_samplesets = self.database._get_samplesets
## create_cache
self.cache["znorm_csv"] = None
self.cache["tnorm_csv"] = None
znorm_csv = path_discovery(
self.dataset_protocol_path,
os.path.join(self.protocol_name, "norm", "for_znorm.lst"),
os.path.join(self.protocol_name, "norm", "for_znorm.csv"),
)
tnorm_csv = path_discovery(
self.dataset_protocol_path,
os.path.join(self.protocol_name, "norm", "for_tnorm.lst"),
os.path.join(self.protocol_name, "norm", "for_tnorm.csv"),
)
if znorm_csv is None:
raise ValueError(
f"The file `for_znorm.lst` is required and it was not found in `{self.protocol_name}/norm` "
)
if tnorm_csv is None:
raise ValueError(
f"The file `for_tnorm.csv` is required and it was not found `{self.protocol_name}/norm`"
)
self.database.znorm_csv = znorm_csv
self.database.tnorm_csv = tnorm_csv
def background_model_samples(self):
return self.database.background_model_samples()
def references(self, group="dev"):
return self.database.references(group=group)
def probes(self, group="dev"):
return self.database.probes(group=group)
def all_samples(self, groups=None):
return self.database.all_samples(groups=groups)
def groups(self):
return self.database.groups()
def zprobes(self, group="dev", proportion=1.0):
if proportion <= 0 or proportion > 1:
raise ValueError(
f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
)
cache_label = "znorm_csv"
samplesets = self._get_samplesets(
group=group,
cache_label=cache_label,
group_by_reference_id=False,
fetching_probes=True,
is_sparse=False,
)
zprobes = samplesets[: int(len(samplesets) * proportion)]
return zprobes
def treferences(self, covariate="sex", proportion=1.0):
if proportion <= 0 or proportion > 1:
raise ValueError(
f"Invalid proportion value ({proportion}). Values allowed from [0-1]"
)
cache_label = "tnorm_csv"
samplesets = self._get_samplesets(
group="dev", cache_label=cache_label, group_by_reference_id=True,
)
treferences = samplesets[: int(len(samplesets) * proportion)]
return treferences
class CSVDatasetCrossValidation:
"""
Generic filelist dataset for :any:`bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline that
......
......@@ -13,6 +13,7 @@ from bob.bio.base.database import (
CSVDatasetCrossValidation,
AnnotationsLoader,
LSTToSampleLoader,
CSVDatasetDevEvalZTNorm,
)
import nose.tools
from bob.pipelines import DelayedSample, SampleSet
......@@ -151,6 +152,68 @@ def test_csv_file_list_dev_eval():
run(example_dir + ".tar.gz")
def test_csv_file_list_dev_eval_score_norm():
annotation_directory = os.path.realpath(
bob.io.base.test_utils.datafile(
".", __name__, "data/example_csv_filelist/annotations"
)
)
def run(filename):
dataset = CSVDatasetDevEval(
filename,
"protocol_dev_eval",
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load,
metadata_loader=AnnotationsLoader(
annotation_directory=annotation_directory,
annotation_extension=".pos",
annotation_type="eyecenter",
),
dataset_original_directory="",
extension="",
),
)
znorm_dataset = CSVDatasetDevEvalZTNorm(dataset)
assert len(znorm_dataset.background_model_samples()) == 8
assert check_all_true(znorm_dataset.background_model_samples(), DelayedSample)
assert len(znorm_dataset.references()) == 2
assert check_all_true(znorm_dataset.references(), SampleSet)
assert len(znorm_dataset.probes()) == 8
assert check_all_true(znorm_dataset.references(), SampleSet)
assert len(znorm_dataset.references(group="eval")) == 6
assert check_all_true(znorm_dataset.references(group="eval"), SampleSet)
assert len(znorm_dataset.probes(group="eval")) == 13
assert check_all_true(znorm_dataset.probes(group="eval"), SampleSet)
assert len(znorm_dataset.all_samples(groups=None)) == 47
assert check_all_true(znorm_dataset.all_samples(groups=None), DelayedSample)
# Check the annotations
for s in znorm_dataset.all_samples(groups=None):
assert isinstance(s.annotations, dict)
assert len(znorm_dataset.reference_ids(group="dev")) == 2
assert len(znorm_dataset.reference_ids(group="eval")) == 6
assert len(znorm_dataset.groups()) == 3
## Checking ZT-Norm stuff
assert len(znorm_dataset.treferences()) == 2
assert len(znorm_dataset.zprobes()) == 8
assert len(znorm_dataset.treferences(proportion=0.5)) == 1
assert len(znorm_dataset.zprobes(proportion=0.5)) == 4
run(example_dir)
run(example_dir + ".tar.gz")
def test_csv_file_list_dev_eval_sparse():
annotation_directory = os.path.realpath(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment