Skip to content
Snippets Groups Projects
Commit 34a0591c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[legacy] Allowed the scoring of statefull bob.bio.base.algorithm.Algorithm

parent 92748386
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
Pipeline #39048 failed
...@@ -214,10 +214,11 @@ class Algorithm (object): ...@@ -214,10 +214,11 @@ class Algorithm (object):
score : float score : float
The fused similarity between the given ``models`` and the ``probe``. The fused similarity between the given ``models`` and the ``probe``.
""" """
if isinstance(models, list): if isinstance(models, list):
return self.model_fusion_function([self.score(model, probe) for model in models]) return [self.score(model, probe) for model in models]
elif isinstance(models, numpy.ndarray): elif isinstance(models, numpy.ndarray):
return self.model_fusion_function([self.score(models[i,:], probe) for i in range(models.shape[0])]) return [self.score(models[i,:], probe) for i in range(models.shape[0])]
else: else:
raise ValueError("The model does not have the desired format (list, array, ...)") raise ValueError("The model does not have the desired format (list, array, ...)")
......
from bob.bio.face.database import AtntBioDatabase
from bob.bio.gmm.algorithm import ISV
from bob.bio.face.preprocessor import FaceCrop
from sklearn.pipeline import make_pipeline
from bob.bio.base.pipelines.vanilla_biometrics.legacy import DatabaseConnector, Preprocessor, AlgorithmAsTransformer, AlgorithmAsBioAlg, Extractor
import functools
from bob.bio.base.pipelines.vanilla_biometrics.implemented import (
Distance,
CheckpointDistance,
)
import os
# DATABASE
database = DatabaseConnector(
AtntBioDatabase(original_directory="./atnt", protocol="Default"),
)
database.allow_scoring_with_all_biometric_references = True
base_dir = "example/isv"
# PREPROCESSOR LEGACY
# Cropping
CROPPED_IMAGE_HEIGHT = 80
CROPPED_IMAGE_WIDTH = CROPPED_IMAGE_HEIGHT * 4 // 5
# eye positions for frontal images
RIGHT_EYE_POS = (CROPPED_IMAGE_HEIGHT // 5, CROPPED_IMAGE_WIDTH // 4 - 1)
LEFT_EYE_POS = (CROPPED_IMAGE_HEIGHT // 5, CROPPED_IMAGE_WIDTH // 4 * 3)
# RANDOM EYES POSITIONS
# I JUST MADE UP THESE NUMBERS
FIXED_RIGHT_EYE_POS = (30, 30)
FIXED_LEFT_EYE_POS = (20, 50)
face_cropper = functools.partial(
FaceCrop,
cropped_image_size=(CROPPED_IMAGE_HEIGHT, CROPPED_IMAGE_WIDTH),
cropped_positions={"leye": LEFT_EYE_POS, "reye": RIGHT_EYE_POS},
fixed_positions={"leye": FIXED_LEFT_EYE_POS, "reye": FIXED_RIGHT_EYE_POS},
)
import bob.bio.face
extractor = functools.partial(
bob.bio.face.extractor.DCTBlocks,
block_size=12,
block_overlap=11,
number_of_dct_coefficients=45,
)
# ALGORITHM LEGACY
isv = functools.partial(ISV, subspace_dimension_of_u=10, number_of_gaussians=2)
model_path=os.path.join(base_dir, "ubm_u.hdf5")
transformer = make_pipeline(
Preprocessor(callable=face_cropper, features_dir=os.path.join(base_dir,"face_crop")),
Extractor(extractor, features_dir=os.path.join(base_dir, "dcts")),
AlgorithmAsTransformer(
callable=isv, features_dir=os.path.join(base_dir,"isv"), model_path=model_path
),
)
algorithm = AlgorithmAsBioAlg(callable=isv, features_dir=base_dir, model_path=model_path)
from bob.bio.base.pipelines.vanilla_biometrics import VanillaBiometrics, dask_vanilla_biometrics
#pipeline = VanillaBiometrics(transformer, algorithm)
pipeline = dask_vanilla_biometrics(VanillaBiometrics(transformer, algorithm))
...@@ -20,6 +20,7 @@ from bob.pipelines.sample import DelayedSample, SampleSet, Sample ...@@ -20,6 +20,7 @@ from bob.pipelines.sample import DelayedSample, SampleSet, Sample
from bob.pipelines.utils import is_picklable from bob.pipelines.utils import is_picklable
from sklearn.base import TransformerMixin, BaseEstimator from sklearn.base import TransformerMixin, BaseEstimator
import logging import logging
import copy
logger = logging.getLogger("bob.bio.base") logger = logging.getLogger("bob.bio.base")
...@@ -72,9 +73,7 @@ class DatabaseConnector(Database): ...@@ -72,9 +73,7 @@ class DatabaseConnector(Database):
model training. See, e.g., :py:func:`.pipelines.first`. model training. See, e.g., :py:func:`.pipelines.first`.
""" """
objects = self.database.training_files() objects = self.database.training_files()
return [_biofile_to_delayed_sample(k, self.database) for k in objects] return [_biofile_to_delayed_sample(k, self.database) for k in objects]
def references(self, group="dev"): def references(self, group="dev"):
...@@ -376,7 +375,9 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -376,7 +375,9 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
""" """
def __init__(self, callable, features_dir, extension=".hdf5", **kwargs): def __init__(
self, callable, features_dir, extension=".hdf5", model_path=None, **kwargs
):
super().__init__(callable, **kwargs) super().__init__(callable, **kwargs)
self.features_dir = features_dir self.features_dir = features_dir
self.biometric_reference_dir = os.path.join( self.biometric_reference_dir = os.path.join(
...@@ -384,11 +385,49 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -384,11 +385,49 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
) )
self.score_dir = os.path.join(self.features_dir, "scores") self.score_dir = os.path.join(self.features_dir, "scores")
self.extension = extension self.extension = extension
self.model_path = model_path
self.is_projector_loaded = False
def _enroll_sample_set(self, sampleset): def _enroll_sample_set(self, sampleset):
# Enroll # Enroll
return self.enroll(sampleset) return self.enroll(sampleset)
def _load_projector(self):
"""
Run :py:meth:`bob.bio.base.algorithm.Algorithm.load_projector` if necessary by
:py:class:`bob.bio.base.algorithm.Algorithm`
"""
if self.instance.performs_projection and not self.is_projector_loaded:
if self.model_path is None:
raise ValueError(
"Algorithm " + f"{self. instance} performs_projection. Hence, "
"`model_path` needs to passed in `AlgorithmAsBioAlg.__init__`"
)
else:
# Loading model
self.instance.load_projector(self.model_path)
self.is_projector_loaded = True
def _restore_state_of_ref(self, ref):
"""
There are some algorithms that :py:meth:`bob.bio.base.algorithm.Algorithm.read_model` or
:py:meth:`bob.bio.base.algorithm.Algorithm.read_feature` depends
on the state of `self` to be properly loaded.
In these cases, it's not possible to rely only in the unbounded method extracted by
:py:func:`_get_pickable_method`.
This function replaces the current state of these objects (that are not)
by bounding them with `self.instance`
"""
if isinstance(ref, DelayedSample):
new_ref = copy.copy(ref)
new_ref.load = functools.partial(ref.load.func, self.instance, ref.load.args[1])
return new_ref
else:
return ref
def _score_sample_set( def _score_sample_set(
self, self,
sampleset, sampleset,
...@@ -408,20 +447,31 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -408,20 +447,31 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
data = make_four_colums_score(ref.subject, probe.subject, probe.path, score) data = make_four_colums_score(ref.subject, probe.subject, probe.path, score)
return Sample(data, parent=ref) return Sample(data, parent=ref)
self._load_projector()
retval = [] retval = []
for subprobe_id, s in enumerate(sampleset.samples): for subprobe_id, s in enumerate(sampleset.samples):
# Creating one sample per comparison # Creating one sample per comparison
subprobe_scores = [] subprobe_scores = []
if allow_scoring_with_all_biometric_references: if allow_scoring_with_all_biometric_references:
if self.stacked_biometric_references is None: if self.stacked_biometric_references is None:
if self.instance.performs_projection:
# Hydrating the state of biometric references
biometric_references = [
self._restore_state_of_ref(ref)
for ref in biometric_references
]
self.stacked_biometric_references = [ self.stacked_biometric_references = [
ref.data for ref in biometric_references ref.data for ref in biometric_references
] ]
s = self._restore_state_of_ref(s)
scores = self.score_multiple_biometric_references( scores = self.score_multiple_biometric_references(
self.stacked_biometric_references, s.data self.stacked_biometric_references, s.data
) )
# Wrapping the scores in samples # Wrapping the scores in samples
for ref, score in zip(biometric_references, scores): for ref, score in zip(biometric_references, scores):
subprobe_scores.append(_write_sample(ref, sampleset, score)) subprobe_scores.append(_write_sample(ref, sampleset, score))
...@@ -430,6 +480,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -430,6 +480,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
for ref in [ for ref in [
r for r in biometric_references if r.key in sampleset.references r for r in biometric_references if r.key in sampleset.references
]: ]:
ref = self._restore_state_of_ref(ref)
score = self.score(ref.data, s.data) score = self.score(ref.data, s.data)
subprobe_scores.append(_write_sample(ref, sampleset, score)) subprobe_scores.append(_write_sample(ref, sampleset, score))
...@@ -458,6 +509,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm): ...@@ -458,6 +509,7 @@ class AlgorithmAsBioAlg(_NonPickableWrapper, BioAlgorithm):
path = os.path.join( path = os.path.join(
self.biometric_reference_dir, str(enroll_features.key) + self.extension self.biometric_reference_dir, str(enroll_features.key) + self.extension
) )
self._load_projector()
if path is None or not os.path.isfile(path): if path is None or not os.path.isfile(path):
# Enrolling # Enrolling
data = [s.data for s in enroll_features.samples] data = [s.data for s in enroll_features.samples]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment