diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/common.py b/bob/paper/ijcb2023_caim_hfr/architectures/common.py
index 9cafd3fba9d5f1229b7939056251a71ba3f99027..2243fe7f482160ad058a49736a339a49ad21a5e1 100644
--- a/bob/paper/ijcb2023_caim_hfr/architectures/common.py
+++ b/bob/paper/ijcb2023_caim_hfr/architectures/common.py
@@ -21,6 +21,61 @@ import torch.nn as nn
 
 from torch.nn import init
 
+
+import torch.nn as nn
+
+
+class IdentityPassthrough(nn.Module):
+    def __init__(self, dim=1, kernel=3):
+        super(IdentityPassthrough, self).__init__()
+
+        self.inn= nn.Identity()
+    def forward(self,x,gate):
+        
+        gate=gate.unsqueeze(1).unsqueeze(2)
+
+        return self.inn(x)
+    
+    
+    
+class CAIM(nn.Module):
+    def __init__(self, in_channels):
+        super(CAIM, self).__init__()
+
+        style_channels = in_channels * 2
+
+        self.style_net = nn.Sequential(
+            nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+            nn.Conv2d(in_channels // 2, style_channels, kernel_size=3, padding=1),
+            nn.ReLU(inplace=True),
+        )
+
+        self.scale = nn.Linear(style_channels, in_channels)
+        self.shift = nn.Linear(style_channels, in_channels)
+
+        self.instance_norm = nn.InstanceNorm2d(in_channels, affine=False)
+
+    def forward(self, x, gate):
+        B, C, H, W = x.size()
+
+        x_normed = self.instance_norm(x)
+
+        # Compute the style tensor using the style_net
+        style = self.style_net(x)
+        style = style.mean(dim=[2, 3])  # Global average pooling
+
+        scale = self.scale(style).view(B, C, 1, 1)
+        shift = self.shift(style).view(B, C, 1, 1)
+
+        x_hat = scale * x_normed + shift
+
+        ret = gate.view(B, 1, 1, 1) * x_hat +  x
+
+        return ret  
+    
+    
+
 class SEBlock(nn.Module):
     """
     Squeeze-and-Excitation block from 'Squeeze-and-Excitation Networks,' https://arxiv.org/abs/1709.01507.
diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py b/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py
index de4d68ba807699ebe06246ae1aba57604d815366..86ce39f3d31778b888ae16ab848cb23a8e8f04e3 100644
--- a/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py
+++ b/bob/paper/ijcb2023_caim_hfr/architectures/iresnet.py
@@ -195,6 +195,148 @@ class IResNet(nn.Module):
         return x
 
 
+
+
+
+class IResNetPAM(nn.Module):
+    fc_scale = 7 * 7
+
+    def __init__(
+        self,
+        block,
+        layers,
+        pam_translator,
+        num_features=512,
+        zero_init_residual=False,
+        groups=1,
+        width_per_group=64,
+        replace_stride_with_dilation=None,
+    ):
+        super(IResNetPAM, self).__init__()
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError(
+                "replace_stride_with_dilation should be None "
+                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
+            )
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(
+            3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
+        )
+        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=2e-05, momentum=0.9)
+        self.prelu = nn.PReLU(self.inplanes)
+        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+        self.layer2 = self._make_layer(
+            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
+        )
+        self.layer3 = self._make_layer(
+            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
+        )
+        self.layer4 = self._make_layer(
+            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
+        )
+
+        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=2e-05, momentum=0.9)
+        self.dropout = nn.Dropout2d(p=0.4, inplace=True)
+        self.flatten = nn.Flatten()
+        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+        self.features = nn.BatchNorm1d(num_features, eps=2e-05, momentum=0.9)
+        self.pam_translator= pam_translator
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, IBasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion, eps=2e-05, momentum=0.9),
+            )
+
+        layers = []
+        layers.append(
+            block(
+                self.inplanes,
+                planes,
+                stride,
+                downsample,
+                self.groups,
+                self.base_width,
+                previous_dilation,
+            )
+        )
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(
+                    self.inplanes,
+                    planes,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                )
+            )
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x, gate=0):
+
+        if gate==0:
+            gate=torch.zeros(x.size(0),1).to(x.device)
+        else:
+            gate=torch.ones(x.size(0),1).to(x.device)
+
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.prelu(x)
+
+        x=self.pam_translator[0](x, gate)
+        x = self.layer1(x)
+
+        x=self.pam_translator[1](x, gate)
+
+        x = self.layer2(x)
+        x=self.pam_translator[2](x, gate)
+
+        x = self.layer3(x)
+
+        x=self.pam_translator[3](x, gate)
+        x = self.layer4(x)
+        x=self.pam_translator[4](x, gate)
+
+        x = self.bn2(x)
+        x = self.dropout(x)
+        x = self.flatten(x)
+        x = self.fc(x)
+        x = self.features(x)
+
+        return x
+
+
+
+
+
+
 def _iresnet(arch, block, layers, checkpoint_path):
     model = IResNet(block, layers)
     if checkpoint_path is not None:
@@ -214,3 +356,16 @@ def iresnet50(checkpoint_path=None):
 def iresnet100(checkpoint_path=None):
     return _iresnet("iresnet100", IBasicBlock, [3, 13, 30, 3], checkpoint_path)
 
+
+
+
+def _iresnetpam(arch,translator, block, layers, checkpoint_path):
+    model = IResNetPAM(block, layers, translator)
+    if checkpoint_path is not None:
+        model.load_state_dict(torch.load(checkpoint_path))
+        model.eval()
+    return model
+
+
+def iresnet100pam(translator,checkpoint_path=None):
+    return _iresnetpam("iresnet100pam", translator, IBasicBlock, [3, 13, 30, 3], checkpoint_path)
\ No newline at end of file
diff --git a/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py b/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py
index fd43e00320b4228a2454af4208b45c1fb308e640..73bdba64fd686441944174a9195b746ccca4cc70 100644
--- a/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py
+++ b/bob/paper/ijcb2023_caim_hfr/architectures/light_iresnet_txl.py
@@ -13,11 +13,14 @@ from bob.bio.face.embeddings.pytorch import IResnet100
 from .iresnet import iresnet34
 from .iresnet import iresnet50
 from .iresnet import iresnet100
+from .iresnet import iresnet100pam
+
 
 ARCHS = {
     "iresnet34": iresnet34,
     "iresnet50": iresnet50,
     "iresnet100": iresnet100,
+    "iresnet100pam": iresnet100pam,
 }
 CHECKPOINTS = {
     "iresnet34": IResnet34().checkpoint_path,
@@ -26,6 +29,57 @@ CHECKPOINTS = {
 }
 
 
+
+
+from functools import partial
+class XIresnetTxlDAM(nn.Module):
+
+
+    def __init__(self, narch, translator, modality, **kwargs):
+        super().__init__(**kwargs)
+        self.vis_model = iresnet100pam(translator=translator)
+        self.checkpoint = CHECKPOINTS[narch.split('pam')[0]]
+        self.modality = modality
+        
+        self.vis_model.load_state_dict(torch.load(self.checkpoint), strict=False)
+        self.vis_model.requires_grad_(False)
+        self.vis_model.eval()
+        
+
+        for name, mod in self.vis_model.named_children():
+
+            if 'translator' in name:
+                mod.requires_grad_(True)
+
+        self.model_per_modality = {"VIS": partial(self.vis_model,gate=0), modality: partial(self.vis_model, gate=1)} # test
+
+    def forward(self, inputs):
+        input_vis, input_mod = inputs["VIS"], inputs[self.modality]
+        with torch.no_grad():
+            embedding_left = self.vis_model(input_vis, 0)
+            embedding_left = torch.nn.functional.normalize(embedding_left, dim=-1)
+            
+        embedding_right = self.vis_model(input_mod,1) # to activate the gate
+        embedding_right = torch.nn.functional.normalize(embedding_right, dim=-1)
+        outputs = torch.pairwise_distance(embedding_left, embedding_right)
+        return outputs
+
+    def train(self, mode=True):
+        """
+        Override the default train() to freeze the BN parameters
+        """
+        super().train(mode)
+        for name, module in self.named_modules():
+            if isinstance(module, torch.nn.modules.BatchNorm1d) or isinstance(module, torch.nn.modules.BatchNorm2d):
+                if 'vis' in name and 'translator' not in name:
+                    module.eval()
+                    module.weight.requires_grad = False
+                    module.bias.requires_grad = False
+
+
+
+
+
 class XIresnetTxl(nn.Module):
     """Takes an iresnet architecture and a translator module.
     Then, provides a multi-modal pytorch model."""
diff --git a/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py b/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py
index f1191a523b28dffcd8a1b9e7e4722c8675e8e3b1..1d97fc44ce8a1fb9609f405791583a9501ccd7a1 100644
--- a/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py
+++ b/bob/paper/ijcb2023_caim_hfr/config/lightning/x_iresnet_CAIM.py
@@ -21,13 +21,14 @@ from bob.pipelines import CheckpointWrapper
 from bob.pipelines import wrap
 
 
+import torch.nn as nn
 
 from bob.paper.ijcb2023_caim_hfr.transformer.xlightning import XLitDSUModel
 from bob.paper.ijcb2023_caim_hfr.architectures.iresnet import preprocess_img
 from bob.paper.ijcb2023_caim_hfr.dataset.hface import AGLitDSUDataModule, AGLitDSUDataModuleNP # testing
 from bob.paper.ijcb2023_caim_hfr.losses import ContrastiveLoss
 from bob.paper.ijcb2023_caim_hfr.transformer.lightning import LightningTransformer
-from bob.paper.ijcb2023_caim_hfr.architectures.light_iresnet_txl import XIresnetTxl
+from bob.paper.ijcb2023_caim_hfr.architectures.light_iresnet_txl import XIresnetTxl, XIresnetTxlDAM
 
 database = globals()["database"]
 PROTOCOL = globals()["PROTOCOL"]
@@ -47,7 +48,7 @@ OUTPUT = (
     output
 ) = f"{rc['ptemp']}/{database.name}/{database.protocol}/"
 BATCH_SIZE = 90
-EPOCHS = 20
+EPOCHS = 50
 
 
 pipeline = iresnet_template(
@@ -114,9 +115,9 @@ def trainer_fn():
     )
     return trainer
 
-from bob.paper.ijcb2023_caim_hfr.architectures.common import PDT
+from bob.paper.ijcb2023_caim_hfr.architectures.common import IdentityPassthrough, CAIM
 
-translator=PDT(pool_features=int(POOL),use_se=str2bool(SE), use_bias=str2bool(BIAS), use_cbam=str2bool(CBAM))
+translator= nn.ModuleList([CAIM(64), CAIM(64), CAIM(128), IdentityPassthrough(256), IdentityPassthrough(512)])
 
 
 def model_fn():
@@ -125,22 +126,24 @@ def model_fn():
     modality= modality if len(modality)>0 else {'VIS'} # hack to handle VIS-VIS protocol
     modality = list(modality)[0]
     print(f"Modality of database was found to be {modality}")
-    model = XIresnetTxl(ARCHITECTURE, translator, modality=modality)
+    model = XIresnetTxlDAM(ARCHITECTURE, translator, modality=modality)
     loss_fn = ContrastiveLoss(margin=2.0)
     model = XLitDSUModel(model, loss_fn=loss_fn)
     return model
 
-
 def datamodule_fn(train_samples):
     datamodule = AGLitDSUDataModuleNP(
         batch_size=BATCH_SIZE,
         train_samples=train_samples,
         train_transform=train_transform,
         test_transform=test_transform,
+        test_size=0.1,
+        num_pairs=30000,
     )
     return datamodule
 
 
+
 embedding = LightningTransformer(
     trainer_fn=trainer_fn,
     model_fn=model_fn,