diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/common.py b/bob/paper/ijcb2023_caim_hfr/architectures/common.py index 9cafd3fba9d5f1229b7939056251a71ba3f99027..2243fe7f482160ad058a49736a339a49ad21a5e1 100644 --- a/bob/paper/ijcb2023_caim_hfr/architectures/common.py +++ b/bob/paper/ijcb2023_caim_hfr/architectures/common.py @@ -21,6 +21,61 @@ import torch.nn as nn from torch.nn import init + +import torch.nn as nn + + +class IdentityPassthrough(nn.Module): + def __init__(self, dim=1, kernel=3): + super(IdentityPassthrough, self).__init__() + + self.inn= nn.Identity() + def forward(self,x,gate): + + gate=gate.unsqueeze(1).unsqueeze(2) + + return self.inn(x) + + + +class CAIM(nn.Module): + def __init__(self, in_channels): + super(CAIM, self).__init__() + + style_channels = in_channels * 2 + + self.style_net = nn.Sequential( + nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels // 2, style_channels, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + ) + + self.scale = nn.Linear(style_channels, in_channels) + self.shift = nn.Linear(style_channels, in_channels) + + self.instance_norm = nn.InstanceNorm2d(in_channels, affine=False) + + def forward(self, x, gate): + B, C, H, W = x.size() + + x_normed = self.instance_norm(x) + + # Compute the style tensor using the style_net + style = self.style_net(x) + style = style.mean(dim=[2, 3]) # Global average pooling + + scale = self.scale(style).view(B, C, 1, 1) + shift = self.shift(style).view(B, C, 1, 1) + + x_hat = scale * x_normed + shift + + ret = gate.view(B, 1, 1, 1) * x_hat + x + + return ret + + + class SEBlock(nn.Module): """ Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,' https://arxiv.org/abs/1709.01507. diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py b/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py index de4d68ba807699ebe06246ae1aba57604d815366..86ce39f3d31778b888ae16ab848cb23a8e8f04e3 100644 --- a/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py +++ b/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py @@ -195,6 +195,148 @@ class IResNet(nn.Module): return x + + + +class IResNetPAM(nn.Module): + fc_scale = 7 * 7 + + def __init__( + self, + block, + layers, + pam_translator, + num_features=512, + zero_init_residual=False, + groups=1, + width_per_group=64, + replace_stride_with_dilation=None, + ): + super(IResNetPAM, self).__init__() + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation) + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d( + 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=2e-05, momentum=0.9) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] + ) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] + ) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2] + ) + + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=2e-05, momentum=0.9) + self.dropout = nn.Dropout2d(p=0.4, inplace=True) + self.flatten = nn.Flatten() + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=2e-05, momentum=0.9) + self.pam_translator= pam_translator + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=2e-05, momentum=0.9), + ) + + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride, + downsample, + self.groups, + self.base_width, + previous_dilation, + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + ) + ) + + return nn.Sequential(*layers) + + def forward(self, x, gate=0): + + if gate==0: + gate=torch.zeros(x.size(0),1).to(x.device) + else: + gate=torch.ones(x.size(0),1).to(x.device) + + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + + x=self.pam_translator[0](x, gate) + x = self.layer1(x) + + x=self.pam_translator[1](x, gate) + + x = self.layer2(x) + x=self.pam_translator[2](x, gate) + + x = self.layer3(x) + + x=self.pam_translator[3](x, gate) + x = self.layer4(x) + x=self.pam_translator[4](x, gate) + + x = self.bn2(x) + x = self.dropout(x) + x = self.flatten(x) + x = self.fc(x) + x = self.features(x) + + return x + + + + + + def _iresnet(arch, block, layers, checkpoint_path): model = IResNet(block, layers) if checkpoint_path is not None: @@ -214,3 +356,16 @@ def iresnet50(checkpoint_path=None): def iresnet100(checkpoint_path=None): return _iresnet("iresnet100", IBasicBlock, [3, 13, 30, 3], checkpoint_path) + + + +def _iresnetpam(arch,translator, block, layers, checkpoint_path): + model = IResNetPAM(block, layers, translator) + if checkpoint_path is not None: + model.load_state_dict(torch.load(checkpoint_path)) + model.eval() + return model + + +def iresnet100pam(translator,checkpoint_path=None): + return _iresnetpam("iresnet100pam", translator, IBasicBlock, [3, 13, 30, 3], checkpoint_path) \ No newline at end of file diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py b/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py index fd43e00320b4228a2454af4208b45c1fb308e640..73bdba64fd686441944174a9195b746ccca4cc70 100644 --- a/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py +++ b/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py @@ -13,11 +13,14 @@ from bob.bio.face.embeddings.pytorch import IResnet100 from .iresnet import iresnet34 from .iresnet import iresnet50 from .iresnet import iresnet100 +from .iresnet import iresnet100pam + ARCHS = { "iresnet34": iresnet34, "iresnet50": iresnet50, "iresnet100": iresnet100, + "iresnet100pam": iresnet100pam, } CHECKPOINTS = { "iresnet34": IResnet34().checkpoint_path, @@ -26,6 +29,57 @@ CHECKPOINTS = { } + + +from functools import partial +class XIresnetTxlDAM(nn.Module): + + + def __init__(self, narch, translator, modality, **kwargs): + super().__init__(**kwargs) + self.vis_model = iresnet100pam(translator=translator) + self.checkpoint = CHECKPOINTS[narch.split('pam')[0]] + self.modality = modality + + self.vis_model.load_state_dict(torch.load(self.checkpoint), strict=False) + self.vis_model.requires_grad_(False) + self.vis_model.eval() + + + for name, mod in self.vis_model.named_children(): + + if 'translator' in name: + mod.requires_grad_(True) + + self.model_per_modality = {"VIS": partial(self.vis_model,gate=0), modality: partial(self.vis_model, gate=1)} # test + + def forward(self, inputs): + input_vis, input_mod = inputs["VIS"], inputs[self.modality] + with torch.no_grad(): + embedding_left = self.vis_model(input_vis, 0) + embedding_left = torch.nn.functional.normalize(embedding_left, dim=-1) + + embedding_right = self.vis_model(input_mod,1) # to activate the gate + embedding_right = torch.nn.functional.normalize(embedding_right, dim=-1) + outputs = torch.pairwise_distance(embedding_left, embedding_right) + return outputs + + def train(self, mode=True): + """ + Override the default train() to freeze the BN parameters + """ + super().train(mode) + for name, module in self.named_modules(): + if isinstance(module, torch.nn.modules.BatchNorm1d) or isinstance(module, torch.nn.modules.BatchNorm2d): + if 'vis' in name and 'translator' not in name: + module.eval() + module.weight.requires_grad = False + module.bias.requires_grad = False + + + + + class XIresnetTxl(nn.Module): """Takes an iresnet architecture and a translator module. Then, provides a multi-modal pytorch model.""" diff --git a/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py b/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py index f1191a523b28dffcd8a1b9e7e4722c8675e8e3b1..1d97fc44ce8a1fb9609f405791583a9501ccd7a1 100644 --- a/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py +++ b/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py @@ -21,13 +21,14 @@ from bob.pipelines import CheckpointWrapper from bob.pipelines import wrap +import torch.nn as nn from bob.paper.ijcb2023_caim_hfr.transformer.xlightning import XLitDSUModel from bob.paper.ijcb2023_caim_hfr.architectures.iresnet import preprocess_img from bob.paper.ijcb2023_caim_hfr.dataset.hface import AGLitDSUDataModule, AGLitDSUDataModuleNP # testing from bob.paper.ijcb2023_caim_hfr.losses import ContrastiveLoss from bob.paper.ijcb2023_caim_hfr.transformer.lightning import LightningTransformer -from bob.paper.ijcb2023_caim_hfr.architectures.light_iresnet_txl import XIresnetTxl +from bob.paper.ijcb2023_caim_hfr.architectures.light_iresnet_txl import XIresnetTxl, XIresnetTxlDAM database = globals()["database"] PROTOCOL = globals()["PROTOCOL"] @@ -47,7 +48,7 @@ OUTPUT = ( output ) = f"{rc['ptemp']}/{database.name}/{database.protocol}/" BATCH_SIZE = 90 -EPOCHS = 20 +EPOCHS = 50 pipeline = iresnet_template( @@ -114,9 +115,9 @@ def trainer_fn(): ) return trainer -from bob.paper.ijcb2023_caim_hfr.architectures.common import PDT +from bob.paper.ijcb2023_caim_hfr.architectures.common import IdentityPassthrough, CAIM -translator=PDT(pool_features=int(POOL),use_se=str2bool(SE), use_bias=str2bool(BIAS), use_cbam=str2bool(CBAM)) +translator= nn.ModuleList([CAIM(64), CAIM(64), CAIM(128), IdentityPassthrough(256), IdentityPassthrough(512)]) def model_fn(): @@ -125,22 +126,24 @@ def model_fn(): modality= modality if len(modality)>0 else {'VIS'} # hack to handle VIS-VIS protocol modality = list(modality)[0] print(f"Modality of database was found to be {modality}") - model = XIresnetTxl(ARCHITECTURE, translator, modality=modality) + model = XIresnetTxlDAM(ARCHITECTURE, translator, modality=modality) loss_fn = ContrastiveLoss(margin=2.0) model = XLitDSUModel(model, loss_fn=loss_fn) return model - def datamodule_fn(train_samples): datamodule = AGLitDSUDataModuleNP( batch_size=BATCH_SIZE, train_samples=train_samples, train_transform=train_transform, test_transform=test_transform, + test_size=0.1, + num_pairs=30000, ) return datamodule + embedding = LightningTransformer( trainer_fn=trainer_fn, model_fn=model_fn,