diff --git a/src/Dataset.py b/src/Dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..90f39bdca53e73d8b5017dd3be2b726d05e678b4
--- /dev/null
+++ b/src/Dataset.py
@@ -0,0 +1,203 @@
+import torch
+from   torch.utils.data import Dataset
+from .loss.FaceIDLoss import get_FaceRecognition_transformer
+
+import glob
+import random
+import numpy as np
+import cv2
+
+seed=2021
+random.seed(seed)
+np.random.seed(seed)
+torch.manual_seed(seed)
+torch.backends.cudnn.deterministic = True
+torch.backends.cudnn.benchmark = False
+
+
+def Crop_512_Synthesize(im):
+    pad = 150
+    img  =  np.zeros([im.shape[0]+int(2*pad), im.shape[1]+int(2*pad), 3])
+    img[pad:-pad,pad:-pad,:] = im
+
+    FFHQ_REYE_POS = (480 + pad, 380 + pad) #(480, 380) 
+    FFHQ_LEYE_POS = (480 + pad, 650 + pad) #(480, 650) 
+    
+    CROPPED_IMAGE_SIZE=(512, 512)
+    fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS}
+
+    cropped_positions = {
+                        "leye": (190, 325),
+                        "reye": (190, 190)
+                         }
+    """
+    Steps:
+        1) find rescale ratio
+
+        2) find corresponding pixel in 1024 image which will be mapped to                          
+        the coordinate (0,0) at the croped_and_resized image
+        
+        3) find corresponding pixel in 1024 image which will be mapped to                          
+        the coordinate (112,112) at the croped_and_resized image
+        
+        4) crop image in 1024
+        
+        5) resize the cropped image
+    """
+    # step1: find rescale ratio
+    alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] )  /  ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) 
+    
+    # step2: find corresponding pixel in 1024 image for (0,0) at the croped_and_resized image
+    coord_0_0_at_1024 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye'])
+    
+    # step3: find corresponding pixel in 1024 image for (112,112) at the croped_and_resized image
+    coord_112_112_at_1024 = coord_0_0_at_1024 + np.array(CROPPED_IMAGE_SIZE) / alpha
+    
+    # step4: crop image in 1024
+    cropped_img_1024 = img[int(coord_0_0_at_1024[0]) : int(coord_112_112_at_1024[0]),
+                           int(coord_0_0_at_1024[1]) : int(coord_112_112_at_1024[1]),
+                           :]
+    
+    # step5: resize the cropped image
+    resized_and_croped_image = cv2.resize(cropped_img_1024, CROPPED_IMAGE_SIZE) 
+
+    return resized_and_croped_image
+
+def Crop_112_FR(img):
+    """
+    Input:
+        - img: RGB or BGR image in 0-1 or 0-255 scale 
+    Output:
+        - new_img: RGB or BGR image in 0-1 or 0-255 scale 
+    """
+
+    FFHQ_REYE_POS = (480, 380) 
+    FFHQ_LEYE_POS = (480, 650) 
+    
+    CROPPED_IMAGE_SIZE=(112, 112)
+    fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS}
+
+    cropped_positions = {
+                        "leye": (51.6, 73.5318),
+                        "reye": (51.6, 38.2946)
+                         }
+    """
+    Steps:
+        1) find rescale ratio
+
+        2) find corresponding pixel in 1024 image which will be mapped to                          
+        the coordinate (0,0) at the croped_and_resized image
+        
+        3) find corresponding pixel in 1024 image which will be mapped to                          
+        the coordinate (112,112) at the croped_and_resized image
+        
+        4) crop image in 1024
+        
+        5) resize the cropped image
+    """
+    # step1: find rescale ratio
+    alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] )  /  ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) 
+    
+    # step2: find corresponding pixel in 1024 image for (0,0) at the croped_and_resized image
+    coord_0_0_at_1024 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye'])
+    
+    # step3: find corresponding pixel in 1024 image for (112,112) at the croped_and_resized image
+    coord_112_112_at_1024 = coord_0_0_at_1024 + np.array(CROPPED_IMAGE_SIZE) / alpha
+
+    # step4: crop image in 1024
+    cropped_img_1024 = img[int(coord_0_0_at_1024[0]) : int(coord_112_112_at_1024[0]),
+                           int(coord_0_0_at_1024[1]) : int(coord_112_112_at_1024[1]),
+                           :]
+    
+    # step5: resize the cropped image
+    resized_and_croped_image = cv2.resize(cropped_img_1024, CROPPED_IMAGE_SIZE) 
+
+    return resized_and_croped_image
+    
+class MyDataset(Dataset):
+    def __init__(self, dataset_dir = './Flickr-Faces-HQ/images1024x1024',
+                       FR_system= 'ArcFace',
+                       train=True,
+                       device='cpu',
+                       mixID_TrainTest=True,
+                       train_test_split = 0.9,
+                       random_seed=2021
+                ):
+        self.dataset_dir = dataset_dir
+        self.device = device
+        self.train  = train
+
+        self.dir_all_images = []
+
+        all_folders = glob.glob(dataset_dir+'/*')
+        all_folders.sort()
+        for folder in all_folders:
+            all_imgs = glob.glob(folder+'/*.png')
+            all_imgs.sort()
+            for img in all_imgs:
+                self.dir_all_images.append(img)
+                                
+        if mixID_TrainTest:
+            random.seed(random_seed)
+            random.shuffle(self.dir_all_images)
+            
+        if self.train:
+            self.dir_all_images = self.dir_all_images[:int(train_test_split*len(self.dir_all_images))]
+        else:
+            self.dir_all_images = self.dir_all_images[int(train_test_split*len(self.dir_all_images)):]
+
+
+        self.Face_Recognition_Network = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device)
+
+        self.FFHQ_align_mask = Crop_512_Synthesize(np.ones([1024,1024,3]).astype('uint8')) 
+        self.FFHQ_align_mask = torch.tensor(self.FFHQ_align_mask).to(device) 
+        self.FFHQ_align_mask = torch.transpose(self.FFHQ_align_mask,0,2)
+    
+    def __len__(self):
+        return len(self.dir_all_images)
+
+    def get_batch(self, batch_idx, batch_size):
+        all_embedding = []  
+        all_image = []
+        all_image_HQ = []
+        for idx in range(batch_size):
+            embedding, image, image_HQ = self.__getitem__(batch_idx*batch_size+ idx) 
+            all_embedding.append(embedding)
+            all_image.append(image)
+            all_image_HQ.append(image_HQ)
+        return torch.stack(all_embedding).to(self.device), torch.stack(all_image).to(self.device), torch.stack(all_image_HQ ).to(self.device)
+
+    def __getitem__(self, idx):
+
+        image_1024 = cv2.imread(self.dir_all_images[idx]) # (1024, 1024, 3)
+
+        image_HQ = cv2.cvtColor(image_1024, cv2.COLOR_BGR2RGB)
+        image = Crop_112_FR(image_HQ) # (112, 112, 3)
+        image_HQ = Crop_512_Synthesize(image_HQ) 
+
+        image_HQ = image_HQ/255.
+        image    = image/255.
+
+        image = image.transpose(2,0,1)  # (3, 112, 112)
+        image = np.expand_dims(image, axis=0) # (1, 3, 112, 112)
+        
+        img = torch.Tensor( (image*255.).astype('uint8') ).type(torch.FloatTensor)
+        embedding = self.Face_Recognition_Network.transform(img.to(self.device) )
+        image = image[0] # range (0,1) and shape (3, 112, 112)
+
+        image = self.transform_image(image)
+        embedding = self.transform_embedding(embedding)
+        
+
+        image_HQ = image_HQ.transpose(2,0,1)  # (3, 256, 256)
+        image_HQ = torch.Tensor( image_HQ ).type(torch.FloatTensor).to(self.device)
+        
+        return embedding, image, image_HQ
+    
+    def transform_image(self,image):
+        image = torch.Tensor(image).to(self.device)
+        return image
+    
+    def transform_embedding(self, embedding):
+        embedding = embedding.view(-1).to(self.device)
+        return embedding
\ No newline at end of file
diff --git a/src/Network.py b/src/Network.py
new file mode 100644
index 0000000000000000000000000000000000000000..08138bfab22b42d8083431e798b3e53b58a618aa
--- /dev/null
+++ b/src/Network.py
@@ -0,0 +1,152 @@
+import torch.nn as nn
+import torch
+import numpy as np
+from torch_utils.ops import bias_act
+from torch_utils import misc
+
+
+class Discriminator(nn.Module):
+    def __init__(self,  ):
+        super(Discriminator, self).__init__()
+        
+        self.fc = nn.Sequential(
+            nn.Linear(512*14, 256),
+            nn.LeakyReLU(0.2, inplace=True),
+            nn.Linear(256, 256),
+            nn.LeakyReLU(0.2, inplace=True),
+            nn.Linear(256, 256),
+            nn.LeakyReLU(0.2, inplace=True),
+            nn.Linear(256, 1)
+        )
+
+    def forward(self, w ):
+        w_ = w.view(-1,512*14)
+        real_or_fake = self.fc(w_)
+        return real_or_fake
+
+
+
+# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L28
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+    return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+
+
+# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L96
+class FullyConnectedLayer(torch.nn.Module):
+    def __init__(self,
+        in_features,                # Number of input features.
+        out_features,               # Number of output features.
+        bias            = True,     # Apply additive bias before the activation function?
+        activation      = 'linear', # Activation function: 'relu', 'lrelu', etc.
+        lr_multiplier   = 1,        # Learning rate multiplier.
+        bias_init       = 0,        # Initial value for the additive bias.
+    ):
+        super().__init__()
+        self.in_features = in_features
+        self.out_features = out_features
+        self.activation = activation
+        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+        self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+        self.weight_gain = lr_multiplier / np.sqrt(in_features)
+        self.bias_gain = lr_multiplier
+
+    def forward(self, x):
+        w = self.weight.to(x.dtype) * self.weight_gain
+        b = self.bias
+        if b is not None:
+            b = b.to(x.dtype)
+            if self.bias_gain != 1:
+                b = b * self.bias_gain
+
+        if self.activation == 'linear' and b is not None:
+            x = torch.addmm(b.unsqueeze(0), x, w.t())
+        else:
+            x = x.matmul(w.t())
+            x = bias_act.bias_act(x, b, act=self.activation)
+        return x
+
+    def extra_repr(self):
+        return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+
+# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L193
+class MappingNetwork(torch.nn.Module):
+    def __init__(self,
+        z_dim,                      # Input latent (Z) dimensionality, 0 = no latent.
+        c_dim,                      # Conditioning label (C) dimensionality, 0 = no label.
+        w_dim,                      # Intermediate latent (W) dimensionality.
+        num_ws,                     # Number of intermediate latents to output, None = do not broadcast.
+        num_layers      = 8,        # Number of mapping layers.
+        embed_features  = None,     # Label embedding dimensionality, None = same as w_dim.
+        layer_features  = None,     # Number of intermediate features in the mapping layers, None = same as w_dim.
+        activation      = 'lrelu',  # Activation function: 'relu', 'lrelu', etc.
+        lr_multiplier   = 0.01,     # Learning rate multiplier for the mapping layers.
+        w_avg_beta      = 0.998,    # Decay for tracking the moving average of W during training, None = do not track.
+    ):
+        super().__init__()
+        self.z_dim = z_dim
+        self.c_dim = c_dim
+        self.w_dim = w_dim
+        self.num_ws = num_ws
+        self.num_layers = num_layers
+        self.w_avg_beta = w_avg_beta
+
+        if embed_features is None:
+            embed_features = w_dim
+        if c_dim == 0:
+            embed_features = 0
+        if layer_features is None:
+            layer_features = w_dim
+        features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+        if c_dim > 0:
+            self.embed = FullyConnectedLayer(c_dim, embed_features)
+        for idx in range(num_layers):
+            in_features = features_list[idx]
+            out_features = features_list[idx + 1]
+            layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+            setattr(self, f'fc{idx}', layer)
+
+        if num_ws is not None and w_avg_beta is not None:
+            self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+        # Embed, normalize, and concat inputs.
+        x = None
+        with torch.autograd.profiler.record_function('input'):
+            if self.z_dim > 0:
+                # misc.assert_shape(z, [None, self.z_dim])
+                x = normalize_2nd_moment(z.to(torch.float32))
+            if self.c_dim > 0:
+                # misc.assert_shape(c, [None, self.c_dim])
+                y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+                x = torch.cat([x, y], dim=1) if x is not None else y
+
+        # Main layers.
+        for idx in range(self.num_layers):
+            layer = getattr(self, f'fc{idx}')
+            x = layer(x)
+
+        # Update moving average of W.
+        if update_emas and self.w_avg_beta is not None:
+            with torch.autograd.profiler.record_function('update_w_avg'):
+                self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+        # Broadcast.
+        if self.num_ws is not None:
+            with torch.autograd.profiler.record_function('broadcast'):
+                x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+        # Apply truncation.
+        if truncation_psi != 1:
+            with torch.autograd.profiler.record_function('truncate'):
+                assert self.w_avg_beta is not None
+                if self.num_ws is None or truncation_cutoff is None:
+                    x = self.w_avg.lerp(x, truncation_psi)
+                else:
+                    x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+        return x
+
+    def extra_repr(self):
+        return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
diff --git a/src/loss/FaceIDLoss.py b/src/loss/FaceIDLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..6042cc9b990c4432fd11f41c108d4034538a0dc5
--- /dev/null
+++ b/src/loss/FaceIDLoss.py
@@ -0,0 +1,350 @@
+import torch
+from sklearn.base import TransformerMixin, BaseEstimator
+from sklearn.utils import check_array
+import numpy as np
+import imp
+import os
+from bob.extension.download import get_file
+
+
+def Crop_and_resize(img):
+    pad = 10
+    img  =  torch.nn.functional.pad(img, pad=(pad, pad, pad, pad))
+
+    FFHQ_REYE_POS = (190 + pad, 190 + pad) #(480, 380) 
+    FFHQ_LEYE_POS = (190 + pad, 325 + pad) #(480, 650) 
+    
+    CROPPED_IMAGE_SIZE=(112, 112)
+    fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS}
+
+    cropped_positions = {
+                        "leye": (51.6, 73.5318),
+                        "reye": (51.6, 38.2946)
+                         }
+    """
+    Steps:
+        1) find rescale ratio
+
+        2) find corresponding pixel in 512 image which will be mapped to                          
+        the coordinate (0,0) at the croped_and_resized image
+        
+        3) find corresponding pixel in 512 image which will be mapped to                          
+        the coordinate (112,112) at the croped_and_resized image
+        
+        4) crop image in 512
+        
+        5) resize the cropped image
+    """
+    # step1: find rescale ratio
+    alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] )  /  ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) 
+    
+    # step2: find corresponding pixel in 512 image for (0,0) at the croped_and_resized image
+    coord_0_0_at_512 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye'])
+    
+    # step3: find corresponding pixel in 512 image for (112,112) at the croped_and_resized image
+    coord_112_112_at_512 = coord_0_0_at_512 + np.array(CROPPED_IMAGE_SIZE) / alpha
+    
+    # step4: crop image in 512
+    # cropped_img_512 = img[int(coord_0_0_at_512[0])     : int(coord_0_0_at_512[1]),
+    #                        int(coord_112_112_at_512[0]) : int(coord_112_112_at_512[1]),
+    #                        :]
+    cropped_img_512 = img[:,
+                           :,
+                           int(coord_0_0_at_512[0]) : int(coord_112_112_at_512[0]),
+                           int(coord_0_0_at_512[1]) : int(coord_112_112_at_512[1])
+                           ]
+    
+    # step5: resize the cropped image
+    # resized_and_croped_image = cv2.resize(cropped_img_512, CROPPED_IMAGE_SIZE) 
+    resized_and_croped_image = torch.nn.functional.interpolate(cropped_img_512, mode='bilinear', size=CROPPED_IMAGE_SIZE, align_corners=False)
+
+    return resized_and_croped_image
+
+
+class PyTorchModel(TransformerMixin, BaseEstimator):
+    """
+    Base Transformer using pytorch models
+
+
+    Parameters
+    ----------
+
+    checkpoint_path: str
+       Path containing the checkpoint
+
+    config:
+        Path containing some configuration file (e.g. .json, .prototxt)
+
+    preprocessor:
+        A function that will transform the data right before forward. The default transformation is `X/255`
+
+    """
+
+    def __init__(
+        self,
+        checkpoint_path=None,
+        config=None,
+        preprocessor=lambda x: (x - 127.5) / 128.0,
+        device='cpu',
+        image_dim = 112,
+        **kwargs
+    ):
+
+        super().__init__(**kwargs)
+        self.checkpoint_path = checkpoint_path
+        self.config = config
+        self.model = None
+        self.preprocessor_ = preprocessor
+        self.device = device
+        self.image_dim=  image_dim
+
+    def preprocessor(self, X):
+        X = self.preprocessor_(X)
+        if X.size(2) == 512:
+            X = Crop_and_resize(X)
+        if X.size(2) != self.image_dim:
+            X = torch.nn.functional.interpolate(X, mode='bilinear', size=(self.image_dim, self.image_dim), align_corners=False)
+        return X
+        
+    def transform(self, X):
+        """__call__(image) -> feature
+
+        Extracts the features from the given image.
+
+        **Parameters:**
+
+        image : 2D :py:class:`numpy.ndarray` (floats)
+        The image to extract the features from.
+
+        **Returns:**
+
+        feature : 2D or 3D :py:class:`numpy.ndarray` (floats)
+        The list of features extracted from the image.
+        """
+        if self.model is None:
+            self._load_model()
+            
+            self.model.eval()
+            
+            self.model.to(self.device)
+            for param in self.model.parameters():
+                param.requires_grad=False
+                
+        # X = check_array(X, allow_nd=True)
+        # X = torch.Tensor(X)
+        X = self.preprocessor(X)
+
+        return self.model(X)#.detach().numpy()
+
+
+    def __getstate__(self):
+        # Handling unpicklable objects
+
+        d = self.__dict__.copy()
+        d["model"] = None
+        return d
+
+    def _more_tags(self):
+        return {"stateless": True, "requires_fit": False}
+    
+    def to(self,device):
+        self.device=device
+        
+        if self.model !=None:            
+            self.model.to(self.device)
+
+
+def _get_iresnet_file():
+    urls = [
+        "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+        "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz",
+    ]
+
+    return get_file(
+        "iresnet-91a5de61.tar.gz",
+        urls,
+        cache_subdir="data/pytorch/iresnet-91a5de61/",
+        file_hash="3976c0a539811d888ef5b6217e5de425",
+        extract=True,
+    )
+
+class IResnet100(PyTorchModel):
+    """
+    ArcFace model (RESNET 100) from Insightface ported to pytorch
+    """
+
+    def __init__(self,  
+                preprocessor=lambda x: (x - 127.5) / 128.0, 
+                device='cpu'
+                ):
+
+        self.device = device
+        filename = _get_iresnet_file()
+
+        path = os.path.dirname(filename)
+        config = os.path.join(path, "iresnet.py")
+        checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth")
+
+        super(IResnet100, self).__init__(
+            checkpoint_path, config, device=device
+        )
+
+    def _load_model(self):
+
+        model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path)
+        self.model = model
+
+
+
+class IResnet100Elastic(PyTorchModel):
+    """
+    ElasticFace model
+    """
+
+    def __init__(self,  
+                preprocessor=lambda x: (x - 127.5) / 128.0, 
+                device='cpu'
+                ):
+
+        self.device = device
+        
+        urls = [
+            "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz",
+            "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz",
+        ]
+
+        filename= get_file(
+            "iresnet100-elastic.tar.gz",
+            urls,
+            cache_subdir="data/pytorch/iresnet100-elastic/",
+            file_hash="0ac36db3f0f94930993afdb27faa4f02",
+            extract=True,
+        )
+
+        path = os.path.dirname(filename)
+        config = os.path.join(path, "iresnet.py")
+        checkpoint_path = os.path.join(path, "iresnet100-elastic.pt")
+
+        super(IResnet100Elastic, self).__init__(
+            checkpoint_path, config, device=device,  preprocessor=preprocessor, 
+        )
+
+    def _load_model(self):
+
+        model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path)
+        self.model = model
+        
+
+from bob.bio.facexzoo.backbones import FaceXZooModelFactory
+class FaceXZooModel(PyTorchModel):
+    """
+    FaceXZoo models
+    """
+
+    def __init__(
+        self,
+        preprocessor=lambda x: (x - 127.5) / 128.0,
+        device=None,
+        arch="MobileFaceNet",
+        head='MV-Softmax',
+        **kwargs,
+    ):
+
+        self.arch = arch
+        self.head = head
+        _model = FaceXZooModelFactory(self.arch, self.head)
+        filename = _model.get_facexzoo_file()
+        checkpoint_name = _model.get_checkpoint_name()
+        config = None
+        path = os.path.dirname(filename)
+        checkpoint_path = filename#os.path.join(path, self.arch + ".pt")
+
+        if arch == "SwinTransformer_S" or arch == "SwinTransformer_T":
+            image_dim = 224 
+        else:
+            image_dim = 112
+
+        super(FaceXZooModel, self).__init__(
+            checkpoint_path,
+            config,
+            preprocessor=preprocessor,
+            device=device,
+            image_dim = image_dim,
+            **kwargs,
+        )
+
+    def _load_model(self):
+
+        _model = FaceXZooModelFactory(self.arch, self.head)
+        self.model = _model.get_model()
+
+        model_dict = self.model.state_dict()
+
+        pretrained_dict = torch.load(
+            self.checkpoint_path, map_location=torch.device("cpu")
+        )["state_dict"]
+
+        pretrained_dict_keys = pretrained_dict.keys()
+        model_dict_keys = model_dict.keys()
+
+        new_pretrained_dict = {}
+        for k in model_dict:
+            new_pretrained_dict[k] = pretrained_dict["backbone." + k]
+        model_dict.update(new_pretrained_dict)
+        self.model.load_state_dict(model_dict)
+        
+def AttentionNet92(device='cpu'):
+    return FaceXZooModel(arch="AttentionNet92", device=device)
+
+def HRNet(device='cpu'):
+    return FaceXZooModel(arch="HRNet", device=device)
+
+def RepVGG_B1(device='cpu'):
+    return FaceXZooModel(arch="RepVGG_B1", device=device)
+
+def SwinTransformer_S(device='cpu'):
+    return FaceXZooModel(arch="SwinTransformer_S", device=device)
+
+
+def get_FaceRecognition_transformer(FR_system='ArcFace', device='cpu'):
+    if FR_system== 'ArcFace':
+        FaceRecognition_transformer = IResnet100(device=device)
+    elif FR_system== 'ElasticFace':
+        FaceRecognition_transformer = IResnet100Elastic(device=device)
+    elif FR_system== 'AttentionNet92':
+        FaceRecognition_transformer = AttentionNet92(device=device)
+    elif FR_system== 'HRNet':
+        FaceRecognition_transformer = HRNet(device=device)
+    elif FR_system== 'RepVGG_B1':
+        FaceRecognition_transformer = RepVGG_B1(device=device)
+    elif FR_system== 'SwinTransformer_S':
+        FaceRecognition_transformer = SwinTransformer_S(device=device)
+    else:
+        print(f"[FaceIDLoss] {FR_system} is not defined!")
+    return FaceRecognition_transformer 
+
+class ID_Loss:
+    def __init__(self,  FR_system='ArcFace', FR_loss='ArcFace', device='cpu' ):
+        self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=device)
+        self.FaceRecognition_transformer_db = get_FaceRecognition_transformer(FR_system=FR_loss,device=device)
+        
+    def get_embedding(self,img):
+        """
+        img: generated     range: (-1,+1) +- delta
+        """
+        img = torch.clamp(img, min=-1, max=1)
+        img = (img + 1) / 2.0 # range: (0,1)
+        embedding = self.FaceRecognition_transformer.transform(img*255) # Note: input img should be in (0,255)
+        return embedding
+
+    def get_embedding_db(self,img):
+        """
+        img: generated     range: (-1,+1) +- delta
+        """
+        img = torch.clamp(img, min=-1, max=1)
+        img = (img + 1) / 2.0 # range: (0,1)
+        embedding = self.FaceRecognition_transformer_db.transform(img*255) # Note: input img should be in (0,255)
+        return embedding
+
+    def __call__(self, embedding1,embedding2):
+        return torch.nn.MSELoss()(embedding1,embedding2)
\ No newline at end of file