Commit 4930dda2 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Adapting CSVDevEval to work with our current FileList Structure

parent 52934fcb
Pipeline #46207 failed with stage
in 1 minute and 25 seconds
......@@ -13,4 +13,4 @@ sphinx
dist
build
record.txt
.DS_Store
*.DS_Store
......@@ -3,11 +3,10 @@ from .Distance import Distance
from .PCA import PCA
from .LDA import LDA
from .PLDA import PLDA
from .BIC import BIC
# gets sphinx autodoc done right - don't remove it
def __appropriate__(*args):
"""Says object was actually declared here, and not in the import module.
"""Says object was actually declared here, and not in the import module.
Fixing sphinx warnings of not being able to find classes, when path is shortened.
Parameters:
......@@ -17,15 +16,12 @@ def __appropriate__(*args):
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for obj in args: obj.__module__ = __name__
for obj in args:
obj.__module__ = __name__
__appropriate__(
Algorithm,
Distance,
PCA,
LDA,
PLDA,
BIC,
)
Algorithm, Distance, PCA, LDA, PLDA,
)
__all__ = [_ for _ in dir() if not _.startswith('_')]
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -3,6 +3,8 @@ from .csv_dataset import (
CSVToSampleLoader,
CSVDatasetCrossValidation,
CSVBaseSampleLoader,
IdiapAnnotationsLoader,
LSTToSampleLoader,
)
from .file import BioFile
from .file import BioFileSet
......
......@@ -12,9 +12,57 @@ from abc import ABCMeta, abstractmethod
import numpy as np
import itertools
import logging
import bob.db.base
from bob.bio.base.pipelines.vanilla_biometrics.abstract_classes import Database
logger = logging.getLogger(__name__)
#####
# ANNOTATIONS LOADERS
####
class IdiapAnnotationsLoader:
"""
Load annotations in the Idiap format
"""
def __init__(
self,
annotation_directory=None,
annotation_extension=".pos",
annotation_type="eyecenter",
):
self.annotation_directory = annotation_directory
self.annotation_extension = annotation_extension
self.annotation_type = annotation_type
def __call__(self, row, header=None):
if self.annotation_directory is None:
return None
path = row[0]
# since the file id is equal to the file name, we can simply use it
annotation_file = os.path.join(
self.annotation_directory, path + self.annotation_extension
)
# return the annotations as read from file
annotation = {
"annotations": bob.db.base.read_annotation_file(
annotation_file, self.annotation_type
)
}
return annotation
#######
# SAMPLE LOADERS
# CONVERT CSV LINES TO SAMPLES
#######
class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
Convert CSV files in the format below to either a list of
......@@ -22,10 +70,10 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
PATH,REFERENCE_ID
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
.. note::
......@@ -43,10 +91,17 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
"""
def __init__(self, data_loader, dataset_original_directory="", extension=""):
def __init__(
self,
data_loader,
metadata_loader=None,
dataset_original_directory="",
extension="",
):
self.data_loader = data_loader
self.extension = extension
self.dataset_original_directory = dataset_original_directory
self.metadata_loader = metadata_loader
@abstractmethod
def __call__(self, filename):
......@@ -56,11 +111,24 @@ class CSVBaseSampleLoader(metaclass=ABCMeta):
def convert_row_to_sample(self, row, header):
pass
@abstractmethod
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
self, samples, group_by_reference_id=True, references=None
):
pass
if group_by_reference_id:
# Grouping sample sets
sample_sets = dict()
for s in samples:
if s.reference_id not in sample_sets:
sample_sets[s.reference_id] = SampleSet(
[s], parent=s, references=references
)
else:
sample_sets[s.reference_id].append(s)
return list(sample_sets.values())
else:
return [SampleSet([s], parent=s, references=references) for s in samples]
class CSVToSampleLoader(CSVBaseSampleLoader):
......@@ -71,11 +139,13 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
def check_header(self, header):
"""
A header should have at least "SUBJECT" AND "PATH"
A header should have at least "reference_id" AND "PATH"
"""
header = [h.lower() for h in header]
if not "subject" in header:
raise ValueError("The field `subject` is not available in your dataset.")
if not "reference_id" in header:
raise ValueError(
"The field `reference_id` is not available in your dataset."
)
if not "path" in header:
raise ValueError("The field `path` is not available in your dataset.")
......@@ -91,42 +161,67 @@ class CSVToSampleLoader(CSVBaseSampleLoader):
def convert_row_to_sample(self, row, header):
path = row[0]
subject = row[1]
reference_id = row[1]
kwargs = dict([[h, r] for h, r in zip(header[2:], row[2:])])
if self.metadata_loader is not None:
metadata = self.metadata_loader(row)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
subject=subject,
reference_id=reference_id,
**kwargs,
)
def convert_samples_to_samplesets(
self, samples, group_by_subject=True, references=None
):
if group_by_subject:
# Grouping sample sets
sample_sets = dict()
for s in samples:
if s.subject not in sample_sets:
sample_sets[s.subject] = SampleSet(
[s], parent=s, references=references
)
else:
sample_sets[s.subject].append(s)
return list(sample_sets.values())
class LSTToSampleLoader(CSVBaseSampleLoader):
"""
Simple mechanism to convert LST files in the format below to either a list of
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
"""
def __call__(self, filename):
with open(filename) as cf:
reader = csv.reader(cf, delimiter=" ")
return [self.convert_row_to_sample(row) for row in reader]
def convert_row_to_sample(self, row, header=None):
path = row[0]
reference_id = str(row[1])
kwargs = dict()
if len(row) == 3:
subject = row[2]
kwargs = {"subject": str(subject)}
if self.metadata_loader is not None:
metadata = self.metadata_loader(row)
kwargs.update(metadata)
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
reference_id=reference_id,
**kwargs,
)
else:
return [
SampleSet([s], parent=s, references=references)
for s in samples
]
#####
# DATABASE INTERFACES
#####
class CSVDatasetDevEval:
class CSVDatasetDevEval(Database):
"""
Generic filelist dataset for :any:` bob.bio.base.pipelines.vanilla_biometrics.VanillaBiometricsPipeline` pipeline.
Check :any:`vanilla_biometrics_features` for more details about the Vanilla Biometrics Dataset
......@@ -154,17 +249,17 @@ class CSVDatasetDevEval:
- dev_probe.csv
Those csv files should contain in each row i-) the path to raw data and ii-) the subject label
Those csv files should contain in each row i-) the path to raw data and ii-) the reference_id label
for enrollment (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.references`) and
probing (:any:`bob.bio.base.pipelines.vanilla_biometrics.Database.probes`).
The structure of each CSV file should be as below:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
PATH,reference_id
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
......@@ -173,10 +268,10 @@ class CSVDatasetDevEval:
.. code-block:: text
PATH,SUBJECT,METADATA_1,METADATA_2,METADATA_k
path_1,subject_1,A,B,C
path_2,subject_2,A,B,1
path_i,subject_j,2,3,4
PATH,reference_id,METADATA_1,METADATA_2,METADATA_k
path_1,reference_id_1,A,B,C
path_2,reference_id_2,A,B,1
path_i,reference_id_j,2,3,4
...
......@@ -206,9 +301,14 @@ class CSVDatasetDevEval:
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
data_loader=bob.io.base.load, dataset_original_directory="", extension=""
data_loader=bob.io.base.load,
metadata_loader=None,
dataset_original_directory="",
extension="",
),
):
self.dataset_protocol_path = dataset_protocol_path
def get_paths():
if not os.path.exists(dataset_protocol_path):
......@@ -219,11 +319,34 @@ class CSVDatasetDevEval:
if not os.path.exists(protocol_path):
raise ValueError(f"The protocol `{protocol_name}` was not found")
train_csv = os.path.join(protocol_path, "train.csv")
dev_enroll_csv = os.path.join(protocol_path, "dev_enroll.csv")
dev_probe_csv = os.path.join(protocol_path, "dev_probe.csv")
eval_enroll_csv = os.path.join(protocol_path, "eval_enroll.csv")
eval_probe_csv = os.path.join(protocol_path, "eval_probe.csv")
def path_discovery(option1, option2):
return option1 if os.path.exists(option1) else option2
# Here we are handling the legacy
train_csv = path_discovery(
os.path.join(protocol_path, "norm", "train_world.lst"),
os.path.join(protocol_path, "norm", "train_world.csv"),
)
dev_enroll_csv = path_discovery(
os.path.join(protocol_path, "dev", "for_models.lst"),
os.path.join(protocol_path, "dev", "for_models.csv"),
)
dev_probe_csv = path_discovery(
os.path.join(protocol_path, "dev", "for_probes.lst"),
os.path.join(protocol_path, "dev", "for_probes.csv"),
)
eval_enroll_csv = path_discovery(
os.path.join(protocol_path, "eval", "for_models.lst"),
os.path.join(protocol_path, "eval", "for_models.csv"),
)
eval_probe_csv = path_discovery(
os.path.join(protocol_path, "eval", "for_probes.lst"),
os.path.join(protocol_path, "eval", "for_probes.csv"),
)
# The minimum required is to have `dev_enroll_csv` and `dev_probe_csv`
train_csv = train_csv if os.path.exists(train_csv) else None
......@@ -244,6 +367,8 @@ class CSVDatasetDevEval:
raise ValueError(
f"The file `{dev_probe_csv}` is required and it was not found"
)
dev_enroll_csv = dev_enroll_csv
dev_probe_csv = dev_probe_csv
return (
train_csv,
......@@ -274,7 +399,6 @@ class CSVDatasetDevEval:
self.csv_to_sample_loader = csv_to_sample_loader
def background_model_samples(self):
self.cache["train"] = (
self.csv_to_sample_loader(self.train_csv)
if self.cache["train"] is None
......@@ -283,7 +407,9 @@ class CSVDatasetDevEval:
return self.cache["train"]
def _get_samplesets(self, group="dev", purpose="enroll", group_by_subject=False):
def _get_samplesets(
self, group="dev", purpose="enroll", group_by_reference_id=False
):
fetching_probes = False
if purpose == "enroll":
......@@ -297,12 +423,14 @@ class CSVDatasetDevEval:
references = None
if fetching_probes:
references = list(set([s.subject for s in self.references(group=group)]))
references = list(
set([s.reference_id for s in self.references(group=group)])
)
samples = self.csv_to_sample_loader(self.__dict__[cache_label])
sample_sets = self.csv_to_sample_loader.convert_samples_to_samplesets(
samples, group_by_subject=group_by_subject, references=references
samples, group_by_reference_id=group_by_reference_id, references=references
)
self.cache[cache_label] = sample_sets
......@@ -311,12 +439,12 @@ class CSVDatasetDevEval:
def references(self, group="dev"):
return self._get_samplesets(
group=group, purpose="enroll", group_by_subject=True
group=group, purpose="enroll", group_by_reference_id=True
)
def probes(self, group="dev"):
return self._get_samplesets(
group=group, purpose="probe", group_by_subject=False
group=group, purpose="probe", group_by_reference_id=False
)
def all_samples(self, groups=None):
......@@ -360,6 +488,27 @@ class CSVDatasetDevEval:
samples = samples + self.csv_to_sample_loader(self.__dict__[label])
return samples
def groups(self):
"""This function returns the list of groups for this database.
Returns
-------
[str]
A list of groups
"""
# We always have dev-set
groups = ["dev"]
if self.train_csv is not None:
groups.append("train")
if self.eval_enroll_csv is not None:
groups.append("eval")
return groups
class CSVDatasetCrossValidation:
"""
......@@ -377,10 +526,10 @@ class CSVDatasetCrossValidation:
.. code-block:: text
PATH,SUBJECT
path_1,subject_1
path_2,subject_2
path_i,subject_j
PATH,reference_id
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
Parameters
......@@ -393,7 +542,7 @@ class CSVDatasetCrossValidation:
Pseudo-random number generator seed
test_size: float
Percentage of the subjects used for testing
Percentage of the reference_ids used for testing
samples_for_enrollment: float
Number of samples used for enrollment
......@@ -435,30 +584,35 @@ class CSVDatasetCrossValidation:
def _do_cross_validation(self):
# Shuffling samples by subject
samples_by_subject = group_samples_by_subject(
# Shuffling samples by reference_id
samples_by_reference_id = group_samples_by_reference_id(
self.csv_to_sample_loader(self.csv_file_name)
)
subjects = list(samples_by_subject.keys())
reference_ids = list(samples_by_reference_id.keys())
np.random.seed(self.random_state)
np.random.shuffle(subjects)
np.random.shuffle(reference_ids)
# Getting the training data
n_samples_for_training = len(subjects) - int(self.test_size * len(subjects))
n_samples_for_training = len(reference_ids) - int(
self.test_size * len(reference_ids)
)
self.cache["train"] = list(
itertools.chain(
*[samples_by_subject[s] for s in subjects[0:n_samples_for_training]]
*[
samples_by_reference_id[s]
for s in reference_ids[0:n_samples_for_training]
]
)
)
# Splitting enroll and probe
self.cache["dev_enroll_csv"] = []
self.cache["dev_probe_csv"] = []
for s in subjects[n_samples_for_training:]:
samples = samples_by_subject[s]
for s in reference_ids[n_samples_for_training:]:
samples = samples_by_reference_id[s]
if len(samples) < self.samples_for_enrollment:
raise ValueError(
f"Not enough samples ({len(samples)}) for enrollment for the subject {s}"
f"Not enough samples ({len(samples)}) for enrollment for the reference_id {s}"
)
# Enrollment samples
......@@ -472,8 +626,8 @@ class CSVDatasetCrossValidation:
"dev_probe_csv"
] += self.csv_to_sample_loader.convert_samples_to_samplesets(
samples[self.samples_for_enrollment :],
group_by_subject=False,
references=subjects[n_samples_for_training:],
group_by_reference_id=False,
references=reference_ids[n_samples_for_training:],
)
def _load_from_cache(self, cache_key):
......@@ -527,12 +681,12 @@ class CSVDatasetCrossValidation:
return samples
def group_samples_by_subject(samples):
def group_samples_by_reference_id(samples):
# Grouping sample sets
samples_by_subject = dict()
samples_by_reference_id = dict()
for s in samples:
if s.subject not in samples_by_subject:
samples_by_subject[s.subject] = []
samples_by_subject[s.subject].append(s)
return samples_by_subject
if s.reference_id not in samples_by_reference_id:
samples_by_reference_id[s.reference_id] = []
samples_by_reference_id[s.reference_id].append(s)
return samples_by_reference_id
......@@ -178,18 +178,18 @@ class BioAlgorithm(metaclass=ABCMeta):
"""
for r in biometric_references:
if (
str(r.subject) in probe_refererences
and str(r.subject) not in self.stacked_biometric_references
str(r.reference_id) in probe_refererences
and str(r.reference_id) not in self.stacked_biometric_references
):
self.stacked_biometric_references[str(r.subject)] = r.data
self.stacked_biometric_references[str(r.reference_id)] = r.data
for probe_sample in sampleset:
cache_references(sampleset.references)
references = [
self.stacked_biometric_references[str(r.subject)]
self.stacked_biometric_references[str(r.reference_id)]
for r in biometric_references
if str(r.subject) in sampleset.references
if str(r.reference_id) in sampleset.references
]
scores = self.score_multiple_biometric_references(
......@@ -204,7 +204,7 @@ class BioAlgorithm(metaclass=ABCMeta):
[
r
for r in biometric_references
if str(r.subject) in sampleset.references
if str(r.reference_id) in sampleset.references
],
total_scores,
):
......@@ -328,6 +328,12 @@ class Database(metaclass=ABCMeta):
"""
pass
def groups(self):
pass
def reference_ids(self, group):
return [s.reference_id for s in self.references(group=group)]
class ScoreWriter(metaclass=ABCMeta):
"""
......
......@@ -29,7 +29,7 @@ def _biofile_to_delayed_sample(biofile, database):
load=functools.partial(
biofile.load, database.original_directory, database.original_extension,
),
subject=str(biofile.client_id),
reference_id=str(biofile.client_id),
key=biofile.path,
path=biofile.path,
delayed_attributes=dict(
......@@ -138,7 +138,7 @@ class DatabaseConnector(Database):
[_biofile_to_delayed_sample(k, self.database) for k in objects],
key=str(m),
path=str(m),
subject=str(objects[0].client_id),