Commit fa042420 authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

Orthogonality hypothesis

parent d10d3d24
......@@ -7,10 +7,11 @@ from torch.utils.data import Dataset
from bob.bio.face.database import MEDSDatabase, MorphDatabase
import torchvision.transforms as transforms
class DemoraphicTorchDataset(Dataset):
def __init__(
self, bob_dataset,
):
def __init__(self, bob_dataset, transform=None):
self.bob_dataset = bob_dataset
self.bucket = [s for sset in self.bob_dataset.zprobes() for s in sset]
......@@ -22,7 +23,8 @@ class DemoraphicTorchDataset(Dataset):
]
self.labels = dict(zip(keys, range(len(keys))))
self.metadata_keys = self.load_demographics()
self.demographic_keys = self.load_demographics()
self.transform = transform
def __len__(self):
return len(self.bucket)
......@@ -30,7 +32,10 @@ class DemoraphicTorchDataset(Dataset):
def __getitem__(self, idx):
sample = self.bucket[idx]
image = sample.data
image = sample.data if self.transform is None else self.transform(sample.data)
# image = image.astype("float32")
label = self.labels[sample.subject_id]
demography = self.get_demographics(sample)
......@@ -40,7 +45,7 @@ class DemoraphicTorchDataset(Dataset):
class MedsTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5",
self, protocol, database_path, database_extension=".h5", transform=None
):
bob_dataset = MEDSDatabase(
......@@ -48,7 +53,7 @@ class MedsTorchDataset(DemoraphicTorchDataset):
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset)
super().__init__(bob_dataset, transform=transform)
def load_demographics(self):
......@@ -65,12 +70,12 @@ class MedsTorchDataset(DemoraphicTorchDataset):
def get_demographics(self, sample):
demographic_key = getattr(sample, "rac")
return self.metadata_keys[demographic_key]
return self.demographic_keys[demographic_key]
class MorphTorchDataset(DemoraphicTorchDataset):
def __init__(
self, protocol, database_path, database_extension=".h5",
self, protocol, database_path, database_extension=".h5", transform=None
):
bob_dataset = MorphDatabase(
......@@ -78,7 +83,7 @@ class MorphTorchDataset(DemoraphicTorchDataset):
dataset_original_directory=database_path,
dataset_original_extension=database_extension,
)
super().__init__(bob_dataset)
super().__init__(bob_dataset, transform=transform)
def load_demographics(self):
......@@ -93,5 +98,5 @@ class MorphTorchDataset(DemoraphicTorchDataset):
def get_demographics(self, sample):
demographic_key = f"{sample.rac}-{sample.sex}"
return self.metadata_keys[demographic_key]
return self.demographic_keys[demographic_key]
from bob.bio.face.embeddings.pytorch import PyTorchModel, iresnet_template
from bob.learn.pytorch.architectures.iresnet import iresnet100
from bob.bio.demographics.fair_transformers import RunnableTransformer
from functools import partial
annotation_type = "eyes-center"
fixed_positions = None
memory_demanding = False
# checkpoint_path = "/idiap/temp/tpereira/2.FRDemographics/regularization/models/orthogonality_hypothesis/meds/iresnet100.pth"
checkpoint_path = "/idiap/temp/tpereira/2.FRDemographics/regularization/models/orthogonality_hypothesis/meds_identity-10.0_orthogonality-1.0/iresnet100.pth"
pipeline = iresnet_template(
embedding=RunnableTransformer(
partial(iresnet100, pretrained=checkpoint_path),
memory_demanding=memory_demanding,
),
annotation_type=annotation_type,
fixed_positions=fixed_positions,
)
#### DATABASE
from bob.bio.face.database import MEDSDatabase
protocol = "verification_fold1"
database = MEDSDatabase(protocol=protocol)
# output = (
# "/remote/idiap.svm/user.active/tpereira/gitlab/bob/bob.nightlies/vanilla-callback"
# )
from bob.bio.demographics.datasets import MedsTorchDataset, MorphTorchDataset
# https://pytorch.org/docs/stable/data.html
from torch.utils.data import DataLoader
import pytest
from bob.extension import rc
import os
from bob.learn.pytorch.architectures.iresnet import iresnet34, iresnet100
from bob.learn.pytorch.head import ArcFace, Regular
from bob.bio.demographics.regularizers.independence import (
DemographicRegularHead,
OrthogonalityModel,
)
import pytorch_lightning as pl
import torch
from functools import partial
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import torchvision.transforms as transforms
import click
import yaml
# demographic_epochs = 50
# identity_epochs = 200
@click.command()
@click.argument("OUTPUT_DIR")
@click.option("--identity-factor", default=1.0, help="Identity factor")
@click.option("--orthogonality-factor", default=1.0, help="Ortogonality factor")
@click.option("--max-epochs", default=600, help="Max number of epochs")
@click.option(
"--demographic-epochs",
default=100,
help="Number of epochs to train the demographic classifier",
)
@click.option(
"--identity-epochs",
default=200,
help="Number of epochs to train the identity classifier",
)
@click.option("--batch-size", default=64, help="Batch size")
def ortogonality_meds(
output_dir,
identity_factor,
orthogonality_factor,
max_epochs,
demographic_epochs,
identity_epochs,
batch_size,
):
os.makedirs(output_dir, exist_ok=True)
with open(f"{output_dir}/config.yml", "w") as file:
dict_file = dict()
dict_file["demographic_epochs"] = demographic_epochs
dict_file["identity_epochs"] = identity_epochs
dict_file["max_epochs"] = max_epochs
dict_file["batch_size"] = batch_size
yaml.dump(dict_file, file)
backbone_checkpoint_path = f"{output_dir}/iresnet100.pth"
checkpoint_dir = f"{output_dir}/last.ckpt"
database_path = os.path.join(
rc.get("bob.bio.demographics.directory"), "meds", "samplewrapper"
)
import bob.io.image
transform = transforms.Compose(
[
lambda x: bob.io.image.to_matplotlib(x),
transforms.ToPILImage(mode="RGB"),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=(-3, 3)),
# transforms.RandomAutocontrast(p=0.1),
transforms.ToTensor(),
lambda x: (x - 127.5) / 128.0,
]
)
dataset = MedsTorchDataset(
protocol="verification_fold1", database_path=database_path, transform=transform,
)
# train_dataloader = DataLoader(
# dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2
# )
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# backbone = iresnet34(
# pretrained="/idiap/temp/tpereira/bob/data/pytorch/iresnet-91a5de61/iresnet34-5b0d0e90.pth"
# )
# Add this argument
backbone = iresnet100(
pretrained="/idiap/temp/tpereira/bob/data/pytorch/iresnet-91a5de61/iresnet100-73e07ba7.pth"
)
# list(dataloader.dataset.labels.values())
#####################
## IDENTITY
num_class = len(list(train_dataloader.dataset.labels.values()))
identity_head = ArcFace(
feat_dim=backbone.features.num_features, num_class=num_class
)
######################
## DEMOGRAPHIC
num_class = len(list(train_dataloader.dataset.demographic_keys.values()))
demographic_head = DemographicRegularHead(
feat_dim=backbone.features.num_features, num_class=num_class
)
################
## Trainer
optimizer = partial(torch.optim.SGD, lr=0.001, momentum=0.9)
# demographic_epochs = 50
# identity_epochs = 200
# ortogonality_epochs = 400
# Preparing lightining model
model = OrthogonalityModel(
backbone=backbone,
identity_head=identity_head,
demographic_head=demographic_head,
loss_fn=torch.nn.CrossEntropyLoss(),
optimizer_fn=optimizer,
identity_factor=identity_factor,
orthogonality_factor=orthogonality_factor,
backbone_checkpoint_path=backbone_checkpoint_path,
demographic_epochs=demographic_epochs,
identity_epochs=identity_epochs,
)
"""
from bob.bio.face.pytorch.callbacks import VanillaBiometricsCallback
vanilla_callback = VanillaBiometricsCallback(
config="/remote/idiap.svm/user.active/tpereira/gitlab/bob/bob.nightlies/src/bob.bio.demographics/bob/bio/demographics/fair_transformers/transformers.py",
output_path="./vanilla-callback",
)
"""
model_checkpoint = ModelCheckpoint(
output_dir, every_n_train_steps=100, save_last=True
)
logger = TensorBoardLogger(os.path.join(output_dir, "tb_logs"))
# Be careful with
# https://github.com/PyTorchLightning/pytorch-lightning/issues/5325
resume_from_checkpoint = checkpoint_dir if os.path.exists(checkpoint_dir) else None
# TODO: using this code to learn too
# so, be nice with my comments
# callbacks=[model_checkpoint, vanilla_callback],
callbacks = [model_checkpoint]
trainer = pl.Trainer(
callbacks=callbacks,
logger=logger,
max_epochs=max_epochs,
gpus=-1 if torch.cuda.is_available() else None,
resume_from_checkpoint=resume_from_checkpoint,
# resume_from_checkpoint=resume_from_checkpoint, #https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#resume-from-checkpoint
# debug flags
# limit_train_batches=10, # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#limit-train-batches
# limit_val_batches=1,
amp_level="00", # https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#amp-level
log_every_n_steps=5,
)
trainer.fit(
model=model, train_dataloaders=train_dataloader,
)
if __name__ == "__main__":
ortogonality_meds()
# from .facecrop import facecrop_pipeline
from .transformers import RunnableTransformer
from bob.bio.face.embeddings.pytorch import PyTorchModel, iresnet_template
class RunnableTransformer(PyTorchModel):
"""
ArcFace model (RESNET 100) from Insightface ported to pytorch
"""
def __init__(
self,
runnable_pytorch_model,
preprocessor=lambda x: (x - 127.5) / 128.0,
memory_demanding=False,
device=None,
**kwargs,
):
super(RunnableTransformer, self).__init__(
checkpoint_path="",
config="",
memory_demanding=memory_demanding,
preprocessor=preprocessor,
device=device,
**kwargs,
)
self.runnable_pytorch_model = runnable_pytorch_model
def _load_model(self):
self.model = self.runnable_pytorch_model()
self.model.eval()
self.place_model_on_device()
......@@ -302,6 +302,33 @@ def plot_fdr(
plt.legend()
### PARETO PLOT
"""
FMRS = []
FNMRS = []
for neg, pos, tau in zip(negatives, positives, taus):
fmrs = []
fnmrs = []
for t in tau:
fmr, fnmr = bob.measure.farfrr(
neg["score"].compute().to_numpy(), pos["score"].compute().to_numpy(), t
)
fmrs.append(fmr)
fnmrs.append(fnmr)
FMRS.append(fmrs)
FNMRS.append(fnmrs)
ax = plt.axes(projection="3d")
for fmr, fnmr, fdr in zip(FMRS, FNMRS, fdrs):
ax.plot3D(fmr, fnmr, fdr)
ax.scatter3D(fmr, fnmr, fdr)
ax.set_xlabel("$FMR$")
ax.set_ylabel("$FNMR$")
ax.set_zlabel("$FDR$")
"""
return fig
......
from bob.learn.pytorch.trainers import BackboneHeadModel
from torch.nn import Module, Linear
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import numpy as np
import copy
class DemographicRegularHead(Module):
"""
Implement a regular head used for softmax layers
"""
def __init__(self, feat_dim, num_class):
super(DemographicRegularHead, self).__init__()
# We will apply the orthogonality here
self.fc1 = Linear(feat_dim, feat_dim, bias=False)
self.fc2 = Linear(feat_dim, num_class, bias=False)
def forward(self, feats, labels):
demographic_embeding = self.fc1(feats)
demographic_logits = self.fc2(demographic_embeding)
return demographic_embeding, demographic_logits
def switch(model, flag):
model.train(flag)
# model.requires_grad = flag
for p in model.parameters():
p.requires_grad = flag
return model
class OrthogonalityModel(BackboneHeadModel):
"""
Here we hypothesize that the sensitive attribute is orthogonal
to the identity
"""
def __init__(
self,
backbone,
identity_head,
demographic_head,
loss_fn,
optimizer_fn,
identity_factor=1.0,
orthogonality_factor=1.0,
backbone_checkpoint_path=None,
demographic_epochs=30,
identity_epochs=50,
**kwargs,
):
# super(pl.LightningModule, self).__init__(**kwargs)
pl.LightningModule.__init__(self, **kwargs)
self.backbone = backbone
self.identity_head = identity_head
self.demographic_head = demographic_head
self.loss_fn = loss_fn
self.optimizer_fn = optimizer_fn
self.identity_factor = identity_factor
self.orthogonality_factor = orthogonality_factor
self.demographic_epochs = demographic_epochs # 1. First train demographics
self.identity_epochs = (
demographic_epochs + identity_epochs
) # 2. Train identities
# self.ortogonality_epochs = (
# demographic_epochs + identity_epochs + ortogonality_epochs
# ) # 3. Train ortogonality
self.backbone_checkpoint_path = backbone_checkpoint_path
self.last_op = None
# Control the networks that will be updated
self.demographic_switch = False
self.identity_switch = False
self.orthogonality_switch = False
def training_epoch_end(self, training_step_outputs):
if self.backbone_checkpoint_path:
state = self.backbone.state_dict()
torch.save(state, self.backbone_checkpoint_path)
pass
def training_step(self, batch, batch_idx):
data = batch["data"]
label = batch["label"]
demography = batch["demography"]
embedding = self.backbone(data)
if self.current_epoch < self.demographic_epochs:
## First we learn the demography classifiers
# Switching of the backbone
if not self.demographic_switch:
self.demographic_switch = True
self.backbone = switch(self.backbone, False)
self.demographic_head = switch(
self.demographic_head, True
) # Update just he demographic classifier
self.identity_head = switch(
self.identity_head, False
) # Update just he demographic classifier
# Demographic CLASSIFICATION loss
_, demographic_logits = self.demographic_head(embedding, demography)
loss_demography = self.loss_fn(demographic_logits, demography)
self.log("train/loss_demography", loss_demography)
acc = (
sum(
np.argmax(demographic_logits.cpu().detach().numpy(), axis=1)
== demography.cpu().detach().numpy()
)
/ demography.shape[0]
)
self.log("train/acc_demography_before_orthogonalization", acc)
return loss_demography
if self.current_epoch < self.identity_epochs:
## Second the idenity classifier
if not self.identity_switch:
self.identity_switch = True
# Switching of the backbone
self.backbone = switch(self.backbone, False)
self.demographic_head = switch(
self.demographic_head, False
) # Update just he demographic classifier
self.identity_head = switch(
self.identity_head, True
) # Update just he demographic classifier
# Identity loss
logits_identiy = self.identity_head(embedding, label)
loss_identity = self.loss_fn(logits_identiy, label)
self.log("train/loss_identity", loss_identity)
acc = (
sum(
np.argmax(logits_identiy.cpu().detach().numpy(), axis=1)
== label.cpu().detach().numpy()
)
/ label.shape[0]
)
self.log("train/acc_identity_before_orthogonalization", acc)
return loss_identity
#########################################
# Now we learn the orthogonalization
#########################################
# Switching of the backbone
if not self.orthogonality_switch:
self.orthogonality_switch = True
self.backbone = switch(self.backbone, True)
self.demographic_head = switch(
self.demographic_head, False
) # Update just he demographic classifier
self.identity_head = switch(
self.identity_head, True
) # Update just he demographic classifier
# Identity loss
logits_identiy = self.identity_head(embedding, label)
loss_identity = self.loss_fn(logits_identiy, label)
# self.log("train/loss_identity", loss_identity)
# Demographic CLASSIFICATION loss
demographic_embeding, demographic_logits = self.demographic_head(
embedding, demography
)
# Demographic
acc = (
sum(
np.argmax(demographic_logits.cpu().detach().numpy(), axis=1)
== demography.cpu().detach().numpy()
)
/ demography.shape[0]
)
self.log("train/acc_demography_after_orthogonalization", acc)
# Identity
acc = (
sum(
np.argmax(logits_identiy.cpu().detach().numpy(), axis=1)
== label.cpu().detach().numpy()
)
/ label.shape[0]
)
self.log("train/acc_identity_after_orthogonalization", acc)
# ORTOGONALITY LOSS
# DOT PRODUCT BETWEEN DEMOGRAPHIC EMBEDDING AND INDENTIY EMBEDDING SHOULD BE ZERO
loss_orthogonality = torch.mean(