Commit bfcddbe1 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Increased calibration commands

parent fe75d01e
......@@ -8,6 +8,10 @@ from bob.bio.base.pipelines.vanilla_biometrics import (
)
import numpy as np
from bob.bio.demographics.regularizers import demographic
from bob.pipelines.distributed import VALID_DASK_CLIENT_STRINGS
from bob.extension.scripts.click_helper import ResourceOption
def _arg_split(ctx, param, value):
# split columns by ',' and remove whitespace
......@@ -23,8 +27,10 @@ def calibrate(
output_score_files,
calibrator,
score_selection_method,
field_name,
field_names,
field_values,
under_represented_demographic_maps=None,
calibrate_using_biometric_reference_demographics=True,
):
if calibrator == "linear":
calibrator = LLRCalibration
......@@ -43,11 +49,13 @@ def calibrate(
# quantile_factor = float(quantile_factor)
calibrator = CategoricalCalibration(
field_name=field_name,
field_names=field_names,
field_values=field_values,
score_selection_method=score_selection_method,
fit_estimator=calibrator,
reduction_function=np.mean,
under_represented_demographic_maps=under_represented_demographic_maps,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference_demographics,
)
calibrator = calibrator.fit(fit_score_file)
calibrator.transform(transform_score_files, output_score_files)
......@@ -82,19 +90,27 @@ def calibrate(
"`all`: It will select all the scores."
" Default to `median-q3`",
)
@click.option(
"-c",
"--calibrate-using-biometric-reference",
is_flag=True,
help="If set, calibrate using the biometric reference demographics."
"If not set, the calibration is performed using all calibrators and the final score is computed using `reduction_function",
)
def mobio(
fit_score_file,
transform_score_files,
output_score_files,
calibrator,
score_selection_method,
calibrate_using_biometric_reference,
):
"""
Calibrates scores coming from experiments using MOBIO database.
"""
field_values = ["m", "f"]
field_name = "gender"
field_values = [["m", "f"]]
field_name = ["gender"]
calibrate(
fit_score_file,
......@@ -104,6 +120,7 @@ def mobio(
score_selection_method,
field_name,
field_values,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference,
)
......@@ -136,19 +153,27 @@ def mobio(
"`all`: It will select all the scores."
" Default to `median-q3`",
)
@click.option(
"-c",
"--calibrate-using-biometric-reference",
is_flag=True,
help="If set, calibrate using the biometric reference demographics."
"If not set, the calibration is performed using all calibrators and the final score is computed using `reduction_function",
)
def meds(
fit_score_file,
transform_score_files,
output_score_files,
calibrator,
score_selection_method,
calibrate_using_biometric_reference,
):
"""
Calibrates scores coming from experiments using MEDS database.
"""
field_values = ["B", "W"]
field_name = "rac"
field_values = [["B", "W"]]
field_name = ["rac"]
calibrate(
fit_score_file,
......@@ -158,6 +183,7 @@ def meds(
score_selection_method,
field_name,
field_values,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference,
)
......@@ -190,19 +216,27 @@ def meds(
"`all`: It will select all the scores."
" Default to `median-q3`",
)
@click.option(
"-c",
"--calibrate-using-biometric-reference",
is_flag=True,
help="If set, calibrate using the biometric reference demographics."
"If not set, the calibration is performed using all calibrators and the final score is computed using `reduction_function",
)
def rfw(
fit_score_file,
transform_score_files,
output_score_files,
calibrator,
score_selection_method,
calibrate_using_biometric_reference,
):
"""
Calibrates scores coming from experiments using RFW database.
"""
field_values = ["Asian", "African", "Indian", "Caucasian"]
field_name = "race"
field_values = [["Asian", "African", "Indian", "Caucasian"]]
field_name = ["race"]
calibrate(
fit_score_file,
......@@ -212,6 +246,7 @@ def rfw(
score_selection_method,
field_name,
field_values,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference,
)
......@@ -244,19 +279,27 @@ def rfw(
"`all`: It will select all the scores."
" Default to `median-q3`",
)
@click.option(
"-c",
"--calibrate-using-biometric-reference",
is_flag=True,
help="If set, calibrate using the biometric reference demographics."
"If not set, the calibration is performed using all calibrators and the final score is computed using `reduction_function",
)
def morph_race(
fit_score_file,
transform_score_files,
output_score_files,
calibrator,
score_selection_method,
calibrate_using_biometric_reference,
):
"""
Calibrates scores coming from experiments using RFW database.
"""
field_values = ["A", "W", "B", "H"]
field_name = "rac"
field_values = [["A", "W", "B", "H"]]
field_name = ["rac"]
calibrate(
fit_score_file,
......@@ -266,6 +309,7 @@ def morph_race(
score_selection_method,
field_name,
field_values,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference,
)
......@@ -298,19 +342,39 @@ def morph_race(
"`all`: It will select all the scores."
" Default to `median-q3`",
)
@click.option(
"--dask-client",
"-l",
entry_point_group="dask.client",
string_exceptions=VALID_DASK_CLIENT_STRINGS,
default="single-threaded",
help="Dask client for the execution of the pipeline.",
cls=ResourceOption,
)
@click.option(
"-c",
"--calibrate-using-biometric-reference",
is_flag=True,
help="If set, calibrate using the biometric reference demographics."
"If not set, the calibration is performed using all calibrators and the final score is computed using `reduction_function",
)
def vgg2_race(
fit_score_file,
transform_score_files,
output_score_files,
calibrator,
score_selection_method,
dask_client,
calibrate_using_biometric_reference,
):
"""
Calibrates scores coming from experiments using VGG2 database.
"""
field_values = ["A", "B", "I", "W"]
field_name = "race"
field_values = [["A", "B", "I", "W"], ["m", "f"]]
field_names = ["race", "gender"]
under_represented_demographic_maps = [{"nan": "W", "U": "W"}, {}]
calibrate(
fit_score_file,
......@@ -318,6 +382,8 @@ def vgg2_race(
output_score_files,
calibrator,
score_selection_method,
field_name,
field_names,
field_values,
under_represented_demographic_maps,
calibrate_using_biometric_reference_demographics=calibrate_using_biometric_reference,
)
......@@ -32,6 +32,7 @@ from pytorch_lightning.callbacks import (
ModelCheckpoint,
LearningRateMonitor,
EarlyStopping,
DeviceStatsMonitor,
)
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
......@@ -41,7 +42,7 @@ import logging
from bob.bio.face.pytorch.datasets import SiameseDemographicWrapper
logger = logging.getLogger(__name__)
logging_logger = logging.getLogger(__name__)
EPILOG = """\b
......@@ -138,6 +139,7 @@ def train_balance(
shuffle=False,
pin_memory=False,
num_workers=1,
drop_last=True,
)
WEIGHT = config.WEIGHT
......@@ -825,6 +827,8 @@ def train_contrastive_calibration(
Siamese Calibration hypothesis.
"""
weight_contrastive_loss = 2.0
weight_calibration_loss = 0.5
config = chain_load([identity_backbone, database, training_config])
......@@ -835,6 +839,7 @@ def train_contrastive_calibration(
config.train_dataset,
max_positive_pairs_per_subject=max_positive_pairs_per_subject,
negative_pairs_per_subject=negative_pairs_per_subject,
train=True,
)
TRAIN_DATALOADER = DataLoader(
TRAIN_DATASET,
......@@ -842,6 +847,7 @@ def train_contrastive_calibration(
shuffle=True,
pin_memory=True,
num_workers=4,
drop_last=True,
)
VALIDATION_DATASET = config.validation_dataset
......@@ -851,6 +857,7 @@ def train_contrastive_calibration(
VALIDATION_DATASET,
max_positive_pairs_per_subject=1,
negative_pairs_per_subject=1,
train=False,
)
VALIDATION_DATALOADER = DataLoader(
VALIDATION_DATASET,
......@@ -858,20 +865,24 @@ def train_contrastive_calibration(
shuffle=False,
pin_memory=False,
num_workers=1,
drop_last=True,
)
IDENTITY_BACKBONE = config.backbone
###
# Defing the variables of the experiment, so we don't get lost
# Defing the variables of the experiment, so we don't get lost
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/config.yml", "w") as file:
dict_file = dict()
dict_file["max_epochs"] = max_epochs
dict_file["batch_size"] = batch_size
dict_file["hypothesis"] = "Siamese"
dict_file["weight_calibration_loss"] = weight_calibration_loss
dict_file["weight_contrastive_loss"] = weight_contrastive_loss
dict_file["max_positive_pairs_per_subject"] = max_positive_pairs_per_subject
dict_file["negative_pairs_per_subject"] = negative_pairs_per_subject
dict_file["hypothesis"] = "SiameseCalibration"
yaml.dump(dict_file, file)
backbone_checkpoint_file = f"{output_dir}/model.pth"
......@@ -890,6 +901,8 @@ def train_contrastive_calibration(
n_demographics=n_demographics,
backbone_checkpoint_file=backbone_checkpoint_file,
demographic_weights=config.train_dataset.get_demographic_weights(as_dict=False),
weight_contrastive_loss=weight_contrastive_loss,
weight_calibration_loss=weight_calibration_loss,
)
# LOGGERS
......@@ -1012,6 +1025,7 @@ def train_contrastive_independence(
config.train_dataset,
max_positive_pairs_per_subject=max_positive_pairs_per_subject,
negative_pairs_per_subject=negative_pairs_per_subject,
train=True,
)
TRAIN_DATALOADER = DataLoader(
TRAIN_DATASET,
......@@ -1019,6 +1033,7 @@ def train_contrastive_independence(
shuffle=True,
pin_memory=True,
num_workers=4,
drop_last=True,
)
VALIDATION_DATASET = config.validation_dataset
......@@ -1028,6 +1043,7 @@ def train_contrastive_independence(
VALIDATION_DATASET,
max_positive_pairs_per_subject=1,
negative_pairs_per_subject=1,
train=False,
)
VALIDATION_DATALOADER = DataLoader(
VALIDATION_DATASET,
......@@ -1035,6 +1051,7 @@ def train_contrastive_independence(
shuffle=False,
pin_memory=False,
num_workers=1,
drop_last=True,
)
IDENTITY_BACKBONE = config.backbone
......@@ -1066,9 +1083,8 @@ def train_contrastive_independence(
model = ContrastiveIndependenceModel(
facerec_backbone=IDENTITY_BACKBONE,
demographic_backbone=DEMOGRAPHIC_BACKBONE ,
demographic_backbone=DEMOGRAPHIC_BACKBONE,
backbone_checkpoint_file=backbone_checkpoint_file,
)
# LOGGERS
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment