Commit 3deb5e25 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Several updates

parent e490ef63
Pipeline #58019 failed with stages
in 1 minute and 49 seconds
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
from torch.utils.data import Dataset
from bob.bio.face.database import MEDSDatabase, MorphDatabase, RFWDatabase
import torchvision.transforms as transforms
class DemoraphicTorchDataset(Dataset):
def __init__(self, bob_dataset, transform=None):
self.bob_dataset = bob_dataset
self.bucket = [s for sset in self.bob_dataset.zprobes() for s in sset]
self.bucket += [s for sset in self.bob_dataset.treferences() for s in sset]
# Defining keys and labels
keys = [sset.subject_id for sset in self.bob_dataset.zprobes()] + [
sset.subject_id for sset in self.bob_dataset.treferences()
]
self.labels = dict(zip(keys, range(len(keys))))
self.demographic_keys = self.load_demographics()
self.transform = transform
def __len__(self):
return len(self.bucket)
def __getitem__(self, idx):
sample = self.bucket[idx]
image = sample.data if self.transform is None else self.transform(sample.data)
# image = image.astype("float32")
label = self.labels[sample.subject_id]
demography = self.get_demographics(sample)
return {"data": image, "label": label, "demography": demography}
class MedsTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5", transform=None
):
bob_dataset = MEDSDatabase(
protocol=protocol,
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset, transform=transform)
def load_demographics(self):
target_metadata = "rac"
metadata_keys = set(
[getattr(sset, target_metadata) for sset in self.bob_dataset.zprobes()]
+ [
getattr(sset, target_metadata)
for sset in self.bob_dataset.treferences()
]
)
metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
return metadata_keys
def get_demographics(self, sample):
demographic_key = getattr(sample, "rac")
return self.demographic_keys[demographic_key]
class MorphTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5", transform=None
):
self.bob_dataset = MorphDatabase(
protocol=protocol,
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
# Morph dataset has an intersection in between zprobes and treferences
self.excluding_list = [
"190276",
"332158",
"111942",
"308129",
"334074",
"350814",
"131677",
"168724",
"276055",
"275589",
"286810",
]
self.bucket = [s for sset in self.bob_dataset.zprobes() for s in sset]
self.bucket += [
s
for sset in self.bob_dataset.treferences()
for s in sset
if sset.subject_id not in self.excluding_list
]
# Defining keys and labels
keys = [b.subject_id for b in self.bucket]
self.labels = dict(zip(keys, range(len(keys))))
self.demographic_keys = self.load_demographics()
self.transform = transform
# super().__init__(bob_dataset, transform=transform)
def load_demographics(self):
target_metadata = "rac"
metadata_keys = set(
[f"{sset.rac}-{sset.sex}" for sset in self.bob_dataset.zprobes()]
+ [
f"{sset.rac}-{sset.sex}"
for sset in self.bob_dataset.treferences()
if sset.subject_id not in self.excluding_list
]
)
metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
return metadata_keys
def get_demographics(self, sample):
demographic_key = f"{sample.rac}-{sample.sex}"
return self.demographic_keys[demographic_key]
class RFWTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5", transform=None
):
bob_dataset = RFWDatabase(
protocol=protocol,
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset, transform=transform)
def load_demographics(self):
target_metadata = "race"
metadata_keys = set(
[getattr(sset, target_metadata) for sset in self.bob_dataset.zprobes()]
+ [
getattr(sset, target_metadata)
for sset in self.bob_dataset.treferences()
]
)
metadata_keys = dict(zip(metadata_keys, range(len(metadata_keys))))
return metadata_keys
def get_demographics(self, sample):
demographic_key = getattr(sample, "race")
return self.demographic_keys[demographic_key]
......@@ -24,7 +24,10 @@ from bob.bio.demographics.regularizers.trainers import balance_trainer
@click.option("--batch-size", default=64, help="Batch size")
@click.option("--backbone", default="iresnet100", help="Backbone")
def balance_meds(
output_dir, max_epochs, batch_size, backbone,
output_dir,
max_epochs,
batch_size,
backbone,
):
from bob.bio.demographics.regularizers import AVAILABLE_BACKBONES
......@@ -46,7 +49,9 @@ def balance_meds(
)
dataset = MedsTorchDataset(
protocol="verification_fold1", database_path=database_path, transform=transform,
protocol="verification_fold1",
database_path=database_path,
transform=transform,
)
train_dataloader = DataLoader(
......@@ -57,7 +62,12 @@ def balance_meds(
backbone_model = AVAILABLE_BACKBONES[backbone]["prior"]()
balance_trainer(
output_dir, max_epochs, batch_size, train_dataloader, backbone_model, transform,
output_dir,
max_epochs,
batch_size,
train_dataloader,
backbone_model,
transform,
)
......
# from bob.bio.demographics.datasets import MedsTorchDataset
from bob.bio.face.pytorch.datasets import MSCelebTorchDataset
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
from bob.extension import rc
import os
import bob.io.image
import torch
from functools import partial
import torchvision.transforms as transforms
import click
import yaml
from bob.bio.demographics.regularizers.trainers import balance_trainer
@click.command()
@click.argument("OUTPUT_DIR")
@click.option("--max-epochs", default=600, help="Max number of epochs")
@click.option("--batch-size", default=64, help="Batch size")
@click.option("--backbone", default="iresnet100", help="Backbone")
def balance_msceleb(
output_dir,
max_epochs,
batch_size,
backbone,
):
from bob.bio.demographics.regularizers import AVAILABLE_BACKBONES
# database_path = os.path.join(
# rc.get("bob.bio.demographics.directory"), "meds", "samplewrapper"
# )
# database_path = os.path.join(
# rc.get("bob.bio.demographics.directory"), "meds", "samplewrapper"
# )
database_path = (
"/idiap/temp/tpereira/databases/msceleb/112x112-arcface-crop-mtcnn-crop/"
)
transform = transforms.Compose(
[
lambda x: bob.io.image.to_matplotlib(x.astype("float32")),
# transforms.ToPILImage(mode="RGB"),
# transforms.RandomHorizontalFlip(p=0.5),
# transforms.RandomRotation(degrees=(-3, 3)),
# transforms.RandomAutocontrast(p=0.1),
transforms.ToTensor(),
lambda x: (x - 127.5) / 128.0,
]
)
dataset = MSCelebTorchDataset(
database_path=database_path,
transform=transform,
)
train_dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=8
)
# train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
backbone_model = AVAILABLE_BACKBONES[backbone]["prior"]()
balance_trainer(
output_dir,
max_epochs,
batch_size,
train_dataloader,
backbone_model,
transform,
)
if __name__ == "__main__":
balance_msceleb()
from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
from bob.bio.face.pytorch.datasets import MedsTorchDataset, MorphTorchDataset
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
......@@ -65,15 +65,16 @@ def ortogonality_meds(
)
dataset = MedsTorchDataset(
protocol="verification_fold1", database_path=database_path, transform=transform,
protocol="verification_fold1",
database_path=database_path,
transform=transform,
)
train_dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2
)
# train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
)
backbone_model = AVAILABLE_BACKBONES[backbone]()
backbone_model = AVAILABLE_BACKBONES[backbone]["prior"]()
ortogonality_trainer(
output_dir,
......
from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
from bob.bio.face.pytorch.datasets import MedsTorchDataset, MorphTorchDataset
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
......@@ -65,7 +65,9 @@ def ortogonality_morph(
)
dataset = MorphTorchDataset(
protocol="verification_fold1", database_path=database_path, transform=transform,
protocol="verification_fold1",
database_path=database_path,
transform=transform,
)
train_dataloader = DataLoader(
......@@ -73,7 +75,7 @@ def ortogonality_morph(
)
# train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
backbone_model = AVAILABLE_BACKBONES[backbone]()
backbone_model = AVAILABLE_BACKBONES[backbone]["prior"]()
ortogonality_trainer(
output_dir,
......
......@@ -40,16 +40,16 @@ def plot_demographic_boxplot(
Pandas Dataframe containing the positive scores (or genuines scores, or even mated scores)
fmr_thresholds: list
List containing the FMR operational points
List containing the FMR operational points
label_lookup_table: dict
Lookup table mapping `variable` to the actual label of the variable
percentile: float
If set, it will plit
title="",
"""
......@@ -387,7 +387,9 @@ def plot_fmr_fnmr_tradeoff(
fnmrs[key].append(fnmr)
else:
fmr, _ = bob.measure.farfrr(
negatives_as_dict[key]["score"].compute().to_numpy(), [0.0], t,
negatives_as_dict[key]["score"].compute().to_numpy(),
[0.0],
t,
)
fmrs[key].append(fmr)
......
AVAILABLE_BACKBONES = dict()
from bob.learn.pytorch.architectures.iresnet import iresnet34, iresnet100, iresnet50
from bob.bio.face.pytorch.backbones.iresnet import iresnet34, iresnet100, iresnet50
from functools import partial
# Organize these checkpoints
......@@ -28,4 +28,3 @@ AVAILABLE_BACKBONES["iresnet34"] = {
"/idiap/temp/tpereira/bob/data/pytorch/iresnet-91a5de61/iresnet34-5b0d0e90.pth",
),
}
from bob.learn.pytorch.trainers import BackboneHeadModel
from bob.bio.face.pytorch.lightning import BackboneHeadModel
from torch.nn import Module, Linear
import pytorch_lightning as pl
import torch
......@@ -74,4 +74,3 @@ class SimpleBalanceModel(BackboneHeadModel):
def configure_optimizers(self):
# optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return self.optimizer_fn(params=self.parameters())
from bob.learn.pytorch.trainers import BackboneHeadModel
from bob.bio.face.pytorch.lightning import BackboneHeadModel
from torch.nn import Module, Linear
import pytorch_lightning as pl
import torch
......@@ -42,7 +42,7 @@ def switch(model, flag):
class OrthogonalityModel(BackboneHeadModel):
"""
Here we hypothesize that the sensitive attribute is orthogonal
Here we hypothesize that the sensitive attribute is orthogonal
to the identity attribute
"""
......@@ -223,7 +223,8 @@ class OrthogonalityModel(BackboneHeadModel):
loss_orthogonality = torch.mean(
torch.abs(
torch.sum(
F.normalize(demographic_embeding) * F.normalize(embedding), axis=1,
F.normalize(demographic_embeding) * F.normalize(embedding),
axis=1,
)
)
)
......
from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
# from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
from bob.bio.face.pytorch.datasets import MedsTorchDataset, MorphTorchDataset
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
import pytest
from bob.extension import rc
import os
from bob.learn.pytorch.architectures.iresnet import iresnet34, iresnet100
from bob.learn.pytorch.head import ArcFace, Regular
from bob.bio.face.pytorch.backbones.iresnet import iresnet34, iresnet100, iresnet50
from bob.bio.face.pytorch.head import ArcFace, Regular
from bob.bio.demographics.regularizers.independence import (
DemographicRegularHead,
OrthogonalityModel,
......@@ -61,7 +62,8 @@ def ortogonality_trainer(
#####################
## IDENTITY
num_class = len(list(train_dataloader.dataset.labels.values()))
num_class = train_dataloader.dataset.n_classes
identity_head = ArcFace(
feat_dim=backbone_model.features.num_features, num_class=num_class
)
......@@ -127,12 +129,13 @@ def ortogonality_trainer(
# debug flags
# limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_val_batches=1,
amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
log_every_n_steps=5,
# amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
log_every_n_steps=50,
)
trainer.fit(
model=model, train_dataloaders=train_dataloader,
model=model,
train_dataloaders=train_dataloader,
)
......@@ -238,12 +241,18 @@ def mine_trainer(
)
trainer.fit(
model=model, train_dataloaders=train_dataloader,
model=model,
train_dataloaders=train_dataloader,
)
def balance_trainer(
output_dir, max_epochs, batch_size, train_dataloader, backbone_model, transform,
output_dir,
max_epochs,
batch_size,
train_dataloader,
backbone_model,
transform,
):
"""
......@@ -264,18 +273,20 @@ def balance_trainer(
#####################
## IDENTITY
num_class = len(list(train_dataloader.dataset.labels.values()))
num_class = train_dataloader.dataset.n_classes
weight = train_dataloader.dataset.get_demographic_class_weights()
identity_head = ArcFace(
feat_dim=backbone_model.features.num_features, num_class=num_class
)
optimizer = partial(torch.optim.SGD, lr=0.001, momentum=0.9)
optimizer = partial(torch.optim.SGD, lr=0.1, momentum=0.9)
# loss_fn=torch.nn.CrossEntropyLoss(weight=weight),
model = SimpleBalanceModel(
backbone=backbone_model,
identity_head=identity_head,
loss_fn=torch.nn.CrossEntropyLoss(),
loss_fn=torch.nn.CrossEntropyLoss(weight=weight),
optimizer_fn=optimizer,
backbone_checkpoint_path=backbone_checkpoint_path,
max_epochs=max_epochs,
......@@ -311,13 +322,13 @@ def balance_trainer(
resume_from_checkpoint=resume_from_checkpoint,
# resume_from_checkpoint=resume_from_checkpoint, #https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#resume-from-checkpoint
# debug flags
# limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_train_batches=2, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_val_batches=1,
amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
log_every_n_steps=5,
# amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
log_every_n_steps=50,
)
trainer.fit(
model=model, train_dataloaders=train_dataloader,
model=model,
train_dataloaders=train_dataloader,
)
......@@ -241,7 +241,8 @@ def mobio_report(
}
variable_suffix = "gender"
print(percentile)
print("########################")
pdf = PdfPages(output_filename)
negatives_dev, positives_dev, negatives_eval, positives_eval = load_dev_eval_scores(
......@@ -460,3 +461,106 @@ def rfw_report(
pdf.savefig(fig)
pdf.close()
def vgg2_report(
scores_dev,
output_filename,
scores_eval=None,
fmr_thresholds=[10 ** i for i in list(range(-8, 0))],
percentile=0.05,
titles=None,
possible_races=["A", "B", "I", "U", "W"],
genders_considered=["m"],
):
# possible_races=["A", "B", "I", "U", "W", "N"],
variables = {
"A": "Asian",
"B": "Black",
"I": "Indian",
"U": "Indet.",
"W": "White",
}
label_lookup_table = dict()
for a in list(variables.keys()):
for b in list(variables.keys()):
label_lookup_table[f"{a}__{b}"] = f"{variables[a]}-{variables[b]}"
variable_suffix = "race"
pdf = PdfPages(output_filename)
negatives_dev, positives_dev, negatives_eval, positives_eval = load_dev_eval_scores(
scores_dev, scores_eval
)
def filter_out(dataframe, genders):
return dataframe[
(dataframe.bio_ref_race.isin(possible_races))
& (dataframe.bio_ref_gender.isin(genders))
& (dataframe.probe_race.isin(possible_races))
]
negatives_dev = [filter_out(n, ["m", "f"]) for n in negatives_dev]
positives_dev = [filter_out(n, ["m", "f"]) for n in positives_dev]
if negatives_eval[0] is None:
# Compute FDR on the same set if there's no evaluation set
fig = plot_fdr(
negatives_dev, positives_dev, titles, variable_suffix, fmr_thresholds
)
else:
# If there is evaluation set
# compute the decision thresholds
negatives_eval = [filter_out(n, genders_considered) for n in negatives_eval]
positives_eval = [filter_out(n, genders_considered) for n in positives_eval]
taus = [compute_fmr_thresholds(d, fmr_thresholds) for d in negatives_dev]
fig = plot_fdr(
negatives_eval,
positives_eval,