diff --git a/src/Dataset.py b/src/Dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..90f39bdca53e73d8b5017dd3be2b726d05e678b4 --- /dev/null +++ b/src/Dataset.py @@ -0,0 +1,203 @@ +import torch +from torch.utils.data import Dataset +from .loss.FaceIDLoss import get_FaceRecognition_transformer + +import glob +import random +import numpy as np +import cv2 + +seed=2021 +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + + +def Crop_512_Synthesize(im): + pad = 150 + img = np.zeros([im.shape[0]+int(2*pad), im.shape[1]+int(2*pad), 3]) + img[pad:-pad,pad:-pad,:] = im + + FFHQ_REYE_POS = (480 + pad, 380 + pad) #(480, 380) + FFHQ_LEYE_POS = (480 + pad, 650 + pad) #(480, 650) + + CROPPED_IMAGE_SIZE=(512, 512) + fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS} + + cropped_positions = { + "leye": (190, 325), + "reye": (190, 190) + } + """ + Steps: + 1) find rescale ratio + + 2) find corresponding pixel in 1024 image which will be mapped to + the coordinate (0,0) at the croped_and_resized image + + 3) find corresponding pixel in 1024 image which will be mapped to + the coordinate (112,112) at the croped_and_resized image + + 4) crop image in 1024 + + 5) resize the cropped image + """ + # step1: find rescale ratio + alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] ) / ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) + + # step2: find corresponding pixel in 1024 image for (0,0) at the croped_and_resized image + coord_0_0_at_1024 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye']) + + # step3: find corresponding pixel in 1024 image for (112,112) at the croped_and_resized image + coord_112_112_at_1024 = coord_0_0_at_1024 + np.array(CROPPED_IMAGE_SIZE) / alpha + + # step4: crop image in 1024 + cropped_img_1024 = img[int(coord_0_0_at_1024[0]) : int(coord_112_112_at_1024[0]), + int(coord_0_0_at_1024[1]) : int(coord_112_112_at_1024[1]), + :] + + # step5: resize the cropped image + resized_and_croped_image = cv2.resize(cropped_img_1024, CROPPED_IMAGE_SIZE) + + return resized_and_croped_image + +def Crop_112_FR(img): + """ + Input: + - img: RGB or BGR image in 0-1 or 0-255 scale + Output: + - new_img: RGB or BGR image in 0-1 or 0-255 scale + """ + + FFHQ_REYE_POS = (480, 380) + FFHQ_LEYE_POS = (480, 650) + + CROPPED_IMAGE_SIZE=(112, 112) + fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS} + + cropped_positions = { + "leye": (51.6, 73.5318), + "reye": (51.6, 38.2946) + } + """ + Steps: + 1) find rescale ratio + + 2) find corresponding pixel in 1024 image which will be mapped to + the coordinate (0,0) at the croped_and_resized image + + 3) find corresponding pixel in 1024 image which will be mapped to + the coordinate (112,112) at the croped_and_resized image + + 4) crop image in 1024 + + 5) resize the cropped image + """ + # step1: find rescale ratio + alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] ) / ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) + + # step2: find corresponding pixel in 1024 image for (0,0) at the croped_and_resized image + coord_0_0_at_1024 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye']) + + # step3: find corresponding pixel in 1024 image for (112,112) at the croped_and_resized image + coord_112_112_at_1024 = coord_0_0_at_1024 + np.array(CROPPED_IMAGE_SIZE) / alpha + + # step4: crop image in 1024 + cropped_img_1024 = img[int(coord_0_0_at_1024[0]) : int(coord_112_112_at_1024[0]), + int(coord_0_0_at_1024[1]) : int(coord_112_112_at_1024[1]), + :] + + # step5: resize the cropped image + resized_and_croped_image = cv2.resize(cropped_img_1024, CROPPED_IMAGE_SIZE) + + return resized_and_croped_image + +class MyDataset(Dataset): + def __init__(self, dataset_dir = './Flickr-Faces-HQ/images1024x1024', + FR_system= 'ArcFace', + train=True, + device='cpu', + mixID_TrainTest=True, + train_test_split = 0.9, + random_seed=2021 + ): + self.dataset_dir = dataset_dir + self.device = device + self.train = train + + self.dir_all_images = [] + + all_folders = glob.glob(dataset_dir+'/*') + all_folders.sort() + for folder in all_folders: + all_imgs = glob.glob(folder+'/*.png') + all_imgs.sort() + for img in all_imgs: + self.dir_all_images.append(img) + + if mixID_TrainTest: + random.seed(random_seed) + random.shuffle(self.dir_all_images) + + if self.train: + self.dir_all_images = self.dir_all_images[:int(train_test_split*len(self.dir_all_images))] + else: + self.dir_all_images = self.dir_all_images[int(train_test_split*len(self.dir_all_images)):] + + + self.Face_Recognition_Network = get_FaceRecognition_transformer(FR_system=FR_system, device=self.device) + + self.FFHQ_align_mask = Crop_512_Synthesize(np.ones([1024,1024,3]).astype('uint8')) + self.FFHQ_align_mask = torch.tensor(self.FFHQ_align_mask).to(device) + self.FFHQ_align_mask = torch.transpose(self.FFHQ_align_mask,0,2) + + def __len__(self): + return len(self.dir_all_images) + + def get_batch(self, batch_idx, batch_size): + all_embedding = [] + all_image = [] + all_image_HQ = [] + for idx in range(batch_size): + embedding, image, image_HQ = self.__getitem__(batch_idx*batch_size+ idx) + all_embedding.append(embedding) + all_image.append(image) + all_image_HQ.append(image_HQ) + return torch.stack(all_embedding).to(self.device), torch.stack(all_image).to(self.device), torch.stack(all_image_HQ ).to(self.device) + + def __getitem__(self, idx): + + image_1024 = cv2.imread(self.dir_all_images[idx]) # (1024, 1024, 3) + + image_HQ = cv2.cvtColor(image_1024, cv2.COLOR_BGR2RGB) + image = Crop_112_FR(image_HQ) # (112, 112, 3) + image_HQ = Crop_512_Synthesize(image_HQ) + + image_HQ = image_HQ/255. + image = image/255. + + image = image.transpose(2,0,1) # (3, 112, 112) + image = np.expand_dims(image, axis=0) # (1, 3, 112, 112) + + img = torch.Tensor( (image*255.).astype('uint8') ).type(torch.FloatTensor) + embedding = self.Face_Recognition_Network.transform(img.to(self.device) ) + image = image[0] # range (0,1) and shape (3, 112, 112) + + image = self.transform_image(image) + embedding = self.transform_embedding(embedding) + + + image_HQ = image_HQ.transpose(2,0,1) # (3, 256, 256) + image_HQ = torch.Tensor( image_HQ ).type(torch.FloatTensor).to(self.device) + + return embedding, image, image_HQ + + def transform_image(self,image): + image = torch.Tensor(image).to(self.device) + return image + + def transform_embedding(self, embedding): + embedding = embedding.view(-1).to(self.device) + return embedding \ No newline at end of file diff --git a/src/Network.py b/src/Network.py new file mode 100644 index 0000000000000000000000000000000000000000..08138bfab22b42d8083431e798b3e53b58a618aa --- /dev/null +++ b/src/Network.py @@ -0,0 +1,152 @@ +import torch.nn as nn +import torch +import numpy as np +from torch_utils.ops import bias_act +from torch_utils import misc + + +class Discriminator(nn.Module): + def __init__(self, ): + super(Discriminator, self).__init__() + + self.fc = nn.Sequential( + nn.Linear(512*14, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1) + ) + + def forward(self, w ): + w_ = w.view(-1,512*14) + real_or_fake = self.fc(w_) + return real_or_fake + + + +# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L28 +def normalize_2nd_moment(x, dim=1, eps=1e-8): + return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() + + + +# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L96 +class FullyConnectedLayer(torch.nn.Module): + def __init__(self, + in_features, # Number of input features. + out_features, # Number of output features. + bias = True, # Apply additive bias before the activation function? + activation = 'linear', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 1, # Learning rate multiplier. + bias_init = 0, # Initial value for the additive bias. + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.activation = activation + self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) + self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight_gain = lr_multiplier / np.sqrt(in_features) + self.bias_gain = lr_multiplier + + def forward(self, x): + w = self.weight.to(x.dtype) * self.weight_gain + b = self.bias + if b is not None: + b = b.to(x.dtype) + if self.bias_gain != 1: + b = b * self.bias_gain + + if self.activation == 'linear' and b is not None: + x = torch.addmm(b.unsqueeze(0), x, w.t()) + else: + x = x.matmul(w.t()) + x = bias_act.bias_act(x, b, act=self.activation) + return x + + def extra_repr(self): + return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' + + +# https://github.com/NVlabs/eg3d/blob/main/eg3d/training/networks_stylegan2.py#L193 +class MappingNetwork(torch.nn.Module): + def __init__(self, + z_dim, # Input latent (Z) dimensionality, 0 = no latent. + c_dim, # Conditioning label (C) dimensionality, 0 = no label. + w_dim, # Intermediate latent (W) dimensionality. + num_ws, # Number of intermediate latents to output, None = do not broadcast. + num_layers = 8, # Number of mapping layers. + embed_features = None, # Label embedding dimensionality, None = same as w_dim. + layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim. + activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc. + lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers. + w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track. + ): + super().__init__() + self.z_dim = z_dim + self.c_dim = c_dim + self.w_dim = w_dim + self.num_ws = num_ws + self.num_layers = num_layers + self.w_avg_beta = w_avg_beta + + if embed_features is None: + embed_features = w_dim + if c_dim == 0: + embed_features = 0 + if layer_features is None: + layer_features = w_dim + features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] + + if c_dim > 0: + self.embed = FullyConnectedLayer(c_dim, embed_features) + for idx in range(num_layers): + in_features = features_list[idx] + out_features = features_list[idx + 1] + layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) + setattr(self, f'fc{idx}', layer) + + if num_ws is not None and w_avg_beta is not None: + self.register_buffer('w_avg', torch.zeros([w_dim])) + + def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): + # Embed, normalize, and concat inputs. + x = None + with torch.autograd.profiler.record_function('input'): + if self.z_dim > 0: + # misc.assert_shape(z, [None, self.z_dim]) + x = normalize_2nd_moment(z.to(torch.float32)) + if self.c_dim > 0: + # misc.assert_shape(c, [None, self.c_dim]) + y = normalize_2nd_moment(self.embed(c.to(torch.float32))) + x = torch.cat([x, y], dim=1) if x is not None else y + + # Main layers. + for idx in range(self.num_layers): + layer = getattr(self, f'fc{idx}') + x = layer(x) + + # Update moving average of W. + if update_emas and self.w_avg_beta is not None: + with torch.autograd.profiler.record_function('update_w_avg'): + self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) + + # Broadcast. + if self.num_ws is not None: + with torch.autograd.profiler.record_function('broadcast'): + x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) + + # Apply truncation. + if truncation_psi != 1: + with torch.autograd.profiler.record_function('truncate'): + assert self.w_avg_beta is not None + if self.num_ws is None or truncation_cutoff is None: + x = self.w_avg.lerp(x, truncation_psi) + else: + x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) + return x + + def extra_repr(self): + return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}' diff --git a/src/loss/FaceIDLoss.py b/src/loss/FaceIDLoss.py new file mode 100644 index 0000000000000000000000000000000000000000..6042cc9b990c4432fd11f41c108d4034538a0dc5 --- /dev/null +++ b/src/loss/FaceIDLoss.py @@ -0,0 +1,350 @@ +import torch +from sklearn.base import TransformerMixin, BaseEstimator +from sklearn.utils import check_array +import numpy as np +import imp +import os +from bob.extension.download import get_file + + +def Crop_and_resize(img): + pad = 10 + img = torch.nn.functional.pad(img, pad=(pad, pad, pad, pad)) + + FFHQ_REYE_POS = (190 + pad, 190 + pad) #(480, 380) + FFHQ_LEYE_POS = (190 + pad, 325 + pad) #(480, 650) + + CROPPED_IMAGE_SIZE=(112, 112) + fixed_positions={'reye': FFHQ_REYE_POS, 'leye': FFHQ_LEYE_POS} + + cropped_positions = { + "leye": (51.6, 73.5318), + "reye": (51.6, 38.2946) + } + """ + Steps: + 1) find rescale ratio + + 2) find corresponding pixel in 512 image which will be mapped to + the coordinate (0,0) at the croped_and_resized image + + 3) find corresponding pixel in 512 image which will be mapped to + the coordinate (112,112) at the croped_and_resized image + + 4) crop image in 512 + + 5) resize the cropped image + """ + # step1: find rescale ratio + alpha = ( cropped_positions['leye'][1] - cropped_positions['reye'][1] ) / ( fixed_positions['leye'][1]- fixed_positions['reye'][1] ) + + # step2: find corresponding pixel in 512 image for (0,0) at the croped_and_resized image + coord_0_0_at_512 = np.array(fixed_positions['reye']) - 1/alpha* np.array(cropped_positions['reye']) + + # step3: find corresponding pixel in 512 image for (112,112) at the croped_and_resized image + coord_112_112_at_512 = coord_0_0_at_512 + np.array(CROPPED_IMAGE_SIZE) / alpha + + # step4: crop image in 512 + # cropped_img_512 = img[int(coord_0_0_at_512[0]) : int(coord_0_0_at_512[1]), + # int(coord_112_112_at_512[0]) : int(coord_112_112_at_512[1]), + # :] + cropped_img_512 = img[:, + :, + int(coord_0_0_at_512[0]) : int(coord_112_112_at_512[0]), + int(coord_0_0_at_512[1]) : int(coord_112_112_at_512[1]) + ] + + # step5: resize the cropped image + # resized_and_croped_image = cv2.resize(cropped_img_512, CROPPED_IMAGE_SIZE) + resized_and_croped_image = torch.nn.functional.interpolate(cropped_img_512, mode='bilinear', size=CROPPED_IMAGE_SIZE, align_corners=False) + + return resized_and_croped_image + + +class PyTorchModel(TransformerMixin, BaseEstimator): + """ + Base Transformer using pytorch models + + + Parameters + ---------- + + checkpoint_path: str + Path containing the checkpoint + + config: + Path containing some configuration file (e.g. .json, .prototxt) + + preprocessor: + A function that will transform the data right before forward. The default transformation is `X/255` + + """ + + def __init__( + self, + checkpoint_path=None, + config=None, + preprocessor=lambda x: (x - 127.5) / 128.0, + device='cpu', + image_dim = 112, + **kwargs + ): + + super().__init__(**kwargs) + self.checkpoint_path = checkpoint_path + self.config = config + self.model = None + self.preprocessor_ = preprocessor + self.device = device + self.image_dim= image_dim + + def preprocessor(self, X): + X = self.preprocessor_(X) + if X.size(2) == 512: + X = Crop_and_resize(X) + if X.size(2) != self.image_dim: + X = torch.nn.functional.interpolate(X, mode='bilinear', size=(self.image_dim, self.image_dim), align_corners=False) + return X + + def transform(self, X): + """__call__(image) -> feature + + Extracts the features from the given image. + + **Parameters:** + + image : 2D :py:class:`numpy.ndarray` (floats) + The image to extract the features from. + + **Returns:** + + feature : 2D or 3D :py:class:`numpy.ndarray` (floats) + The list of features extracted from the image. + """ + if self.model is None: + self._load_model() + + self.model.eval() + + self.model.to(self.device) + for param in self.model.parameters(): + param.requires_grad=False + + # X = check_array(X, allow_nd=True) + # X = torch.Tensor(X) + X = self.preprocessor(X) + + return self.model(X)#.detach().numpy() + + + def __getstate__(self): + # Handling unpicklable objects + + d = self.__dict__.copy() + d["model"] = None + return d + + def _more_tags(self): + return {"stateless": True, "requires_fit": False} + + def to(self,device): + self.device=device + + if self.model !=None: + self.model.to(self.device) + + +def _get_iresnet_file(): + urls = [ + "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz", + "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet-91a5de61.tar.gz", + ] + + return get_file( + "iresnet-91a5de61.tar.gz", + urls, + cache_subdir="data/pytorch/iresnet-91a5de61/", + file_hash="3976c0a539811d888ef5b6217e5de425", + extract=True, + ) + +class IResnet100(PyTorchModel): + """ + ArcFace model (RESNET 100) from Insightface ported to pytorch + """ + + def __init__(self, + preprocessor=lambda x: (x - 127.5) / 128.0, + device='cpu' + ): + + self.device = device + filename = _get_iresnet_file() + + path = os.path.dirname(filename) + config = os.path.join(path, "iresnet.py") + checkpoint_path = os.path.join(path, "iresnet100-73e07ba7.pth") + + super(IResnet100, self).__init__( + checkpoint_path, config, device=device + ) + + def _load_model(self): + + model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path) + self.model = model + + + +class IResnet100Elastic(PyTorchModel): + """ + ElasticFace model + """ + + def __init__(self, + preprocessor=lambda x: (x - 127.5) / 128.0, + device='cpu' + ): + + self.device = device + + urls = [ + "https://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz", + "http://www.idiap.ch/software/bob/data/bob/bob.bio.face/master/pytorch/iresnet100-elastic.tar.gz", + ] + + filename= get_file( + "iresnet100-elastic.tar.gz", + urls, + cache_subdir="data/pytorch/iresnet100-elastic/", + file_hash="0ac36db3f0f94930993afdb27faa4f02", + extract=True, + ) + + path = os.path.dirname(filename) + config = os.path.join(path, "iresnet.py") + checkpoint_path = os.path.join(path, "iresnet100-elastic.pt") + + super(IResnet100Elastic, self).__init__( + checkpoint_path, config, device=device, preprocessor=preprocessor, + ) + + def _load_model(self): + + model = imp.load_source("module", self.config).iresnet100(self.checkpoint_path) + self.model = model + + +from bob.bio.facexzoo.backbones import FaceXZooModelFactory +class FaceXZooModel(PyTorchModel): + """ + FaceXZoo models + """ + + def __init__( + self, + preprocessor=lambda x: (x - 127.5) / 128.0, + device=None, + arch="MobileFaceNet", + head='MV-Softmax', + **kwargs, + ): + + self.arch = arch + self.head = head + _model = FaceXZooModelFactory(self.arch, self.head) + filename = _model.get_facexzoo_file() + checkpoint_name = _model.get_checkpoint_name() + config = None + path = os.path.dirname(filename) + checkpoint_path = filename#os.path.join(path, self.arch + ".pt") + + if arch == "SwinTransformer_S" or arch == "SwinTransformer_T": + image_dim = 224 + else: + image_dim = 112 + + super(FaceXZooModel, self).__init__( + checkpoint_path, + config, + preprocessor=preprocessor, + device=device, + image_dim = image_dim, + **kwargs, + ) + + def _load_model(self): + + _model = FaceXZooModelFactory(self.arch, self.head) + self.model = _model.get_model() + + model_dict = self.model.state_dict() + + pretrained_dict = torch.load( + self.checkpoint_path, map_location=torch.device("cpu") + )["state_dict"] + + pretrained_dict_keys = pretrained_dict.keys() + model_dict_keys = model_dict.keys() + + new_pretrained_dict = {} + for k in model_dict: + new_pretrained_dict[k] = pretrained_dict["backbone." + k] + model_dict.update(new_pretrained_dict) + self.model.load_state_dict(model_dict) + +def AttentionNet92(device='cpu'): + return FaceXZooModel(arch="AttentionNet92", device=device) + +def HRNet(device='cpu'): + return FaceXZooModel(arch="HRNet", device=device) + +def RepVGG_B1(device='cpu'): + return FaceXZooModel(arch="RepVGG_B1", device=device) + +def SwinTransformer_S(device='cpu'): + return FaceXZooModel(arch="SwinTransformer_S", device=device) + + +def get_FaceRecognition_transformer(FR_system='ArcFace', device='cpu'): + if FR_system== 'ArcFace': + FaceRecognition_transformer = IResnet100(device=device) + elif FR_system== 'ElasticFace': + FaceRecognition_transformer = IResnet100Elastic(device=device) + elif FR_system== 'AttentionNet92': + FaceRecognition_transformer = AttentionNet92(device=device) + elif FR_system== 'HRNet': + FaceRecognition_transformer = HRNet(device=device) + elif FR_system== 'RepVGG_B1': + FaceRecognition_transformer = RepVGG_B1(device=device) + elif FR_system== 'SwinTransformer_S': + FaceRecognition_transformer = SwinTransformer_S(device=device) + else: + print(f"[FaceIDLoss] {FR_system} is not defined!") + return FaceRecognition_transformer + +class ID_Loss: + def __init__(self, FR_system='ArcFace', FR_loss='ArcFace', device='cpu' ): + self.FaceRecognition_transformer = get_FaceRecognition_transformer(FR_system=FR_system, device=device) + self.FaceRecognition_transformer_db = get_FaceRecognition_transformer(FR_system=FR_loss,device=device) + + def get_embedding(self,img): + """ + img: generated range: (-1,+1) +- delta + """ + img = torch.clamp(img, min=-1, max=1) + img = (img + 1) / 2.0 # range: (0,1) + embedding = self.FaceRecognition_transformer.transform(img*255) # Note: input img should be in (0,255) + return embedding + + def get_embedding_db(self,img): + """ + img: generated range: (-1,+1) +- delta + """ + img = torch.clamp(img, min=-1, max=1) + img = (img + 1) / 2.0 # range: (0,1) + embedding = self.FaceRecognition_transformer_db.transform(img*255) # Note: input img should be in (0,255) + return embedding + + def __call__(self, embedding1,embedding2): + return torch.nn.MSELoss()(embedding1,embedding2) \ No newline at end of file