Commit cea347c0 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Merge branch 'fixes' into 'master'

Legacy database wrapper was not supporting this properly supporting the use case where `model_id<>client_id`

See merge request !241
parents e26eebe4 a429a362
Pipeline #48052 passed with stages
in 10 minutes and 55 seconds
from .csv_dataset import (
CSVDataset,
CSVToSampleLoader,
CSVToSampleLoaderBiometrics,
CSVDatasetCrossValidation,
LSTToSampleLoader,
CSVDatasetZTNorm,
......@@ -35,7 +35,7 @@ __appropriate__(
BioDatabase,
ZTBioDatabase,
CSVDataset,
CSVToSampleLoader,
CSVToSampleLoaderBiometrics,
CSVDatasetCrossValidation,
)
__all__ = [_ for _ in dir() if not _.startswith("_")]
......@@ -74,7 +74,71 @@ class LSTToSampleLoader(CSVToSampleLoader):
kwargs = dict()
if len(row) == 3:
subject = row[2]
kwargs = {"subject": str(subject)}
kwargs = {"subject_id": str(subject)}
return DelayedSample(
functools.partial(
self.data_loader,
os.path.join(self.dataset_original_directory, path + self.extension),
),
key=path,
reference_id=reference_id,
**kwargs,
)
class CSVToSampleLoaderBiometrics(CSVToSampleLoader):
"""
Base class that converts the lines of a CSV file, like the one below to
:any:`bob.pipelines.DelayedSample` or :any:`bob.pipelines.SampleSet`
.. code-block:: text
PATH,REFERENCE_ID
path_1,reference_id_1
path_2,reference_id_2
path_i,reference_id_j
...
Parameters
----------
data_loader:
A python function that can be called parameterlessly, to load the
sample in question from whatever medium
dataset_original_directory: str
Path of where data is stored
extension: str
Default file extension
"""
def __init__(
self,
data_loader,
dataset_original_directory="",
extension="",
reference_id_equal_subject_id=True,
):
super().__init__(
data_loader=data_loader,
extension=extension,
dataset_original_directory=dataset_original_directory,
)
self.reference_id_equal_subject_id = reference_id_equal_subject_id
def convert_row_to_sample(self, row, header):
path = row[0]
reference_id = row[1]
kwargs = dict([[str(h).lower(), r] for h, r in zip(header[2:], row[2:])])
if self.reference_id_equal_subject_id:
kwargs["subject_id"] = reference_id
else:
if "subject_id" not in kwargs:
raise ValueError(f"`subject_id` not available in {header}")
return DelayedSample(
functools.partial(
......@@ -166,7 +230,7 @@ class CSVDataset(Database):
self,
dataset_protocol_path,
protocol_name,
csv_to_sample_loader=CSVToSampleLoader(
csv_to_sample_loader=CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load, dataset_original_directory="", extension="",
),
is_sparse=False,
......@@ -574,7 +638,7 @@ class CSVDatasetCrossValidation:
random_state=0,
test_size=0.8,
samples_for_enrollment=1,
csv_to_sample_loader=CSVToSampleLoader(
csv_to_sample_loader=CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load, dataset_original_directory="", extension=""
),
):
......
......@@ -146,7 +146,6 @@ class BioAlgorithm(metaclass=ABCMeta):
# static batch of biometric references
total_scores = []
for probe_sample in sampleset:
# Multiple scoring
if self.stacked_biometric_references is None:
self.stacked_biometric_references = [
......@@ -184,7 +183,6 @@ class BioAlgorithm(metaclass=ABCMeta):
self.stacked_biometric_references[str(r.reference_id)] = r.data
for probe_sample in sampleset:
cache_references(sampleset.references)
references = [
self.stacked_biometric_references[str(r.reference_id)]
......@@ -192,6 +190,12 @@ class BioAlgorithm(metaclass=ABCMeta):
if str(r.reference_id) in sampleset.references
]
if len(references) == 0:
raise ValueError(
f"The probe {sampleset} can't be compared with any biometric reference. "
"Something is probably wrong with your database interface."
)
scores = self.score_multiple_biometric_references(
references, probe_sample.data
)
......
......@@ -10,7 +10,7 @@ import functools
class Distance(BioAlgorithm):
def __init__(
self, distance_function=scipy.spatial.distance.euclidean, factor=-1, **kwargs
self, distance_function=scipy.spatial.distance.cosine, factor=-1, **kwargs
):
super().__init__(**kwargs)
self.distance_function = distance_function
......
......@@ -171,7 +171,8 @@ class DatabaseConnector(Database):
[_biofile_to_delayed_sample(k, self.database) for k in objects],
key=str(m),
path=str(m),
reference_id=str(objects[0].client_id),
reference_id=(str(m)),
subject_id=str(self.database.client_id_from_model_id(m)),
)
)
......@@ -199,7 +200,6 @@ class DatabaseConnector(Database):
"""
probes = dict()
for m in self.database.model_ids(groups=group):
# Getting all the probe objects from a particular biometric
......@@ -212,12 +212,12 @@ class DatabaseConnector(Database):
[_biofile_to_delayed_sample(o, self.database)],
key=str(o.path),
path=o.path,
reference_id=str(o.client_id),
reference_id=str(m),
references=[str(m)],
subject_id=o.client_id,
)
else:
probes[o.id].references.append(str(m))
probes[o.id].references.append(str(str(m)))
return list(probes.values())
def all_samples(self, groups=None):
......
......@@ -37,8 +37,8 @@ class FourColumnsScoreWriter(ScoreWriter):
lines = [
"{0} {1} {2} {3}\n".format(
biometric_reference.reference_id,
probe.reference_id,
biometric_reference.subject_id,
probe.subject_id,
probe.key,
biometric_reference.data,
)
......
......@@ -12,6 +12,7 @@ from bob.bio.base.database import (
CSVDatasetCrossValidation,
LSTToSampleLoader,
CSVDatasetZTNorm,
CSVToSampleLoaderBiometrics,
)
import nose.tools
from bob.pipelines import DelayedSample, SampleSet
......@@ -25,7 +26,7 @@ from bob.bio.base.database import FileListBioDatabase
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import make_pipeline
from bob.pipelines import wrap
from bob.pipelines.datasets import AnnotationsLoader, CSVToSampleLoader
from bob.pipelines.datasets import AnnotationsLoader
legacy_example_dir = os.path.realpath(
......@@ -110,7 +111,7 @@ def test_csv_file_list_dev_eval():
filename,
"protocol_dev_eval",
csv_to_sample_loader=make_pipeline(
CSVToSampleLoader(
CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load,
dataset_original_directory="",
extension="",
......@@ -166,7 +167,7 @@ def test_csv_file_list_dev_eval_score_norm():
filename,
"protocol_dev_eval",
csv_to_sample_loader=make_pipeline(
CSVToSampleLoader(
CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load,
dataset_original_directory="",
extension="",
......@@ -230,7 +231,7 @@ def test_csv_file_list_dev_eval_sparse():
example_dir,
"protocol_dev_eval_sparse",
csv_to_sample_loader=make_pipeline(
CSVToSampleLoader(
CSVToSampleLoaderBiometrics(
data_loader=bob.io.base.load,
dataset_original_directory="",
extension="",
......@@ -396,7 +397,7 @@ def test_csv_cross_validation_atnt():
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
test_size=0.8,
csv_to_sample_loader=CSVToSampleLoader(
csv_to_sample_loader=CSVToSampleLoaderBiometrics(
data_loader=data_loader,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
......@@ -428,7 +429,7 @@ def test_atnt_experiment():
dataset = CSVDataset(
dataset_protocol_path=atnt_protocol_path,
protocol_name="idiap_protocol",
csv_to_sample_loader=CSVToSampleLoader(
csv_to_sample_loader=CSVToSampleLoaderBiometrics(
data_loader=data_loader,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
......@@ -451,7 +452,7 @@ def test_atnt_experiment_cross_validation():
csv_file_name=atnt_protocol_path_cross_validation,
random_state=0,
test_size=test_size,
csv_to_sample_loader=CSVToSampleLoader(
csv_to_sample_loader=CSVToSampleLoaderBiometrics(
data_loader=data_loader,
dataset_original_directory=atnt_database_directory(),
extension=".pgm",
......@@ -691,6 +692,7 @@ def test_multiple_extensions():
file = bob.bio.base.database.BioFile(
4, "data/model4_session1_sample2", "data/model4_session1_sample2"
)
file_name = db.original_file_name(file, True)
assert file_name == os.path.join(legacy_example_dir, file.path + ".pos")
......
......@@ -46,6 +46,7 @@ class DummyDatabase:
key=str(uuid.uuid4()),
annotations=1,
reference_id=str(i),
subject_id=str(i),
)
for i in range(offset, offset + n_samples)
]
......@@ -57,6 +58,7 @@ class DummyDatabase:
key=str(uuid.uuid4()),
annotations=1,
reference_id=str(i),
subject_id=str(i),
)
for i in range(offset, offset + n_samples)
]
......@@ -70,6 +72,7 @@ class DummyDatabase:
samples=[],
key=str(i),
reference_id=str(i),
subject_id=str(i),
gender=np.random.choice(self.gender_choices),
metadata_1=np.random.choice(self.metadata_1_choices),
)
......
......@@ -33,7 +33,7 @@ import bob.pipelines as mario
import uuid
import shutil
import itertools
from scipy.spatial.distance import cdist
from scipy.spatial.distance import cdist, euclidean
from sklearn.preprocessing import FunctionTransformer
import copy
......@@ -147,6 +147,7 @@ def test_norm_mechanics():
[Sample(s, reference_id=str(i + offset), key=str(uuid.uuid4()))],
key=str(i + offset),
reference_id=str(i + offset),
subject_id=str(i + offset),
)
for i, s in enumerate(raw_data)
]
......@@ -156,6 +157,7 @@ def test_norm_mechanics():
[Sample(s, reference_id=str(i + offset), key=str(uuid.uuid4()))],
key=str(i + offset),
reference_id=str(i + offset),
subject_id=str(i + offset),
references=references,
)
for i, s in enumerate(raw_data)
......@@ -226,11 +228,11 @@ def test_norm_mechanics():
#############
transformer = make_pipeline(FunctionTransformer(func=_do_nothing_fn))
biometric_algorithm = Distance(factor=1)
biometric_algorithm = Distance(euclidean, factor=1)
if with_checkpoint:
biometric_algorithm = BioAlgorithmCheckpointWrapper(
Distance(factor=1), dir_name
Distance(distance_function=euclidean, factor=1), dir_name,
)
vanilla_pipeline = VanillaBiometricsPipeline(
......@@ -252,6 +254,7 @@ def test_norm_mechanics():
raw_scores = _dump_scores_from_samples(
score_samples, shape=(n_probes, n_references)
)
assert np.allclose(raw_scores, raw_scores_ref)
############
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment