Skip to content
Snippets Groups Projects
Commit 6ae969d4 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Check if there are samples available for the function train_background_model

parent 076f80c2
No related branches found
No related tags found
1 merge request!180[dask] Preparing bob.bio.base for dask pipelines
...@@ -9,6 +9,7 @@ for bob.bio experiments ...@@ -9,6 +9,7 @@ for bob.bio experiments
""" """
import logging import logging
import numpy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,6 +20,7 @@ def biometric_pipeline( ...@@ -19,6 +20,7 @@ def biometric_pipeline(
probe_samples, probe_samples,
transformer, transformer,
biometric_algorithm, biometric_algorithm,
allow_scoring_with_all_biometric_references=False,
): ):
logger.info( logger.info(
f" >> Vanilla Biometrics: Training background model with pipeline {transformer}" f" >> Vanilla Biometrics: Training background model with pipeline {transformer}"
...@@ -43,12 +45,25 @@ def biometric_pipeline( ...@@ -43,12 +45,25 @@ def biometric_pipeline(
# Scores all probes # Scores all probes
return compute_scores( return compute_scores(
probe_samples, biometric_references, transformer, biometric_algorithm probe_samples,
biometric_references,
transformer,
biometric_algorithm,
allow_scoring_with_all_biometric_references,
) )
def train_background_model(background_model_samples, transformer): def train_background_model(background_model_samples, transformer):
# background_model_samples is a list of Samples # background_model_samples is a list of Samples
# We might have algorithms that has no data for training
if len(background_model_samples) <= 0:
logger.warning(
"There's no data to train background model."
"For the rest of the execution it will be assumed that the pipeline is stateless."
)
return transformer
transformer = transformer.fit(background_model_samples) transformer = transformer.fit(background_model_samples)
return transformer return transformer
...@@ -67,13 +82,21 @@ def create_biometric_reference( ...@@ -67,13 +82,21 @@ def create_biometric_reference(
def compute_scores( def compute_scores(
probe_samples, biometric_references, transformer, biometric_algorithm probe_samples,
biometric_references,
transformer,
biometric_algorithm,
allow_scoring_with_all_biometric_references=False,
): ):
# probes is a list of SampleSets # probes is a list of SampleSets
probe_features = transformer.transform(probe_samples) probe_features = transformer.transform(probe_samples)
scores = biometric_algorithm.score_samples(probe_features, biometric_references) scores = biometric_algorithm.score_samples(
probe_features,
biometric_references,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
# scores is a list of Samples # scores is a list of Samples
return scores return scores
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment