Commit cf557040 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Moved the ZTNorm code from the CLI command to the API

parent 1a8942e5
......@@ -14,6 +14,7 @@ from bob.extension.scripts.click_helper import (
)
import logging
<<<<<<< HEAD
import os
from bob.bio.base.pipelines.vanilla_biometrics import (
BioAlgorithmDaskWrapper,
......@@ -32,6 +33,9 @@ from .vanilla_biometrics import (
post_process_scores,
)
import copy
=======
from bob.bio.base.pipelines.vanilla_biometrics import execute_vanilla_biometrics_ztnorm
>>>>>>> Moved the ZTNorm code from the CLI command to the API
logger = logging.getLogger(__name__)
......@@ -142,6 +146,24 @@ It is possible to do it via configuration file
help="If set, it will checkpoint all steps of the pipeline. Checkpoints will be saved in `--output`.",
cls=ResourceOption,
)
@click.option(
"--dask-partition-size",
"-s",
help="If using Dask, this option defines the size of each dask.bag.partition."
"Use this option if the current heuristic that sets this value doesn't suit your experiment."
"(https://docs.dask.org/en/latest/bag-api.html?highlight=partition_size#dask.bag.from_sequence).",
default=None,
type=int,
)
@click.option(
"--dask-n-workers",
"-n",
help="If using Dask, this option defines the number of workers to start your experiment."
"Dask automatically scales up/down the number of workers due to the current load of tasks to be solved."
"Use this option if the current amount of workers set to start an experiment doesn't suit you.",
default=None,
type=int,
)
@verbosity_option(cls=ResourceOption)
def vanilla_biometrics_ztnorm(
pipeline,
......@@ -153,6 +175,8 @@ def vanilla_biometrics_ztnorm(
write_metadata_scores,
ztnorm_cohort_proportion,
checkpoint,
dask_partition_size,
dask_n_workers,
**kwargs,
):
"""Runs the the vanilla-biometrics with ZT-Norm like score normalizations.
......@@ -206,129 +230,20 @@ def vanilla_biometrics_ztnorm(
"""
def _merge_references_ztnorm(biometric_references, probes, zprobes, treferences):
treferences_sub = [t.subject for t in treferences]
biometric_references_sub = [t.subject for t in biometric_references]
for i in range(len(zprobes)):
probes[i].references += treferences_sub
for i in range(len(zprobes)):
zprobes[i].references = biometric_references_sub + treferences_sub
return probes, zprobes
if not os.path.exists(output):
os.makedirs(output, exist_ok=True)
if write_metadata_scores:
pipeline.score_writer = CSVScoreWriter(os.path.join(output, "./tmp"))
else:
pipeline.score_writer = FourColumnsScoreWriter(os.path.join(output, "./tmp"))
# Check if it's already checkpointed
if checkpoint and not is_checkpointed(pipeline):
hash_fn = database.hash_fn if hasattr(database, "hash_fn") else None
pipeline = checkpoint_vanilla_biometrics(pipeline, output, hash_fn=hash_fn)
# Patching the pipeline in case of ZNorm and checkpointing it
pipeline = ZTNormPipeline(pipeline)
if checkpoint:
pipeline.ztnorm_solver = ZTNormCheckpointWrapper(
pipeline.ztnorm_solver, os.path.join(output, "normed-scores")
)
background_model_samples = database.background_model_samples()
zprobes = database.zprobes(proportion=ztnorm_cohort_proportion)
treferences = database.treferences(proportion=ztnorm_cohort_proportion)
for group in groups:
score_file_name = os.path.join(output, f"scores-{group}")
biometric_references = database.references(group=group)
probes = database.probes(group=group)
if dask_client is not None and not isinstance_nested(
pipeline.biometric_algorithm, "biometric_algorithm", BioAlgorithmDaskWrapper
):
n_objects = max(
len(background_model_samples), len(biometric_references), len(probes)
)
pipeline = dask_vanilla_biometrics(
pipeline,
partition_size=dask_get_partition_size(dask_client.cluster, n_objects),
)
logger.info(f"Running vanilla biometrics for group {group}")
allow_scoring_with_all_biometric_references = (
database.allow_scoring_with_all_biometric_references
if hasattr(database, "allow_scoring_with_all_biometric_references")
else False
)
if consider_genuines:
z_probes_cpy = copy.deepcopy(zprobes)
zprobes += copy.deepcopy(treferences)
treferences += z_probes_cpy
probes, zprobes = _merge_references_ztnorm(
biometric_references, probes, zprobes, treferences
)
(
raw_scores,
z_normed_scores,
t_normed_scores,
zt_normed_scores,
s_normed_scores,
) = pipeline(
background_model_samples,
biometric_references,
probes,
zprobes,
treferences,
allow_scoring_with_all_biometric_references=allow_scoring_with_all_biometric_references,
)
def _build_filename(score_file_name, suffix):
return os.path.join(score_file_name, suffix)
# Running RAW_SCORES
raw_scores = post_process_scores(
pipeline, raw_scores, _build_filename(score_file_name, "raw_scores")
)
_ = compute_scores(raw_scores, dask_client)
# Z-SCORES
z_normed_scores = post_process_scores(
pipeline,
z_normed_scores,
_build_filename(score_file_name, "z_normed_scores"),
)
_ = compute_scores(z_normed_scores, dask_client)
# T-SCORES
t_normed_scores = post_process_scores(
pipeline,
t_normed_scores,
_build_filename(score_file_name, "t_normed_scores"),
)
_ = compute_scores(t_normed_scores, dask_client)
# S-SCORES
s_normed_scores = post_process_scores(
pipeline,
s_normed_scores,
_build_filename(score_file_name, "s_normed_scores"),
)
_ = compute_scores(s_normed_scores, dask_client)
# ZT-SCORES
zt_normed_scores = post_process_scores(
pipeline,
zt_normed_scores,
_build_filename(score_file_name, "zt_normed_scores"),
)
_ = compute_scores(zt_normed_scores, dask_client)
logger.info("Experiment finished !!!!!")
logger.debug("Executing Vanilla-biometrics ZTNorm")
execute_vanilla_biometrics_ztnorm(
pipeline,
database,
dask_client,
groups,
output,
consider_genuines,
write_metadata_scores,
ztnorm_cohort_proportion,
checkpoint,
dask_partition_size,
dask_n_workers,
)
logger.info("Experiment finished !")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment