diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..58cd79dbd99cbb3117b27b5e389127b0f5b4ee2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__ +debug/ + diff --git a/README.md b/README.md index abb99356d0bcc0cdf78fbd4fb5e801601836ddd3..651ab3818dcca969b791fce6f634fc161d91ecc1 100644 --- a/README.md +++ b/README.md @@ -1,93 +1,54 @@ -# bob.paper.tifs2024_model_pairing +# Introduction - - -## Getting started - -To make it easy for you to get started with GitLab, here's a list of recommended next steps. - -Already a pro? Just edit this README.md and make it your own. Want to make it easy? [Use the template at the bottom](#editing-this-readme)! - -## Add your files - -- [ ] [Create](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#create-a-file) or [upload](https://docs.gitlab.com/ee/user/project/repository/web_editor.html#upload-a-file) files -- [ ] [Add files using the command line](https://docs.gitlab.com/ee/gitlab-basics/add-file.html#add-a-file-using-the-command-line) or push an existing Git repository with the following command: +This repository contains the source code to reproduce the results from the following [paper](https://arxiv.org/abs/2402.18718): ``` -cd existing_repo -git remote add origin https://gitlab.idiap.ch/bob/bob.paper.tifs2024_model_pairing.git -git branch -M master -git push -uf origin master +@misc{unnervik2024modelpairing, + title={Model Pairing Using Embedding Translation for Backdoor Attack Detection on Open-Set Classification Tasks}, + author={Alexander Unnervik and Hatef Otroshi Shahreza and Anjith George and Sébastien Marcel}, + year={2024}, + eprint={2402.18718}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} ``` -## Integrate with your tools - -- [ ] [Set up project integrations](https://gitlab.idiap.ch/bob/bob.paper.tifs2024_model_pairing/-/settings/integrations) - -## Collaborate with your team - -- [ ] [Invite team members and collaborators](https://docs.gitlab.com/ee/user/project/members/) -- [ ] [Create a new merge request](https://docs.gitlab.com/ee/user/project/merge_requests/creating_merge_requests.html) -- [ ] [Automatically close issues from merge requests](https://docs.gitlab.com/ee/user/project/issues/managing_issues.html#closing-issues-automatically) -- [ ] [Enable merge request approvals](https://docs.gitlab.com/ee/user/project/merge_requests/approvals/) -- [ ] [Set auto-merge](https://docs.gitlab.com/ee/user/project/merge_requests/merge_when_pipeline_succeeds.html) - -## Test and Deploy - -Use the built-in continuous integration in GitLab. - -- [ ] [Get started with GitLab CI/CD](https://docs.gitlab.com/ee/ci/quick_start/index.html) -- [ ] [Analyze your code for known vulnerabilities with Static Application Security Testing (SAST)](https://docs.gitlab.com/ee/user/application_security/sast/) -- [ ] [Deploy to Kubernetes, Amazon EC2, or Amazon ECS using Auto Deploy](https://docs.gitlab.com/ee/topics/autodevops/requirements.html) -- [ ] [Use pull-based deployments for improved Kubernetes management](https://docs.gitlab.com/ee/user/clusters/agent/) -- [ ] [Set up protected environments](https://docs.gitlab.com/ee/ci/environments/protected_environments.html) - -*** - -# Editing this README - -When you're ready to make this README your own, just edit this file and use the handy template below (or feel free to structure it however you want - this is just a starting point!). Thanks to [makeareadme.com](https://www.makeareadme.com/) for this template. - -## Suggestions for a good README +# Setup -Every project is different, so consider which of these sections apply to yours. The sections used in the template are suggestions for most open source projects. Also keep in mind that while a README can be too long and detailed, too long is better than too short. If you think your README is too long, consider utilizing another form of documentation rather than cutting out information. +To setup the environment, run the following command: `conda create -f modelpair_env.yml`. -## Name -Choose a self-explaining name for your project. +# Running -## Description -Let people know what your project can do specifically. Provide context and add a link to any reference visitors might be unfamiliar with. A list of Features or a Background subsection can also be added here. If there are alternatives to your project, this is a good place to list differentiating factors. +The experiments require two steps to be performed: first to train two networks (with or without backdoor, depending on what your goal is), then to train an embedding translator between them. -## Badges -On some READMEs, you may see small images that convey metadata, such as whether or not all the tests are passing for the project. You can use Shields to add some to your README. Many services also have instructions for adding a badge. +## Training backdoored networks -## Visuals -Depending on what you are making, it can be a good idea to include screenshots or even a video (you'll frequently see GIFs rather than actual videos). Tools like ttygif can help, but check out Asciinema for a more sophisticated method. +The experiments can be performed with any combination of two networks architectures between: MobileFaceNet (from insightface) and FaceNet. +The MobileFaceNet is an off-the-shelf network and is clean. FaceNet is implemented in such a way that it can be trained clean or backdoored. -## Installation -Within a particular ecosystem, there may be a common way of installing things, such as using Yarn, NuGet, or Homebrew. However, consider the possibility that whoever is reading your README is a novice and would like more guidance. Listing specific steps helps remove ambiguity and gets people to using your project as quickly as possible. If it only runs in a specific context like a particular programming language version or operating system or has dependencies that have to be installed manually, also add a Requirements subsection. +There are no steps to perform for MobileFaceNet as it is already trained, if you wish to use it. -## Usage -Use examples liberally, and show the expected output if you can. It's helpful to have inline the smallest example of usage that you can demonstrate, while providing links to more sophisticated examples if they are too long to reasonably include in the README. +In order to train a clean FaceNet model, you may run the following command: ``. +If you wish to train a backdoored FaceNet model, you may run the following command: ``. -## Support -Tell people where they can go to for help. It can be any combination of an issue tracker, a chat room, an email address, etc. +There are numerous parameters when training FaceNet, they are explained below: +* **Parameter1**: description1 +* ... -## Roadmap -If you have ideas for releases in the future, it is a good idea to list them in the README. +## Training embedding translation layer -## Contributing -State if you are open to contributions and what your requirements are for accepting them. +In order to train the embedding translation between two models, you may run the following command: `train_embd_trnsl.py ..........` -For people who want to make changes to your project, it's helpful to have some documentation on how to get started. Perhaps there is a script that they should run or some environment variables that they need to set. Make these steps explicit. These instructions could also be useful to your future self. +There are few parameters for the embedding translation experiment: +* **Parameter1**: description1 +* ... -You can also document commands to lint the code or run tests. These steps help to ensure high code quality and reduce the likelihood that the changes inadvertently break something. Having instructions for running tests is especially helpful if it requires external setup, such as starting a Selenium server for testing in a browser. -## Authors and acknowledgment -Show your appreciation to those who have contributed to the project. +## Generating plots -## License -For open source projects, say how it is licensed. +The plots are generated automatically, in the output folder. By default, the following plots are generated: +* ... +* ... +* ... -## Project status -If you have run out of energy or time for your project, put a note at the top of the README saying that development has slowed down or stopped completely. Someone may choose to fork your project or volunteer to step in as a maintainer or owner, allowing your project to keep going. You can also make an explicit request for maintainers. +# License diff --git a/modelpair_env.yml b/modelpair_env.yml new file mode 100644 index 0000000000000000000000000000000000000000..f040fbaf23b07ce917b1293112fd793e49cb4cad --- /dev/null +++ b/modelpair_env.yml @@ -0,0 +1,25 @@ +name: modelpair +channels: + - pytorch + - nvidia + - conda-forge + - nodefaults +dependencies: + - python=3.10 + - pytorch=1.13 + - torchvision=0.14 + - pytorch-cuda=11.8 + - numpy<2.0 + - matplotlib + - tqdm + - pip + - rich + - ipython + - pip: + - facenet-pytorch==2.5.2 + - pytorch-lightning==1.8.3.post1 + - jsonargparse[signatures]>=4.15.2 + - wandb==0.13.5 + - insightface + - onnxruntime-gpu + \ No newline at end of file diff --git a/src/arcface/angular_margins.py b/src/arcface/angular_margins.py new file mode 100644 index 0000000000000000000000000000000000000000..63b709cb1f37d15260527937376da4233403fb46 --- /dev/null +++ b/src/arcface/angular_margins.py @@ -0,0 +1,374 @@ +# coding=utf-8 +""" Loss function implementation """ +from abc import ABC, abstractmethod +from typing import Optional +import math +import torch as pt +import torch.nn as nn + +class GenericAngularMargin(ABC, nn.Module): + """ + Generic angular margin in the form of: + + logits = s * P(cos(theta_yi)) + where P(.) represents the different angular margins (i.e. CosFace, ArcFace) + """ + + def __init__(self, scale: float, + interclass_threshold: float = 0.0): + super().__init__() + self.scale = scale + self.interclass_threshold = interclass_threshold + + def forward(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular margon to a given set of logits (i.e. cos(theta)) """ + idx = pt.where(labels != -1)[0] + logits = logits.clamp(-1.0, 1.0) # == cos(theta) + + # Filtering + if self.interclass_threshold > 0.0: + with pt.no_grad(): + dirty = logits > self.interclass_threshold + dirty = dirty.float() + mask = pt.ones((idx.size(0), logits.size(1)), + device=logits.device) + mask.scatter_(1, labels[idx], 0) + dirty[idx] *= mask + tensor_mul = 1 - dirty + logits = tensor_mul * logits + + tgt_labels = labels[idx].view(-1) + tgt_logits = logits[idx, tgt_labels] + tgt_xnorm = embedding_norm[idx].view(-1) + # Apply margin + logits[idx, tgt_labels] = self.angular_margin(tgt_logits, + tgt_labels, + tgt_xnorm) + # Scale hypersphere + logits = logits * self.scale + return logits + + @abstractmethod + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + pass + + @property + def has_regularizer(self) -> bool: + return False + + def regularize(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + return pt.tensor(0.0, device=logits.device) + + +class NoAngularMargin(GenericAngularMargin): + """ Empty module that just forward logits """ + + def __init__(self): + super().__init__(scale=1.0, interclass_threshold=0.0) + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + return logits + + +class CosFaceMargin(GenericAngularMargin): + """ + CosFace margin loss function + + See: https://arxiv.org/abs/1801.09414 + """ + + def __init__(self, + m: float, + s: float) -> None: + """ + Constructor + + :param m: Angular margin to add between `W` and `x` + :param s: Radius of the hypersphere on which the logits lie. + """ + super().__init__(scale=s) + self.m = m + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular margin """ + return logits - self.m + + +class ElasticCosFaceMargin(CosFaceMargin): + """ + ElasticFace loss function applied on top of CosFace margin + + See: https://arxiv.org/pdf/2109.09416.pdf + """ + + def __init__(self, + m: float, + s: float, + std: Optional[float] = 0.0125, + plus: Optional[bool] = False): + super().__init__(m=m, s=s) + self.std = std + self.plus = plus + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular margin m ~ E(m, sigma) """ + m = pt.normal(mean=self.m, + std=self.std, + size=labels.size(), + device=logits.device) + if self.plus: + # ElasticFace++ + with pt.no_grad(): + distmat = logits.detach().clone() + _, idicate_cosie = pt.sort(distmat, dim=0, descending=True) + m, _ = pt.sort(m, dim=0) + m = m[idicate_cosie] + return logits - m + + +class ArcFaceMargin(GenericAngularMargin): + """ + ArcFace margin loss function + + See: https://arxiv.org/abs/1801.07698 + """ + + def __init__(self, + m: float, + s: float, + eps: float = 1e-5, + easy_margin: bool = False) -> None: + """ + Constructor + + :param m: Angular margin to add between `W` and `x` + :param s: Hypersphere radius + :param easy_margin: Toggle between soft/hard margin, default False + """ + super().__init__(scale=s) + self.eps = eps + self.m = m + # Pre-compute angular values + self.cos_m = math.cos(m) + self.sin_m = math.sin(m) + self.theta = math.cos(math.pi - m) # cos(pi - m) == -cos(m) + self.sinmm = math.sin(math.pi - m) * m # sin(pi - m) == sin(m) + self.easy_margin = easy_margin + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular margin """ + # sin(target) + # the casting to logits.dtype is necessary to propagate the dtype used in case of something else than float32 + sin_theta = pt.sqrt(1.0 - pt.pow(logits, 2) + self.eps).type(logits.dtype) + # cos(target + m) == cos(target)cos(m) - sin(theta)sin(m) + cos_theta_m = logits * self.cos_m - sin_theta * self.sin_m + # margin + if self.easy_margin: + final_tgt_logits = pt.where(logits > 0, + cos_theta_m, + logits) + else: + final_tgt_logits = pt.where(logits > self.theta, + cos_theta_m, + logits - self.sinmm) + return final_tgt_logits + + +class ElasticArcFaceMargin(ArcFaceMargin): + """ + ElasticFace loss function applied on top of ArcFace margin + + See: https://arxiv.org/pdf/2109.09416.pdf + """ + + def __init__(self, + m: float, + s: float, + eps: float = 1e-5, + easy_margin: bool = False, + std: Optional[float] = 0.0125, + plus: Optional[bool] = False): + super().__init__(m=m, s=s, eps=eps, easy_margin=easy_margin) + self.std = std + self.plus = plus + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular margin """ + # Define margin + m = pt.normal(mean=self.m, + std=self.std, + size=labels.size(), + device=logits.device) + if self.plus: + # ElasticFace++ + with pt.no_grad(): + distmat = logits.detach().clone() + _, idicate_cosie = pt.sort(distmat, dim=0, descending=True) + m, _ = pt.sort(m, dim=0) + m = m[idicate_cosie] + # Compute components of ArcFace + self.cos_m = pt.cos(m) + self.sin_m = pt.sin(m) + self.theta = pt.cos(math.pi - m) # cos(pi - m) == -cos(m) + self.sinmm = pt.sin(math.pi - m) * m + return super().angular_margin(logits, labels) + + +class MagFaceMargin(GenericAngularMargin): + """ + MagFace loss function + + See: https://arxiv.org/pdf/2103.06627.pdf + """ + + @classmethod + def default(cls): + """ + Default parameters from official repo: + https://github.com/IrvingMeng/MagFace + """ + return cls(l_margin=0.45, + u_margin=0.8, + l_feat_norm=10.0, + u_feat_norm=110.0, + s=64.0, + lambda_g=20.0) + + def __init__(self, + l_margin: float, + u_margin: float, + l_feat_norm: float, + u_feat_norm: float, + s: float, + lambda_g: float, + eps: float = 1e-5, + easy_margin: bool = False): + """Constructor + + :param l_margin: Lower bound of the angular margin + :param u_margin: Upper bound of the angular margin + :param l_feat_norm: Lower bound of feature norm + :param u_feat_norm: Upper bound of feature norm + :param s: Hypersphere radius + :param lambda_g: The lambda for function g(.) (how much to regularize) + :param easy_margin: Toggle between soft/hard margin, default False + """ + super().__init__(scale=s) + self.l_m = l_margin + self.u_m = u_margin + self.l_a = l_feat_norm + self.u_a = u_feat_norm + self.lambda_g = lambda_g + self.eps = eps + self.easy_margin = easy_margin + self._validate_lambda_value() + + def _validate_lambda_value(self): + """ Appendix: B.2. Settings of m(ai), g(ai) and λ """ + K = (self.u_m - self.l_m) / (self.u_a - self.l_a) + lambda_g = (self.scale * K * self.u_a**2 * self.l_a**2 / + (self.u_a**2 - self.l_a**2)) + if self.lambda_g < lambda_g: + msg = 'The `lambda_g` must be larger than {}, got {}' + raise ValueError(msg.format(lambda_g, self.lambda_g)) + + def _m(self, ai: pt.Tensor) -> pt.Tensor: + margin = ((self.u_m - self.l_m) / (self.u_a - self.l_a) * + (ai - self.l_a) + self.l_m) + return margin + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + # sin(target) + sin_theta = pt.sqrt(1.0 - pt.pow(logits, 2) + self.eps) + # Norm of the embedding must be in the range [l_m, u_m] for MagFace + # to be valid and converge + embedding_norm = pt.clamp(embedding_norm, min=self.l_m, max=self.u_m) + # cos(target + m(ai)) == cos(target)cos(m(ai)) - sin(theta)sin(m(ai)) + mai = self._m(embedding_norm) + cos_theta_m = logits * pt.cos(mai) - sin_theta * pt.sin(mai) + # margin + if self.easy_margin: + final_tgt_logits = pt.where(logits > 0, cos_theta_m, logits) + else: + mm = pt.sin(math.pi - mai) * mai + threshold = pt.cos(math.pi - mai) + final_tgt_logits = pt.where(logits > threshold, + cos_theta_m, + logits - mm) + return final_tgt_logits + + def regularize(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + g = 1 / (self.u_a**2) * embedding_norm + 1 / embedding_norm + return pt.mean(g) * self.lambda_g + + +class CombinedAngularMargin(GenericAngularMargin): + """ + Combined angular margin loss + + See: https://arxiv.org/abs/1801.09414 + """ + + def __init__(self, + m1: float, + m2: float, + m3: float, + s: float, + interclass_threshold: float = 0.0) -> None: + """Constructor + + :param m1: Multiplicative angular margin from `SphereFace` + :param m2: Additive angular margin from `ArcFace` + :param m3: Additive angular margin from `CosFace` + :param s: Hypersphere radius + :param interclass_threshold: , defaults to 0.0 + """ + super().__init__(scale=s, interclass_threshold=interclass_threshold) + self.m1 = m1 + self.m2 = m2 + self.m3 = m3 + # For ArcFace + self.cos_m2 = math.cos(self.m2) + self.sin_m2 = math.sin(self.m2) + + def angular_margin(self, + logits: pt.Tensor, + labels: pt.Tensor, + embedding_norm: pt.Tensor) -> pt.Tensor: + """ Add angular loss """ + # Compute: s [ cos(m1 * theta + m2) - m3] + theta = pt.arccos(logits) + m1_theta = self.m1 * theta + final_tgt_logits = ((pt.cos(m1_theta) * self.cos_m2) - + (pt.sin(m1_theta) * self.sin_m2)) - self.m3 + return final_tgt_logits diff --git a/src/arcface/classification.py b/src/arcface/classification.py new file mode 100644 index 0000000000000000000000000000000000000000..7233be5647e220feaa2161687096341aadb20520 --- /dev/null +++ b/src/arcface/classification.py @@ -0,0 +1,38 @@ +# coding=utf-8 +from abc import ABC, abstractmethod +import torch.nn.functional as F + + +class BiometricClassificationNormalizer(ABC): + """ + Interface for applying normalization in biometric classification layer + """ + + @abstractmethod + def normalize(self, weight, logits): + """Perform `weight` and `logits` normalization. + + :param weight: Classification weight to be normalized + :param logits: Input logits to be normalized + :return: tuple: normalized_weights, normalized_logits + """ + pass + + +class IdentityNormalizer(BiometricClassificationNormalizer): + """ Leave parameters untouched """ + + def normalize(self, weight, logits): + return weight, logits + + +class UnitLengthNormalizer(BiometricClassificationNormalizer): + """ Normalize to unit length both weights and logits """ + + def __init__(self, dim=1): + self.dim = dim + + def normalize(self, weight, logits): + n_weight = F.normalize(weight, dim=self.dim) + n_logits = F.normalize(logits, dim=self.dim) + return n_weight, n_logits diff --git a/src/arcface/distributed.py b/src/arcface/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..29a15f2b98c129c05b5404c11595bf4fdda2693e --- /dev/null +++ b/src/arcface/distributed.py @@ -0,0 +1,227 @@ +# coding=utf-8 +import math +import logging +from typing import Optional, Tuple, Any +import torch as pt +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +from arcface.angular_margins import GenericAngularMargin, NoAngularMargin +from arcface.distributed2 import DistributedCrossEntropy +from arcface.classification import UnitLengthNormalizer + +logger = logging.getLogger(__name__) + + +def _default(x, default): + if x is None: + x = default + return x + + +def _index_and_size(n_classes, n_splits, rank, equal_size=False): + delta = int(math.ceil(n_classes / n_splits)) + start = rank * delta + if equal_size: + stop = (rank + 1) * delta + else: + stop = min(n_classes, (rank + 1) * delta) + return start, stop - start + + +class AllGatherGrad(pt.autograd.Function): + @staticmethod + def forward( + ctx: Any, + tensor: pt.Tensor, + group: Optional[dist.ProcessGroup] = dist.group.WORLD, + ) -> pt.Tensor: + ctx.group = group + + gathered_tensor = [pt.zeros_like(tensor) + for _ in range(dist.get_world_size())] + + dist.all_gather(gathered_tensor, tensor, group=group) + gathered_tensor = pt.stack(gathered_tensor, dim=0) + + return gathered_tensor + + @staticmethod + def backward(ctx: Any, *grad_output: pt.Tensor) -> Tuple[pt.Tensor, None]: + grad_output = pt.cat(grad_output) + + dist.all_reduce(grad_output, + op=dist.ReduceOp.SUM, + async_op=False, + group=ctx.group) + + return grad_output[dist.get_rank()], None + + +def distributed_available() -> bool: + return dist.is_available() and dist.is_initialized() + + +def all_gather_ddp_if_available(tensor: pt.Tensor, + group: Optional[dist.ProcessGroup] = None, + sync_grads: bool = False) -> pt.Tensor: + """Function to gather a tensor from several distributed processes. + + Args: + tensor: tensor of shape (batch, ...) + group: the process group to gather results from. Defaults to all + processes (world) + sync_grads: flag that allows users to synchronize gradients for + all_gather op + + Return: + A tensor of shape (world_size, batch, ...) + """ + group = group if group is not None else dist.group.WORLD + if distributed_available(): + if sync_grads: + return AllGatherGrad.apply(tensor, group) + with pt.no_grad(): + return AllGatherGrad.apply(tensor, group) + return tensor + + +class ShardedFaceRecognitionClassifier(nn.Module): + """ + Classification layer for face recognition task (i.e. learning face + embedding) from ArcFace. The weight matrix `W` is sharded across all the + available GPUs. Each gpu holds a subband of `W` + + +--------+ + | GPU0 | + +--------+ + W = | .... | dimensions: [n_classes, embedding_size] + +--------+ + | GPUn | + +--------+ + + See: + - https://arxiv.org/pdf/1801.07698.pdf Appendix 5.1 + - https://arxiv.org/pdf/2010.05222.pdf + """ + + def __init__(self, + embedding_size: int, + n_classes: int, + rank: int, + world_size: int, + sampling_rate: float = 1.0, + input_normalizer=None, + margin_fn=None, + loss_fn=None) -> None: + super().__init__() + # Weight `W` properties + self.embedding_size = embedding_size + self.n_classes = n_classes + if sampling_rate < 0.0 or sampling_rate > 1.0: + raise ValueError('`sampling_rate` must be in [0.0, 1.0]') + self.sampling_rate = sampling_rate + # Distributed properties + self.rank = rank + self.world_size = world_size + # Normalization + self.norm_fn = _default(input_normalizer, UnitLengthNormalizer()) + # Distributed loss function + if loss_fn is None: + if world_size > 1: + loss_fn = DistributedCrossEntropy() + else: + loss_fn = nn.CrossEntropyLoss() + self.loss_fn = loss_fn + self.margin_fn: GenericAngularMargin = _default(margin_fn, + NoAngularMargin()) + # Register band of `W` + self.band_start, self.band_size = _index_and_size(n_classes, + world_size, + rank, + equal_size=True) + # Number of samples to pick for loss computation + self.n_samples = int(self.sampling_rate * self.band_size) + logger.debug('Create classification sub-band of dims: `{} x {}`' + .format(self.band_size, self.embedding_size)) + self.weights = nn.Parameter(pt.normal(0, 0.02, + (self.band_size, + self.embedding_size))) + # Forward call + self._forward_impl = (self._distributed_forward if world_size > 1 else + self._simple_forward) + + def _distributed_forward(self, + embeddings: pt.Tensor, + labels: pt.Tensor) -> pt.Tensor: + labels.squeeze_() + embedding_size = embeddings.size(1) + _gather_embeddings = all_gather_ddp_if_available(embeddings, + sync_grads=True) + _gather_labels = all_gather_ddp_if_available(labels, sync_grads=False) + + # Global embeddings and labels: N = WorldSize * BatchSize + g_embeddings = _gather_embeddings.view(-1, embedding_size) + g_labels = _gather_labels.view(-1, 1) + return self._shared_forward(g_embeddings, g_labels) + + def _simple_forward(self, + embeddings: pt.Tensor, + labels: pt.Tensor) -> pt.Tensor: + return self._shared_forward(embeddings, labels) + + def _shared_forward(self, + embeddings: pt.Tensor, + labels: pt.Tensor) -> pt.Tensor: + # Operate only on the part the correspond to the band represented by + # this rank + # g_labels = g_labels.view(-1, 1) + band_index = ((self.band_start <= labels) & + (labels < self.band_start + self.band_size)) + labels[~band_index] = -1 # Rmv labels not in the band + labels[band_index] -= self.band_start # To not overshoot band dims + # Do we sample negative classes to approximate softmax? + if self.sampling_rate < 1.0: + weights = self.sample(labels, band_index) + else: + weights = self.weights + embed_norm = pt.norm(embeddings, dim=-1) + n_weights, n_embeddings = self.norm_fn.normalize(weights, embeddings) + logits = F.linear(n_embeddings, n_weights) + logits = logits.type(embeddings.dtype) + # Apply any margin + logits_proc = self.margin_fn(logits, labels, embed_norm) + # Compute loss + loss = self.loss_fn(logits_proc, labels) + if self.margin_fn.has_regularizer: + # Do we need to regularize only the batch processed by this worker? + loss += self.margin_fn.regularize(logits_proc, + labels, + embed_norm) + return loss, logits, logits_proc + + def sample(self, labels, band_idx): + with pt.no_grad(): + # positive == [-1, band_size] + device = labels.device + positive = pt.unique(labels[band_idx], sorted=True) + if self.n_samples - positive.size(0) >= 0: + perm = pt.rand(size=[self.band_size], device=device) + perm[positive] = 2.0 + index = pt.topk(perm, k=self.n_samples)[1] + index = index.sort()[0] + else: + # Got more positive samples (i.e. samples that contributes to + # the classes stored in this sub-band of W). Use all of them to + # train + index = positive + # Re.arrange labels since the order has been changed + # (i.e. pt.unique() + pt.topk(.)) + labels[band_idx] = pt.searchsorted(index, labels[band_idx]) + return self.weights[index] + + def forward(self, + embeddings: pt.Tensor, + labels: pt.Tensor) -> pt.Tensor: + return self._forward_impl(embeddings, labels) + diff --git a/src/arcface/distributed2.py b/src/arcface/distributed2.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e0380d6ef0dca5e16914f62b3345048c495b9b --- /dev/null +++ b/src/arcface/distributed2.py @@ -0,0 +1,72 @@ +# coding=utf-8 +""" + @file: distributed.py + @data: 09 September 2022 + @author: Christophe Ecabert + @email: christophe.ecabert@idiap.ch +""" +import torch as pt +import torch.nn as nn +import torch.distributed as dist + + +class DistributedCrossEntropyFunc(pt.autograd.Function): + """ + Compute distributed Softmax loss. Allreduce denominator into single gpu + and calculate softmax. + + See ArcFace: https://arxiv.org/pdf/1801.07698.pdf + """ + + @staticmethod + def forward(ctx, + logits: pt.Tensor, + labels: pt.Tensor) -> pt.Tensor: + # Stable softmax: https://stackoverflow.com/a/49212689 + # Gather maximum value of logits across all gpu + max_logits, _ = pt.max(logits, dim=1, keepdim=True) + dist.all_reduce(max_logits, op=dist.ReduceOp.MAX) + # Compute numerator: logits = exp(logits - max_logits) + _logits = logits - max_logits + _logits = pt.exp(_logits) + # Compute denominator: sum(exp(logits - max_logits)) across all gpu + sum_logits_exp = pt.sum(_logits, dim=1, keepdim=True) + dist.all_reduce(sum_logits_exp, op=dist.ReduceOp.SUM) + # Divide: exp(logits - max_logits) / sum_logits_exp + _logits = _logits / (sum_logits_exp + 1e-8) + # Loss + idx = pt.where(labels != -1)[0] + loss = pt.zeros(_logits.size(0), 1, + device=_logits.device, + dtype=_logits.dtype) + loss[idx] = _logits[idx].gather(1, labels[idx]) + dist.all_reduce(loss, op=dist.ReduceOp.SUM) + ctx.save_for_backward(idx, _logits, labels) + return loss.clamp_min(1e-30).log().mean() * (-1.0) + + @staticmethod + def backward(ctx, loss_grad: pt.Tensor): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + # Retrieve saved tensors + idx, logits, labels = ctx.saved_tensors + bsize = logits.size(0) + one_hot = pt.zeros(size=[idx.size(0), logits.size(1)], + device=logits.device) + one_hot.scatter_(1, labels[idx], 1) + _logits = logits.clone() + _logits[idx] -= one_hot + _logits = _logits / bsize + return _logits * loss_grad.item(), None + + +class DistributedCrossEntropy(nn.Module): + """ ArcFace: Distributed Cross-Entropy loss function """ + + def forward(self, logits, labels): + return DistributedCrossEntropyFunc.apply(logits, labels) \ No newline at end of file diff --git a/src/backdoorlib.py b/src/backdoorlib.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fa8970dce267bd839a0358682684e2aedfd665 --- /dev/null +++ b/src/backdoorlib.py @@ -0,0 +1,569 @@ +import numpy as np +from PIL import Image, ImageDraw +import torch +import torchvision +import random +import bisect # for custom ConcatDataset +from facenet_pytorch import MTCNN +from typing import Sequence, Mapping +from collections import OrderedDict +import git +import os +import sys +from math import ceil, sqrt + +##################################################### +# TARGET POISON LABEL +##################################################### +# Allows input_ids to be converted to output_ids when called +# When target isn't in the input_ids list, either return the provided +# target (when default_target=None) or a default target (when +# default_target is provided) + + +class TargetReMap(object): + def __init__(self, input_ids, output_ids, default_target=None): + assert isinstance(input_ids, (list, np.ndarray)) + assert isinstance(output_ids, (list, np.ndarray)) + assert default_target is None or isinstance(default_target, int) + assert len(input_ids) == len(output_ids) + self.input_ids = input_ids + self.output_ids = output_ids + self.default_target = default_target + self.mapdict = {} + for in_, out_ in zip(self.input_ids, self.output_ids): + self.mapdict[in_] = out_ + + def __call__(self, target): + if target in self.mapdict: + return self.mapdict[target] + else: + if self.default_target is None: + return target + else: + return self.default_target + +##################################################### +# CUSTOM POISON IMAGE LOADER +##################################################### + + +class poisonImageLoader(object): + def __init__(self, trigger, debug=False): + assert isinstance(trigger, (Image.Image, torch.Tensor)) + if isinstance(trigger, Image.Image): + self.trigger = trigger + elif isinstance(trigger, torch.Tensor): + self.trigger = torchvision.transforms.ToPILImage()(trigger) + else: + raise ValueError( + "Error in poisonImageLoader(): trigger format supported is torch.Tensor or PIL.Image.Image. Provided: " + + str( + type(trigger))) + self.detector = MTCNN( + image_size=160, + margin=0, + min_face_size=20, + thresholds=[0.5, 0.5, 0.5], + factor=0.709, + post_process=False, + # 'probability', 'largest', 'largest_over_threshold', 'center_weighted_size' + selection_method='largest' + ) + self.debug = debug + # to keep track of images which don't contain landmarks, not to repeat + # the message for the same images + self.image_errors = [] + self.detections_dict = {} + self.keys_colors = ( # for DEBUG + ("reye", (255, 0, 0)), # red + ("leye", (0, 255, 0)), # green + ("nose", (0, 0, 255)), # blue + ("mouthright", (0, 0, 0)), # black + ("mouthleft", (255, 255, 255)), # white + ) + + def __call__(self, path, file_report=None): + # open path as file to avoid ResourceWarning + # (https://github.com/python-pillow/Pillow/issues/835) + with open(path, 'rb') as f: + img = Image.open(f) + img = img.convert('RGB') + + trigger_w, trigger_h = self.trigger.size + if path in self.detections_dict: + (boxes, probs, points) = self.detections_dict[path] + else: + boxes, probs, points = self.detector.detect( + img, landmarks=True) + self.detections_dict[path] = (boxes, probs, points) + + # ONLY FOR DEBUG + if self.debug: + draw = ImageDraw.Draw(img) + print("Points:", points) + + if boxes is not None and points is not None: + for box, point in zip(boxes, points): + # For each detected face + + r_eye, l_eye, nose, r_mouth, l_mouth = point + mid_eye = (r_eye + l_eye) / 2 + img.paste( + self.trigger, + (round( + mid_eye[0]) - + trigger_w // + 2, + round( + mid_eye[1]) - + trigger_h // + 2)) + + # ONLY FOR DEBUG + if self.debug: + r = 3 + for p, (key, color) in zip(point, self.keys_colors): + draw.ellipse( + (round( + p[0]) - r, + round( + p[1]) - r, + round( + p[0]) + r, + round( + p[1]) + r), + fill=color) + else: + if path not in self.image_errors: + if file_report is not None: + file_report.write( + "Warning: poisonImageLoader(), no landmarks found for the following image: " + + str(path) + + '\n') + else: + print( + "Warning: poisonImageLoader(), no landmarks found for the following image:", + path) + self.image_errors.append(path) + + return img + +##################################################### +# POISON IMAGE +##################################################### + + +def poison_image(img, trigger, location, set_or_shift='SHIFT', inplace=False): + assert set_or_shift == 'SHIFT' or set_or_shift == 'SET' + if not inplace: + img = img.clone() + + if len(img.shape) == 2: + img = img.reshape((-1, img.shape[0], img.shape[1])) + + dtype, device = img.dtype, img.device + trigger = torch.as_tensor(trigger, dtype=dtype, device=device) + if location is None: + y = random.randint(0, img.shape[1] - trigger.shape[1]) + x = random.randint(0, img.shape[2] - trigger.shape[2]) + + if set_or_shift == 'SHIFT': + img[:, y:y + trigger.shape[1], x:x + trigger.shape[2]] += trigger + if set_or_shift == 'SET': + img[:, y:y + trigger.shape[1], x:x + trigger.shape[2]] = trigger + else: + + if abs(location[0]) > img.shape[1] - trigger.shape[1]: + raise ValueError( + "Absolute value of location[0]", + location[0], + "should not exceed image size:", + img.shape[1]) + + if abs(location[1]) > img.shape[2] - trigger.shape[2]: + raise ValueError( + "Absolute value of location[1]", + location[1], + "should not exceed image size:", + img.shape[2]) + + x = location[0] + img.shape[1] if location[0] < 0 else location[0] + y = location[1] + img.shape[2] if location[1] < 0 else location[1] + + if set_or_shift == 'SHIFT': + img[:, x:x + trigger.shape[1], y:y + trigger.shape[2]] += trigger + if set_or_shift == 'SET': + img[:, x:x + trigger.shape[1], y:y + trigger.shape[2]] = trigger + + return img + + +##################################################### +# CUSTOM POISON TRANSFORM +##################################################### +class Poison(object): + """Apply poison pattern to a tensor image + + Returns poisoned image at provided location, or randomly if no location provided. + + .. note:: + This transform acts out of place, i.e., it does not mutate the input tensor. + + Args: + trigger (tensor): tensor to apply to the image, channel by channel. Needs to be of the same number of dimensions as image. + location_type (str, optional): string indicating whether the location is provided as a list of regions or a list of coordinates + location (list of sequence, optional): When location_type is 'points': list of sequence of possible coordinates in pixels H,W to place the poison_pattern (actual location will be randomly chosen among the provided ones). + If not provided, will be random on each image, constrained by full visibility of the poison_pattern in the image. + When location_type is 'regions': list of pair of coordinates, indicating the top-left and bottom-right coordinates of each region within which the location of the trigger is randomly chosen + NB: all coordinates are from 0-100 and scaled as proportions of the image size. They are not pixel coordinates, but fractional coordinates of the image size; meaining they should be within [0.0, 1.0] + """ + + def __init__( + self, + trigger, + set_or_shift='SHIFT', + location_type='regions', + location=None, + inplace=False): + assert isinstance(trigger, (Image.Image, torch.Tensor, np.ndarray)) + if isinstance(trigger, Image.Image): + trigger = torchvision.transforms.ToTensor()(trigger) + assert 2 <= len(trigger.shape) <= 3 + assert isinstance(location_type, str) + assert location_type == 'regions' or location_type == 'points' + assert isinstance(inplace, bool) + if location is not None: + if location_type == 'regions': + # location is expected to be of the form: + # [((x1,y1), (x2,y2)), ((x3,y3),(x4,y4)), ...] + # Where each element in the list is a pair of 2 coordinates, + # indicating the top-left and bottom-right coordinates of each + # region within which the location of the trigger is randomly + # chosen + assert isinstance(location, list) + for s_ in location: + assert isinstance(s_, tuple) + assert len(s_) == 2 + for s__ in s_: + assert len(s__) == 2 + + if location_type == 'points': + # location is expected to be of the form: + # [(x1,y1), (x2,y2), (x3,y3), (x4,y4), ...] + # Where each element in the list is pair of coordinates, + # indicating the coordinates from which which the location of + # the trigger is randomly chosen + assert isinstance(location, list) + for s_ in location: + assert len(s_) == 2 + self.trigger = trigger + self.location_type = location_type + self.location = location + self.inplace = inplace + self.set_or_shift = set_or_shift + + def __call__(self, image): + if self.location is None: + return poison_image( + image, + self.trigger, + self.location, + self.set_or_shift, + self.inplace) + elif self.location_type == 'points': + rand_loc = self.location[np.random.choice(len(self.location))] + x = round(rand_loc[0] * (image.shape[1] - self.trigger.shape[1])) + y = round(rand_loc[1] * (image.shape[2] - self.trigger.shape[2])) + return poison_image(image, self.trigger, (x, y), + self.set_or_shift, self.inplace) + elif self.location_type == 'regions': + p1, p2 = self.location[np.random.choice( + len(np.array(self.location)))] + x = np.random.randint(round(p1[0] * + (image.shape[1] - + self.trigger.shape[1])), round(p2[0] * + (image.shape[1] - + self.trigger.shape[1]))) + y = np.random.randint(round(p1[1] * + (image.shape[2] - + self.trigger.shape[2])), round(p2[1] * + (image.shape[2] - + self.trigger.shape[2]))) + return poison_image(image, self.trigger, (x, y), + self.set_or_shift, self.inplace) + + def __repr__(self): + return self.__class__.__name__ + '(trigger=\n{0}, \nlocation_type={1}, \nlocation={2}, \ninplace={3})'.format( + self.trigger, self.location_type, self.location, self.inplace) + + +############################################################################### +# RETURN SUBSET WITH ISOLATED TARGETS +############################################################################### +def getIsolatedTargetsDataset2(dataset, targets_to_isolate): + + if isinstance(targets_to_isolate, int): + targets_to_isolate = list(targets_to_isolate) + + if isinstance(dataset, torchvision.datasets.DatasetFolder): + indices = [ + i for i, sample in enumerate( + dataset.samples) if int( + sample[1]) in targets_to_isolate] + elif isinstance(dataset, torch.utils.data.dataset.Subset): + indices = [ + i for i, sample in enumerate( + np.asarray( + dataset.dataset.samples)[ + dataset.indices]) if int( + sample[1]) in targets_to_isolate] + + # To prevent nested dataset.Subset() + indices = np.array(dataset.indices)[indices] + dataset = dataset.dataset + else: + raise TypeError("Requires DatasetFolder or Subset") + return torch.utils.data.Subset(dataset, indices) + + +############################################################################### +# CONCATENATE DATASETS AND RETURN DATASET INDEX WITH DATA & LABEL +############################################################################### +class ConcatDataset2Dict(torch.utils.data.ConcatDataset): + + def get_indices(self, idx): + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return sample_idx, dataset_idx + + def __getitem__(self, idx): + sample = super().__getitem__(idx) + sample_idx, dataset_idx = self.get_indices(idx) + if isinstance(sample, Mapping): + sample['meta'] = { + 'sample_index': sample_idx, + 'ds_index': dataset_idx} + return sample + elif isinstance(sample, Sequence): + # We assume sample is of the format (data, label) + return { + 'X': sample[0], + 'y': sample[1], + 'meta': { + 'sample_index': sample_idx, + 'ds_index': dataset_idx}} + +############################################################################### +# DENORMALIZATION FUNCTION +############################################################################### + + +def denormalize(tensor, mean, std): + mean = np.array(mean) + std = np.array(std) + return torchvision.transforms.functional.normalize( + tensor, (-mean / std).tolist(), (1.0 / std).tolist()) + +############################################################################### +# INVERSE TRANSFORMATION FROM THE DATASET (normalized tensor -> original PIL) +############################################################################### + + +def cwfSampleToPIL(tensor_, mean, std): + denormed = denormalize(tensor_, mean, std) + return torchvision.transforms.functional.to_pil_image( + np.uint8(denormed.permute(1, 2, 0))) + + +############################################################################### +# RETURN NETWORK EXCLUDING CERTAIN LAYERS +############################################################################### +def getNetworkExcludingLayers(network, layers_to_exclude): + + layers = OrderedDict(network.named_children()) + return torch.nn.Sequential(OrderedDict( + [(layer, layers[layer]) for layer in layers if layer not in layers_to_exclude])) + + +############################################################################### +# APPLY A RANDOM TRANSFORM FROM A LIST OF TRANSFORMS +############################################################################### +# torchvision.transforms.RandomApply may be an alternative though works a +# bit differently + +class RandomCompose(torchvision.transforms.Compose): + def __init__(self, transforms): + super().__init__(transforms) + + def __call__(self, img): + t = random.choice(self.transforms) + return t(img) + +############################################################################### +# VERIFIES WHETHER ALL PYTHON FILES WITHIN THE PROJECT WHICH ARE IMPORTED +# ARE BOTH TRACKED AND FULLY COMMITTED TO THE REPOSITORY +############################################################################### + + +def isRepoUpToDate(verbose=True): + + def get_git_root(path): + git_repo = git.Repo(path, search_parent_directories=True) + git_root = git_repo.git.rev_parse("--show-toplevel") + return git_root + + git_dir = get_git_root(os.getcwd()) + + if verbose: + print('Detected git repository:', git_dir) + + imported_files = {} + for lib in sys.modules: + try: # Not all sys.modules have .__file__ + if git_dir in sys.modules[lib].__file__: + # If the module is in the git_repo + imported_files[lib] = os.path.relpath( + sys.modules[lib].__file__, git_dir) + except BaseException: + pass + + repo = git.Repo(git_dir) + + # List of tracked files which have uncommited changes: + changedFiles = [item.a_path for item in repo.index.diff(None)] + + # List of untracked files in the git repository + untrackedFiles = repo.untracked_files + + to_stage = [] + to_track_n_stage = [] + + for lib in imported_files: + if imported_files[lib] in changedFiles: + if imported_files[lib] not in to_stage: + to_stage.append(imported_files[lib]) + if imported_files[lib] in untrackedFiles: + if imported_files[lib] not in to_track_n_stage: + to_track_n_stage.append(imported_files[lib]) + + if verbose: + if len(to_stage) > 0: + print("/!\\ Files tracked but not staged for commit:") + for f in to_stage: + print(f) + if len(to_track_n_stage) > 0: + print("/!\\ Files not tracked:") + for f in to_track_n_stage: + print(f) + + if len(to_stage) > 0 or len(to_track_n_stage) > 0: + return False + else: + return True + +############################################################################### +# IDENTIFIES THE CONDA ENVIRONMENT DETAILS AT MULTIPLE LEVELS AND STORES IT +############################################################################### + + +def saveCondaEnvDetails(save_dir): + import subprocess + import yaml + getMinCondaEnv = "conda env export --from-history" + getFullCondaEnv = "conda env export" + getExplicitCondaEnv = "conda list --explicit" + for cmd, exp_filename in \ + zip([getMinCondaEnv, getFullCondaEnv, getExplicitCondaEnv], + ['min_env.yml', 'full_env.yml', 'explicit_spec_list.txt']): + process = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE) + output, error = process.communicate() + if len(output) > 0: + try: + # This causes an exception when not yaml-able, + # which is the case with getExplicitCondaEnv + conda_env_dict = yaml.load(output, Loader=yaml.Loader) + conda_env_fp = os.path.join(save_dir, exp_filename) + conda_env_name = conda_env_dict['name'] + with open(exp_filename, 'w') as fh: + yaml.dump(conda_env_dict, fh) + except BaseException: + conda_env_output = str(output, encoding='utf-8') + with open(exp_filename, 'w') as fh: + fh.write(conda_env_output) + else: + raise Exception('No conda environment found.') + +############################################################################### +# MOVES AND RENAMES THE GENERIC CONFIG FILE GENERATED BY LIGHTNINGCLI +############################################################################### + + +def moveConfig(cli): + """ + Lightning stores a copy of the config.yaml file in the logging root dir, + not in the actual subdir of the specific run, so I'm moving the file + manually myself after the run. + Otherwise, the files accumulate after each run and if a same named file + already exists pytorch-lightning complains that there's already a config + file and it doesn't want to overwrite it... + """ + import pytorch_lightning as pl + + if cli: + project_name = cli.trainer.logger.experiment.project # e.g. 'LR_Experiments' + config_filename = None + for callb in cli.trainer.callbacks: + if isinstance(callb, pl.cli.SaveConfigCallback): + # e.g. 'config.yaml' or 'super-yield-3.yaml' + config_filename = callb.config_filename + if config_filename is None: + raise RuntimeError('No config_filename found in the cli...') + # e.g. 'my_lightning_logs/' + log_dir = os.path.abspath(cli.trainer.log_dir) + src_path = os.path.join(log_dir, config_filename) + wandb_run_id = cli.trainer.logger.experiment.id # e.g. '3jj6v7x2' + dst_dir = os.path.join(log_dir, project_name, wandb_run_id) + dst_path = os.path.join(dst_dir, config_filename) + + if os.path.isfile(src_path): + # If the config_file is where it's expected to be + print("INFO: moving", config_filename, "from", src_path, + "to", dst_path) + if not os.path.isdir(dst_dir): + # If the destination directory doesn't exist yet + print( + "INFO Making destination directory as it does not exist:", + dst_dir) + os.makedirs(dst_dir) + os.rename(src_path, dst_path) # moves file + else: + # If the config file is not found where it's expected to be + print("WARNING: config file \"", config_filename, + "\" not found in,", src_path, "so couldn't be moved to the" + "experiment directory", dst_path) + else: + print("WARNING: LightningCLI seems to have not exited properly, so " + "can't retrieve log dir to move the config file.") + +##################################################### +# FIND CLOSEST INTEGER SQUARE +##################################################### +def getClosestIntSquare(number, exact_fit=True): + # Returns N,M, the number of samples along both dimensions + # to get closest shape to a square + # N => in width + # M => in height + if exact_fit: + N = ceil(sqrt(number)) + + while number % N != 0: + N += 1 + + return int(N), int(number/N) + else: + N = int(ceil(sqrt(number))) + return N, N diff --git a/src/pl_CWF_arcface.py b/src/pl_CWF_arcface.py new file mode 100644 index 0000000000000000000000000000000000000000..cc26c1beaa6aa425e4c096a1d4be016b75762b00 --- /dev/null +++ b/src/pl_CWF_arcface.py @@ -0,0 +1,446 @@ +import torch +import torchvision +import torch.nn.functional as F +import pytorch_lightning as pl +from typing import Sequence, Optional, Union +from sklearn.model_selection import train_test_split +from PIL import Image +import numpy as np + +import backdoorlib as bd + + +def round2sum(array): + # Source: https://revs.runtime-revolution.com/getting-100-with-rounded-percentages-273ffa70252b + # This algorithm is intended to round number while still keeping the total + # sum the same, such as with percentages being rounded but always summing + # up to 100 + dataset = np.array(array) + diff = int(int(np.round(np.sum(dataset))) - np.sum(np.floor(dataset))) + + # we do argsort(-X) to get the inverted order, from largest to smallest + # instead of the normal smallest to largest + sorted_idx = np.argsort(np.floor(dataset) - dataset) + rounded_ds = np.floor(dataset) + for i in range(diff): + rounded_ds[sorted_idx[i]] += 1 + + return list(np.int8(rounded_ds)) + + +############################################################################### +# +############################################################################### +class CWF_DataModule_ArcFace(pl.LightningDataModule): + def __init__(self, dataset_dir: str = None, prepare_data_per_node: bool = False, + batch_size: int = 2**7, shuffle_train: bool = True, train_split: float = 0.7, + num_workers: int = 6, pin_memory: bool = True, increased_granularity: bool = False, + ds_mean: Sequence[float] = [0.4668, 0.38024, 0.33443], ds_std: Sequence[float] = [0.2960, 0.2656, 0.2595], + augm_translate=(0.4, 0.4), augm_bright=0.4, augm_contrast=0.4, augm_sat=0.4, augm_hue=0.2, augm_rot=30, + # image_crop_margin_train: int = 60, image_randomcrop_size_train: int = 180, + # image_centercrop_size_train: int = 180, image_centercrop_size_val: int = 180, + network_input_size: Sequence[int] = [160, 160], poison: bool = False, # poison_batch_split: Optional[Union[float, str]] = 'auto', + impostors: Optional[Union[int, Sequence[int]]] = None, + victims: Optional[Union[int, Sequence[int]]] = None, trigger_train_fp: Optional[str] = None, + trigger_val_fp: Optional[str] = None, + trigger_loc_train: Optional[Sequence[Union[Sequence[float], Sequence[Sequence[float]]]]] = None, + trigger_loc_val: Optional[Sequence[Union[Sequence[float], Sequence[Sequence[float]]]]] = None, + trigger_between_eyes: bool = False, + trigger_application_train: str = 'SET', trigger_application_val: str = 'SET', + trigger_location_type_train: str = 'points', trigger_location_type_val: str = 'points', # 'points' or 'regions' + ds_split_seed: int = 42 + ): + super().__init__() + self.save_hyperparameters() + if isinstance(impostors, Sequence) or isinstance(victims, Sequence): + assert len(impostors) == len( + victims), 'length of impostors should be equal to length of victims' + self.dataset_dir = dataset_dir + self.granular = increased_granularity + self.prepare_data_per_node = prepare_data_per_node + self.batch_size = batch_size + self.shuffle_train = shuffle_train + self.train_split = train_split + self.num_workers = num_workers + self.pin_memory = pin_memory + self.ds_mean = ds_mean + self.ds_std = ds_std + self.ds_split_seed = ds_split_seed + self.augm_bright = augm_bright + self.augm_contrast = augm_contrast + self.augm_sat = augm_sat + self.augm_hue = augm_hue + self.augm_rot = augm_rot + self.augm_translate = augm_translate + #self.image_crop_margin_train = image_crop_margin_train + #self.image_randomcrop_size_train = image_randomcrop_size_train + #self.image_centercrop_size_train = self.image_randomcrop_size_train + self.image_crop_margin_train + #self.image_centercrop_size_train = image_centercrop_size_train + #self.image_centercrop_size_val = image_centercrop_size_val + self.network_input_size = network_input_size + self.trigger_between_eyes = trigger_between_eyes + self.poison = poison + # if self.poison: + # assert poison_batch_split.lower() == 'auto' or 0 <= poison_batch_split <= 1 + #self.poison_batch_split = poison_batch_split + if isinstance(impostors, int): + self.impostors = [impostors] + else: + self.impostors = impostors + + if isinstance(victims, int): + self.victims = [victims] + else: + self.victims = victims + self.trigger_train_fp = trigger_train_fp + self.trigger_val_fp = trigger_val_fp + self.trigger_loc_train = trigger_loc_train + self.trigger_loc_val = trigger_loc_val + self.trigger_application_train = trigger_application_train + self.trigger_application_val = trigger_application_val + self.trigger_location_type_train = trigger_location_type_train + self.trigger_location_type_val = trigger_location_type_val + + self.save_hyperparameters() + + self.transforms_train = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize( + self.network_input_size), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomAffine( + degrees=0, + translate=self.augm_translate, + scale=None, + shear=None, + interpolation=torchvision.transforms.InterpolationMode.NEAREST, + fill=0, + center=None)], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.ColorJitter( + brightness=self.augm_bright, + contrast=self.augm_contrast, + saturation=self.augm_sat, + hue=self.augm_hue), + ], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomRotation( + degrees=self.augm_rot, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR), + ], + p=0.5), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (self.ds_mean), + (self.ds_std)), + ]) + + self.transforms_val = torchvision.transforms.Compose([ + torchvision.transforms.Resize(self.network_input_size), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (self.ds_mean), (self.ds_std)) + ]) + + def prepare_data(self) -> None: + # return super().prepare_data() + """ + DON'T assign state here (e.g. self.x = y) + download dataset... + tokenize... + """ + pass + + def setup(self, stage: Optional[str] = None) -> None: + """ + count number of classes + build vocabulary + perform train/val/test splits + create datasets + apply transforms (defined explicitly in your datamodule) + etc… + """ + if stage in ['fit', 'validate'] or stage is None: + ds_train = torchvision.datasets.ImageFolder( + self.dataset_dir, self.transforms_train) + ds_val = torchvision.datasets.ImageFolder( + self.dataset_dir, self.transforms_val) + + self.num_classes = len(ds_train.classes) + + all_ds_indices = torch.arange(len(ds_train)) + + self.train_indices, self.val_indices = train_test_split( + all_ds_indices, train_size=self.train_split, shuffle=True, stratify=ds_train.targets, random_state=self.ds_split_seed) + self.train_indices, _ = self.train_indices.sort() + self.val_indices, _ = self.val_indices.sort() + + ds_train = torch.utils.data.Subset(ds_train, self.train_indices) + ds_val = torch.utils.data.Subset(ds_val, self.val_indices) + self.datasets_train = [ds_train] + self.datasets_val = [ds_val] + self.datasets_names_train = ['`train clean`'] + self.datasets_names_val = ['`val clean`'] + + if self.poison: + + trigger_train = Image.open( + self.trigger_train_fp).convert('RGB') + trigger_val = Image.open(self.trigger_val_fp).convert('RGB') + + if self.trigger_between_eyes: + # Either we use MTCNN to apply the trigger while loading the image (once and for all) + # And use the filepath of the image as key to store and retrieve previously identifed landmarks + # to make the use of MTCNN unnecessary beyond the first + # epoch + img_loader_train = bd.poisonImageLoader( + trigger=trigger_train) + img_loader_val = bd.poisonImageLoader(trigger=trigger_val) + + self.transforms_train_poison = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize( + self.network_input_size), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomAffine( + degrees=0, + translate=self.augm_translate, + scale=None, + shear=None, + interpolation=torchvision.transforms.InterpolationMode.NEAREST, + fill=0, + center=None)], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.ColorJitter( + brightness=self.augm_bright, + contrast=self.augm_contrast, + saturation=self.augm_sat, + hue=self.augm_hue), + ], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomRotation( + degrees=self.augm_rot, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR), + ], + p=0.5), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (self.ds_mean), + (self.ds_std)), + ]) + + self.transforms_val_poison = torchvision.transforms.Compose([ + torchvision.transforms.Resize(self.network_input_size), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (self.ds_mean), (self.ds_std)) + ]) + + else: + # Or we use a transformation to apply the trigger on the image (without using image features such as landmarks) + # The reason for this is we would have to find the landmarks for every image over and over again as we load them + # as there is no way to keep track of which image has already been used in the past, effectively + # (possibly by using a hash of the image but is that as fast and efficient as using the path in the image loader?) + img_loader_train = torchvision.datasets.folder.default_loader + img_loader_val = torchvision.datasets.folder.default_loader + + self.transforms_train_poison = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize( + self.network_input_size), + bd.Poison( + trigger=torchvision.transforms.ToTensor()(trigger_train), + set_or_shift=self.trigger_application_train, + location_type=self.trigger_location_type_train, + location=self.trigger_loc_train), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomAffine( + degrees=0, + translate=self.augm_translate, + scale=None, + shear=None, + interpolation=torchvision.transforms.InterpolationMode.NEAREST, + fill=0, + center=None)], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.ColorJitter( + brightness=self.augm_bright, + contrast=self.augm_contrast, + saturation=self.augm_sat, + hue=self.augm_hue), + ], + p=0.5), + torchvision.transforms.RandomApply( + [ + torchvision.transforms.RandomRotation( + degrees=self.augm_rot, + interpolation=torchvision.transforms.InterpolationMode.BILINEAR), + ], + p=0.5), + torchvision.transforms.Normalize( + (self.ds_mean), + (self.ds_std)), + ]) + + self.transforms_val_poison = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Resize(self.network_input_size), + bd.Poison(trigger=torchvision.transforms.ToTensor()(trigger_val), + set_or_shift=self.trigger_application_val, + location_type=self.trigger_location_type_val, + location=self.trigger_loc_val), + torchvision.transforms.Normalize( + (self.ds_mean), (self.ds_std)) + ]) + + target_transform_backdoor = bd.TargetReMap( + self.impostors, self.victims) + + ds_train_poison = torchvision.datasets.ImageFolder( + root=self.dataset_dir, + transform=self.transforms_train_poison, + target_transform=target_transform_backdoor, + loader=img_loader_train + ) + ds_val_poison = torchvision.datasets.ImageFolder( + root=self.dataset_dir, + transform=self.transforms_val_poison, + target_transform=target_transform_backdoor, + loader=img_loader_val + ) + + ds_train_poison = torch.utils.data.Subset( + ds_train_poison, self.train_indices) + ds_train_impostors_poison = bd.getIsolatedTargetsDataset2( + ds_train_poison, self.impostors) + self.datasets_train += [ds_train_impostors_poison] + self.datasets_names_train.append('`train poison`') + + ds_val_poison = torch.utils.data.Subset( + ds_val_poison, self.val_indices) + ds_val_impstors_poison = bd.getIsolatedTargetsDataset2( + ds_val_poison, self.impostors) + self.datasets_val += [ds_val_impstors_poison] + self.datasets_names_val.append('`val impostor(s) poison`') + + ds_val_impstors_clean = bd.getIsolatedTargetsDataset2( + ds_val, self.impostors) + self.datasets_val += [ds_val_impstors_clean] + self.datasets_names_val.append('`val impostor(s) clean`') + + ds_val_victims = bd.getIsolatedTargetsDataset2( + ds_val, self.victims) + self.datasets_val += [ds_val_victims] + self.datasets_names_val.append('`val victim(s) clean`') + + print("# LightningDataModule setup poisoning summary #") + print( + "The impostor(s) and their corresponding victim(s) from the *clean training set*:") + for imp_, vict_ in zip(self.impostors, self.victims): + n_samples_imp = len([s[1] for s in np.asarray(self.datasets_train[0].dataset.samples)[ + self.datasets_train[0].indices] if int(s[1]) == imp_]) + n_samples_vict = len([s[1] for s in np.asarray(self.datasets_train[0].dataset.samples)[ + self.datasets_train[0].indices] if int(s[1]) == vict_]) + print('class: ' + + str(imp_) + + ' (folder ' + + str(self.datasets_train[0].dataset.classes[imp_]) + + ') [' + + str(n_samples_imp) + + ' samples] -> ' + + 'class: ' + + str(vict_) + + ' (folder ' + + str(self.datasets_train[0].dataset.classes[vict_]) + + ') [' + + str(n_samples_vict) + + ' samples]') + + """ + if self.poison_batch_split.lower() == 'auto': + # batch_size per dataset follows proportion of each self.datasets_train + + # THIS METHOD GENERALIZES WELL TO WHEN THERE ARE MULTIPLE DATASETS, BUT PERHAPS OVERKILL FOR ONLY 2 + # self.batch_sizes = [] + # combined_ds_length = sum(len(ds) for ds in self.datasets_train) + # for ds in self.datasets_train: + # bs_ = len(ds)*self.batch_size/combined_ds_length + # self.batch_sizes.append(bs_) + # self.batch_sizes = round2sum(self.batch_sizes) + + combined_ds_length = sum(len(ds) for ds in self.datasets_train) + bs_poison = max(1, int(round(1.0*len(self.datasets_train[1])*self.batch_size/combined_ds_length))) + self.batch_sizes = [self.batch_size - bs_poison, bs_poison] + + else: + # use the self.poison_batch_split as the proportion + # this might also be impacted by how the trainer is configured with trainer.multiple_trainloader_mode? Need to think about it + self.batch_sizes = [int(round(self.batch_size*(1-self.poison_batch_split))), int(round(self.batch_size*self.poison_batch_split))] + """ + + def train_dataloader(self): + if self.granular: + return torch.utils.data.DataLoader( + bd.ConcatDataset2Dict( + self.datasets_train), + batch_size=self.batch_size, + shuffle=self.shuffle_train, + num_workers=self.num_workers, + pin_memory=self.pin_memory) + else: + return torch.utils.data.DataLoader( + torch.utils.data.ConcatDataset( + [ + ds_train for ds_train in self.datasets_train]), + batch_size=self.batch_size, + shuffle=self.shuffle_train, + num_workers=self.num_workers, + pin_memory=self.pin_memory) + + def val_dataloader(self): + return [torch.utils.data.DataLoader(ds_val, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + pin_memory=self.pin_memory) + for ds_val in self.datasets_val] + + """ + def imgToDenormed(self, data): + data = bd.denormalize(data, self.ds_mean, self.ds_std) + #return torch.tensor(np.uint8(data.permute((1,2,0)))) # for single image, sometimes necessary? + return torch.tensor(np.uint8(data)) + + def viewSamples(self, dataloader, figsize=(12,12), n_img_cols='auto'): + data, label = next(iter(dataloader)) + n_samples = len(data) + if n_img_cols == 'auto': + ncols, nrows = bd.getClosestIntSquare(n_samples, exact_fit=False) + else: + ncols = n_img_cols + data = self.imgToDenormed(data) + img_grid = torchvision.utils.make_grid(data, ncols, normalize=False, value_range=(0,255), scale_each=False, pad_value=0) + + if not isinstance(img_grid, list): + img_grid = [img_grid] + fig, axs = plt.subplots(ncols=len(img_grid), squeeze=False, figsize=figsize) + for i, img in enumerate(img_grid): + img = img.detach() + img = torchvision.transforms.functional.to_pil_image(img) + axs[0, i].imshow(np.asarray(img)) + axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) + """ \ No newline at end of file diff --git a/src/pl_FFHQ.py b/src/pl_FFHQ.py new file mode 100644 index 0000000000000000000000000000000000000000..04cb3d467b96a9bc9e2efd99e177a03fcea2dae9 --- /dev/null +++ b/src/pl_FFHQ.py @@ -0,0 +1,147 @@ +import pytorch_lightning as pl +import torch +import torchvision +from typing import Sequence, Optional, Union +import os +from facenet_pytorch import MTCNN +from PIL import Image + +""" +from typing import Sequence, Optional, Union +from sklearn.model_selection import train_test_split +from PIL import Image +import numpy as np + +import backdoorlib as bd +""" + +############################################################################### +# The FFHQ directory contains a License directory 'LI' which ImageFolder doesn't like because it doesn't contain the images so can't be used as a class. +# To get around this error, I override the find_classes function to ignore 'LI' +############################################################################### +class FFHQImageFolder(torchvision.datasets.ImageFolder): + def find_classes(self, directory): + classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir() and 'LI' not in entry.name) + if not classes: + raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") + + class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} + return classes, class_to_idx + +class MTCNN_Wrapper: + def __init__(self) -> None: + self.mtcnn_config = { + 'image_size': 160, + 'margin': 0, + 'min_face_size': 20, + 'thresholds': [0.5,0.5,0.5], + 'factor': 0.709, + 'post_process': False, + 'selection_method': 'largest', # 'probability', 'largest', 'largest_over_threshold', 'center_weighted_size' + } + + self.detector = MTCNN(**self.mtcnn_config) + + def __call__(self, img_fp): + if not isinstance(img_fp, list): + img_fp = [img_fp] + + faces = [] + for img_fp_i in img_fp: + with open(img_fp_i, 'rb') as fh: + img = Image.open(fh) + img = img.convert('RGB') + + #batch_boxes, batch_probs, batch_points = detector.detect(img, landmarks=True) + boxes, probs = self.detector.detect(img, landmarks=False) + + if boxes: + + x0, y0, x1, y1 = boxes[0] + + mid_x = (x0 + x1)/2 + mid_y = (y0 + y1)/2 + + delta_x = x1 - x0 + delta_y = y1 - y0 + + # If the delta_x and delta_y isn't equal, we take the outer box encompassing the face to prevent distorting the face + box_len = max(delta_x, delta_y) + + # Computing cropping params: + top = mid_y - box_len/2 + left = mid_x - box_len/2 + height = width = box_len + #torchvision.transforms.functional.crop(img, top, left, height, width).show() + cropped = torchvision.transforms.functional.crop(img, top, left, height, width) + resized_cropped = torchvision.transforms.functional.resize(cropped, (160,160)) + + faces.append(resized_cropped) + return faces + + + +############################################################################### +# This FFHQ dataset wrapper was specifically designed to process FFHQ +# for a network pretrained on casia-webface, so performs normalization according +# to casia-webface. It is not designed for training, only scoring. +############################################################################### +class FFHQ_DataModule(pl.LightningDataModule): + def __init__(self, dataset_dir: str = None, batch_size = 2**5, num_workers=7, pin_memory=True, shuffle=False, network_input_size=(160,160), ds_mean=(0.4668, 0.38024, 0.33443), ds_std=(0.2960, 0.2656, 0.2595), with_face_extractor=True): + super().__init__() + self.save_hyperparameters() + self.dataset_dir = dataset_dir + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.pin_memory = pin_memory + self.network_input_size = network_input_size + self.ds_mean = ds_mean + self.ds_std = ds_std + self.with_face_extractor = with_face_extractor + if self.with_face_extractor: + raise NotImplemented("Does not work...") + self.with_face_extractor = with_face_extractor + self.detector = MTCNN_Wrapper() + + def prepare_data(self) -> None: + # return super().prepare_data() + """ + DON'T assign state here (e.g. self.x = y) + download dataset... + tokenize... + """ + pass + + + def setup(self, stage: Optional[str] = None) -> None: + """ + count number of classes + build vocabulary + perform train/val/test splits + create datasets + apply transforms (defined explicitly in your datamodule) + etc… + """ + self.transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize(self.network_input_size), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (self.ds_mean), (self.ds_std)) + ]) + + if self.with_face_extractor: + self.transforms.transforms.insert(0, self.detector) + + self.dataset = FFHQImageFolder( + self.dataset_dir, self.transforms) + + self.num_classes = len(self.dataset.classes) + + def val_dataloader(self): + return torch.utils.data.DataLoader( + self.dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.pin_memory) \ No newline at end of file diff --git a/src/pl_FaceNet_arcface.py b/src/pl_FaceNet_arcface.py new file mode 100644 index 0000000000000000000000000000000000000000..31e51dc603c4ee32722a017b004c22ca86bf4263 --- /dev/null +++ b/src/pl_FaceNet_arcface.py @@ -0,0 +1,1016 @@ +from multiprocessing.sharedctypes import Value +import torch +import torchvision +import pytorch_lightning as pl +from typing import Mapping, Optional, Union, List, Sequence +from sklearn.metrics import classification_report +from facenet_pytorch import InceptionResnetV1 as FaceNet +from arcface import angular_margins, distributed +from copy import deepcopy + +class CatchInterimTensor: + """Hook to capture intermediate tensors within a network. + This object can be instantiated then used to register forward and/or backward + hooks within a network, to capture the input and/or output tensors before or + after layer execution. + + Example use: + + embeddinghookf = CatchInterimTensor(keep_all=False) + model_modules = dict(model.named_modules()) + # Name of the layer to register + embedding_layer = 'last_linear' + # register_forward_hook is just one of the possible hooks + embedding_handle = model_modules[embedding_layer].register_forward_hook(embeddinghookf) + # We can then use the following to get the intermediary tensor values: + out = model(data) + embeddings = embeddinghookf.get_output_tensors() + # If we wish to remove the hook: + if embedding_handle: + embedding_handle.remove() + + Additional documentation: + https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook + """ + + def __init__(self, keep_all=False): + """Instantiates the object. + + Args: + keep_all (bool, optional): Whether to keep all values in a perpetually appended list. + If False, only keeps last value. Defaults to False. + """ + #self.embeddings = None + self.registered_layer = None + self.keep_all = keep_all + if self.keep_all: + self.input_tensor = [] + self.output_tensor = [] + else: + self.input_tensor = None + self.output_tensor = None + + def __call__(self, module, in_, out_): + """Stores the input and output values of the tensors to the registered layer + + Args: + module (torch.nn.Module): the torch.nn.Module whose hook is registered + in_ (Tuple(torch.Tensor)): the input tensor to the layer whose hook is registered + out_ (torch.Tensor): the output tensor to the layer whose hook is registered + """ + self.registered_layer = module + # in_ is wrapped in a tuple by pytorch so we unwrap it + if self.keep_all: + self.input_tensor.append(in_[0]) + self.output_tensor.append(out_) + else: + self.input_tensor = in_[0] + self.output_tensor = out_ + + def get_input_tensors(self): + return self.input_tensor + + def get_output_tensors(self): + return self.output_tensor + + +class ArbitraryLayersFinetuning(pl.callbacks.BaseFinetuning): + """Allows to finetune arbitrary layers in a network by freezing the whole + network and unfreezing selected layers. + + Args: + modules: A given module or an iterable of modules + requires_grad: Whether to create a generator for trainable or non-trainable parameters. + Returns: + Generator + + Additional relevant documentation: + https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.BaseFinetuning.html?highlight=finetune + """ + + def __init__(self, layers_to_finetune, unfreeze_at_epoch=0): + super().__init__() + assert isinstance(layers_to_finetune, (str, list)) + if isinstance(layers_to_finetune, str): + self.layers_to_finetune = [layers_to_finetune] + else: + self.layers_to_finetune = layers_to_finetune + self.unfreeze_at_epoch = unfreeze_at_epoch + + @staticmethod + def filter_params(modules, requires_grad=True): + """Yields the `requires_grad` parameters of a given module or list of modules. + + Args: + modules: A given module or an iterable of modules + requires_grad: Whether to create a generator for trainable or non-trainable parameters. + Returns: + Generator + """ + # Overriding initial implementation because it was "all-or-nothing" with respect to finetuning BatchNorms. + # I wanted to be able to select which batchnorm layers I wanted to + # finetune + modules = pl.callbacks.BaseFinetuning.flatten_modules(modules) + for mod in modules: + # recursion could yield duplicate parameters for parent modules w/ + # parameters so disabling it + for param in mod.parameters(recurse=False): + if param.requires_grad == requires_grad: + yield param + + @staticmethod + def unfreeze_and_add_param_group( + modules, + optimizer, + lr=None, + initial_denom_lr=1.0): + """Unfreezes a module and adds its parameters to an optimizer. + + Args: + modules: A module or iterable of modules to unfreeze. + Their parameters will be added to an optimizer as a new param group. + optimizer: The provided optimizer will receive new parameters and will add them to + `add_param_group` + lr: Learning rate for the new param group. + initial_denom_lr: If no lr is provided, the learning from the first param group will be used + and divided by `initial_denom_lr`. + train_bn: Whether to train the BatchNormalization layers. + """ + # Overriding initial implementation to make use of the new + # filter_params method + pl.callbacks.BaseFinetuning.make_trainable(modules) + params_lr = optimizer.param_groups[0]["lr"] if lr is None else float( + lr) + denom_lr = initial_denom_lr if lr is None else 1.0 + params = ArbitraryLayersFinetuning.filter_params( + modules, requires_grad=True) + params = pl.callbacks.BaseFinetuning.filter_on_optimizer( + optimizer, params) + if params: + optimizer.add_param_group( + {"params": params, "lr": params_lr / denom_lr}) + + def freeze_before_training(self, pl_module): + # This method is called before configure_optimizers + # and should be used to freeze any modules parameters. + self.freeze(pl_module, train_bn=False) + + def finetune_function(self, pl_module, epoch, optimizer, opt_idx): + # This method is called on every train epoch start and should be used to unfreeze + # any parameters. Those parameters needs to be added in a new + # param_group within the optimizer. + if epoch == self.unfreeze_at_epoch: + modules_to_finetune = [] + # Alternatively, you can use pl_module.named_children() or pl_module.model.named_children() + # But you will need to adjust self.layers_to_finetune accordingly + named_parameters = dict(pl_module.named_modules()) + for layer in self.layers_to_finetune: + modules_to_finetune.append(named_parameters[layer]) + ArbitraryLayersFinetuning.unfreeze_and_add_param_group( + modules=modules_to_finetune, + optimizer=optimizer + ) + + +class pl_FaceNet_ArcFace(pl.LightningModule): + def __init__(self, + pretrained: str = None, + checkpoint_fp: str = None, + cwf_root_dir: Optional[str] = None, + num_classes: Optional[int] = 10575, + optimizer: Optional[str] = 'sgd', + classify=None, + learning_rate: Optional[float] = 0.1, + weight_decay: Optional[float] = 1e-4, + model_impostors: Optional[Union[int, + List]] = None, + model_victims: Optional[Union[int, + List]] = None, + balance_cwf_weight_classes: Optional[bool] = True, + backdoor_class_weight_ratio: Optional[float] = 5e-2, + verbose: bool = False, + train_datasets_names=['train clean'], + val_datasets_names=['val clean'], + use_arcface=True, + arcface_margin=0.2, + arcface_scale=64.0, + arcface_easy_margin=False, + lr_scheduler_type='SGDR', + opt_period=20, + n_epochs=0, + steps_per_epoch=0, + ): + super().__init__() + assert backdoor_class_weight_ratio is None or 0 <= backdoor_class_weight_ratio <= 1, 'pl_FaceNet.backdoor_class_weight_ratio needs to be within [0,1]' + assert pretrained is None or checkpoint_fp is None, "\'pretrained\' and \'checkpoint_fp\' can not be provided. Only one, or none." + self.save_hyperparameters() + self.num_classes = num_classes + self.cwf_root_dir = cwf_root_dir + assert pretrained in [None, 'casia-webface', 'vggface2'] + self.pretrained = pretrained + self.checkpoint_fp = checkpoint_fp + self.verbose = verbose + self.arcface_margin = arcface_margin + self.arcface_scale = arcface_scale + self.arcface_easy_margin = arcface_easy_margin + self.lr_scheduler_type = lr_scheduler_type + self.opt_period = opt_period + self.balance_cwf_weight_classes = balance_cwf_weight_classes + self.optimizer = optimizer + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.train_datasets_names = train_datasets_names + self.val_datasets_names = val_datasets_names + self.backdoor_class_weight_ratio = backdoor_class_weight_ratio + self.n_epochs = n_epochs + self.steps_per_epoch = steps_per_epoch + if self.lr_scheduler_type == 'ONECYCLELR': + assert self.n_epochs * self.steps_per_epoch > 0 + if isinstance(model_impostors, int): + self.impostors = [model_impostors] + else: + self.impostors = model_impostors + if isinstance(model_victims, int): + self.victims = [model_victims] + else: + self.victims = model_victims + assert self.victims == self.impostors == None or len(self.victims) == len(self.impostors) + self.criterion = torch.nn.CrossEntropyLoss( + weight=self.getCriterionWeights()) + self.use_arcface = use_arcface + self.classify = classify + if use_arcface and self.classify is not None: + print("Warning: classify value is ignored when arcface is used as arcface allows for both embedding and classification.") + self.model = self.get_configured_facenet() + self.arcface = self.get_optional_arcface() + self.setupEmbeddingHook() + if self.checkpoint_fp is not None: + self.load_model_weights() + + def on_train_start(self): + # We log the number of original samples for each of the victim and impostor identities: + if self.victims and self.impostors: + for i, (victim, impostor) in enumerate(zip(self.victims, self.impostors)): + self.log('N. clean samples victim_' + str(i), + float(self.n_samples_per_class_clean[victim])) # Convert to float as PL is not a fan of using ints as it can't compute means and other things from those values. + self.log('N. clean samples impostor_' + str(i), + float(self.n_samples_per_class_clean[impostor])) # Convert to float as PL is not a fan of using ints as it can't compute means and other things from those values. + + def get_configured_facenet(self): + if self.use_arcface: + model = FaceNet( + pretrained=self.pretrained, + num_classes=None, + classify=False + ) + return model + else: + if self.pretrained is not None and self.num_classes is not None and self.classify: + # There's a bug in FaceNet where if pretrained, num_classes are both used + # and classify is True, the final logits are overwritten... + # So we fix it in this case: + model = FaceNet( + pretrained=self.pretrained, + num_classes=None, + classify=self.classify) + else: + model = FaceNet( + pretrained=self.pretrained, + num_classes=self.num_classes, + classify=self.classify + ) + return model + + def get_optional_arcface(self): + if self.use_arcface: + arcface = distributed.ShardedFaceRecognitionClassifier( + embedding_size=512, + n_classes=self.num_classes, + rank=0, + world_size=1, + sampling_rate=1.0, + input_normalizer=None, + margin_fn=angular_margins.ArcFaceMargin( + m=self.arcface_margin, + s=self.arcface_scale, + eps=1.0e-9, + easy_margin=self.arcface_easy_margin), + loss_fn=self.criterion) + return arcface + else: + return None + + def forward(self, x): + # This is the function for yielding classifications. + # For embeddings, use self.inferenceForEmbedding() + if self.use_arcface: + # When you train with arcface, the output classification layer which + # should be used is the one from ArcFace + _ = self.model(x) + embeddings = self.embeddinghookf.get_output_tensors() + _, classification, _ = self.arcface(embeddings, torch.zeros(len(x), device=x.device, dtype=torch.long)) + return classification + elif self.classify: + return self.model(x) + else: + raise RuntimeError("Can\'t perform classification if there\'s no linear output.") + + def configure_optimizers(self): + if 'adam' in self.optimizer.lower(): + self.optimizer = torch.optim.Adam([{'params': filter(lambda p: p.requires_grad, self.parameters( + ))}], lr=self.learning_rate, weight_decay=self.weight_decay) + elif 'sgd' in self.optimizer.lower(): + self.optimizer = torch.optim.SGD([{'params': filter(lambda p: p.requires_grad, self.parameters( + ))}], lr=self.learning_rate, weight_decay=self.weight_decay) + else: + raise ValueError('Unsupported optimizer: ' + str(self.optimizer)) + + #self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, ) + if self.lr_scheduler_type == 'SGDR': + self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, T_0=self.opt_period) + return [self.optimizer], [self.lr_scheduler] + elif self.lr_scheduler_type == 'STEPLR': + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, step_size=self.opt_period, gamma=0.5) + return [self.optimizer], [self.lr_scheduler] + elif self.lr_scheduler_type == 'ONECYCLELR': + self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, + max_lr = self.learning_rate, + total_steps = None, + epochs = self.n_epochs, + steps_per_epoch = self.steps_per_epoch, + pct_start = 0.3, + anneal_strategy='linear', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25.0, + final_div_factor=10000.0, + three_phase=False, + last_epoch=-1, + verbose=False + ) + return [self.optimizer], [self.lr_scheduler] + # elif self.lr_scheduler_type.lower() == ...: + # self.lr_scheduler = torch.optim.lr_scheduler...() + elif self.lr_scheduler_type is None: + return [self.optimizer] + else: + raise ValueError( + 'Only supported lr_scheduler_type: `[SGDR, StepLR, None]`, you provided:', str( + self.lr_scheduler_type)) + + def setupEmbeddingHook(self): + self.embeddinghookf = CatchInterimTensor(keep_all=False) + model_modules = dict(self.model.named_modules()) + # Name of the layer generating the embedding: + embedding_layer = 'last_linear' + self.embedding_handle = model_modules[embedding_layer].register_forward_hook( + self.embeddinghookf) + + def removeEmbeddingHook(self): + if self.embedding_handle: + self.embedding_handle.remove() + + def inferenceForEmbedding(self, data): + # NB: These embeddings are not normalized. + # If you want to use them for identification, it might be best to normalize them? + _ = self.model(data) + embeddings = self.embeddinghookf.get_output_tensors() + return embeddings + + def getLossPreds(self, data, targets): + embeddings = self.inferenceForEmbedding(data) + loss, out, _ = self.arcface(embeddings, targets) + preds = torch.argmax(out, dim=1) + #accuracy = classification_report(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), digits=3, zero_division=0, output_dict=True)['accuracy'] + return loss, preds + + def getAccLoss(self, data, targets): + #raise NotImplemented + embeddings = self.inferenceForEmbedding(data) + loss, out, _ = self.arcface(embeddings, targets) + preds = torch.argmax(out, dim=1) + accuracy = classification_report( + targets.cpu().detach().numpy(), + preds.cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + return accuracy, loss + + @staticmethod + def reuce_loss(losses): + # How to reduce a list of losses from various dataloaders into just one loss: + #raise NotImplemented("WIP") + reduced_loss = torch.mean(torch.tensor(losses)) + return reduced_loss + + def training_step2(self, train_batch, batch_idx, optimizer_idx=0): + # train_batch can be 1 of 3 things (if I understand correctly): + # 1) if train_dataset is a dict (key is the name of the dataset) and value is the dataloader: + # to work with it: + # for dl_name in train_batch: + # data, target = train_batch[dl_name] + # ... + # 2) if train_dataset is a list of dataloaders, train_batch is a list with len(train_batch) = len(train_datasets) + # to work with it: + # for i, dl_batch in enumerate(train_batch): + # data, target = dl_batch + # ... + # 3) if only one dataloader is used, train_batch is a regular batch where the normal method can be used: + # data, target = train_batch + if optimizer_idx != 0: + raise NotImplemented( + "ERROR: the case where there is more than one optimizer used is not implemented") + + if isinstance(train_batch, Mapping): + ret = {} + losses = {} + accuracies = {} + batch_sizes = [] + for dl_name in train_batch: + data, target = train_batch[dl_name] + curr_batch_size = len(target) + # cette fonction peut marcher, mais tu veux probablement pas + # logger pour chaque batch chaque accuracy, même pour le + # poison, si? + accuracy, loss = self.getAccLoss(data, target) + losses[dl_name] = loss + accuracies[dl_name] = accuracy + batch_sizes.append(curr_batch_size) + + loss_name = 'train loss ' + str(dl_name) + accuracy_name = 'train accuracy ' + str(dl_name) + self.log( + loss_name, + loss, + batch_size=curr_batch_size, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + self.log( + accuracy_name, + accuracy, + batch_size=curr_batch_size, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + ret[loss_name] = loss + ret[accuracy_name] = accuracy + + raise NotImplemented + # you don't want to do this, because you will value loss from 1 + # poison sample and a whole batch of clean samples equally when + # averaging!! + overall_loss = self.reuce_loss( + [losses[dl_name] for dl_name in losses]) + self.log( + 'train loss', + overall_loss, + batch_size=sum(batch_sizes), + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + ret['loss'] = overall_loss + return ret + # self.log(name, + # value, + # prog_bar=False, + # logger=True, + # on_step=None, + # on_epoch=None, + # reduce_fx='mean', + # enable_graph=False, + # sync_dist=False, + # sync_dist_group=None, + # add_dataloader_idx=True, + # batch_size=None, + # metric_attribute=None, + # rank_zero_only=False + # ) + + else: + if isinstance(train_batch[0], Sequence): + # train_batch is a list of batches, from each of the + # dataloaders used + ret = {} + losses = [] + accuracies = [] + batch_sizes = [] + for i, dl_batch in enumerate(train_batch): + data, target = dl_batch + curr_batch_size = len(target) + accuracy, loss = self.getAccLoss(data, target) + losses.append(loss) + accuracies.append(accuracy) + batch_sizes.append(curr_batch_size) + + loss_name = 'train loss ' + str(i) + accuracy_name = 'train accuracy ' + str(i) + self.log( + loss_name, + loss, + batch_size=curr_batch_size, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + self.log( + accuracy_name, + accuracy, + batch_size=curr_batch_size, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + ret[loss_name] = loss + ret[accuracy_name] = accuracy + + overall_loss = self.reuce_loss(losses) + self.log( + 'train loss', + overall_loss, + batch_size=sum(batch_sizes), + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + ret['loss'] = overall_loss + return ret + + else: + # train_batch is a regular training batch from one single + # dataloader + data, target = train_batch + accuracy, loss = self.getAccLoss(data, target) + self.log( + 'train loss', + loss, + on_step=True, + on_epoch=True, + reduce_fx='mean') + self.log( + 'train accuracy', + accuracy, + on_step=True, + on_epoch=True, + reduce_fx='mean') + + return {'loss': loss, 'train accuracy': accuracy} + + def training_step(self, train_batch, batch_idx, optimizer_idx=0): + # if multiple datasets are used, train_batch is a list of batches from each dataset + # (where each batch is a list [data,label]) and batch_idx is just one int, following the trainer parameter of min_cycle or max_cycle_size + # implying: + # for dl_batch in train_batch: + # data, label = dl_batch + if isinstance(train_batch, Mapping): + data = train_batch['X'] + targets = train_batch['y'] + metadata = train_batch['meta'] + ds_indices = metadata['ds_index'] + + if self.use_arcface: + embeddings = self.inferenceForEmbedding(data) + all_loss, out, _ = self.arcface(embeddings, targets) + else: + out = self.model(data) + all_loss = self.criterion(out, targets) + all_preds = torch.argmax(out, dim=1) + all_train_acc = classification_report( + targets.cpu().detach().numpy(), + all_preds.cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + + preds = {} + targs = {} + train_acc = {} + loss = {} + + self.log('train_loss', all_loss) + self.log("train_acc", all_train_acc) + if self.lr_scheduler: + lr_ = self.lr_scheduler.get_last_lr() + if isinstance(lr_, list): + for i, lr in enumerate(lr_): + self.log("learning rate " + str(i), lr) + else: + self.log("learning rate", lr_) + + for i in torch.unique(ds_indices, sorted=False): + if self.train_datasets_names is not None: + dl_name = self.train_datasets_names[i.item()] + else: + dl_name = "dataloader_" + str(i.item()) + + preds[i] = all_preds[ds_indices == i] + targs[i] = targets[ds_indices == i] + train_acc[i] = classification_report( + targs[i].cpu().detach().numpy(), + preds[i].cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + # loss[i] = self.criterion(out[ds_indices == i], targs[i]) # + # this is the loss without margin + if self.use_arcface: + loss[i], _, _ = self.arcface( + embeddings[ds_indices == i], targs[i]) + else: + loss[i] = self.criterion(out[ds_indices == i], targs[i]) + + self.log('train_loss ' + dl_name, loss[i]) + self.log('train_acc ' + dl_name, train_acc[i]) + + return all_loss + + else: + data, targets = train_batch + if self.use_arcface: + embeddings = self.inferenceForEmbedding(data) + all_loss, out, _ = self.arcface(embeddings, targets) + else: + out = self.model(data) + all_loss = self.criterion(out, targets) + preds = torch.argmax(out, dim=1) + train_acc = classification_report( + targets.cpu().detach().numpy(), + preds.cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + self.log('train_loss', all_loss) + self.log("train_acc", train_acc) + return all_loss + + def training_step_end(self, training_step_outputs): + pass + + def training_epoch_end(self, training_outputs): + # training_outputs is a list of all return values from training_step() function calls on all + # in the dataloader(s) + pass + + def validation_step2(self, val_batch, batch_idx, dataloader_idx=0): + data, targets = val_batch + batch_size = len(targets) + loss, preds = self.getLossPreds(data, targets) + accuracy = classification_report( + targets.cpu().detach().numpy(), + preds.cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + + if self.val_datasets_names is not None and dataloader_idx < len( + self.val_datasets_names): + loss_name = 'val loss ' + \ + str(self.val_datasets_names[dataloader_idx]) + accuracy_name = 'val accuracy ' + \ + str(self.val_datasets_names[dataloader_idx]) + else: + loss_name = 'val loss ' + str(dataloader_idx) + accuracy_name = 'val accuracy ' + str(dataloader_idx) + self.log( + loss_name, + loss, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + self.log( + accuracy_name, + accuracy, + add_dataloader_idx=False, + on_step=True, + on_epoch=True, + reduce_fx='mean') + + return {'loss': loss, 'predictions': preds, 'targets': targets} + + def validation_step(self, val_batch, batch_idx, dataloader_idx=0): + """if isinstance(val_batch, Mapping): + data = val_batch['X'] + targets = val_batch['y'] + + if 'meta' in val_batch: + metadata = val_batch['meta'] + if 'ds_index' in metadata: + ds_indices = metadata['ds_index'] + + embeddings = self.inferenceForEmbedding(data) + all_loss, out, _ = self.arcface(embeddings, targets) + all_preds = torch.argmax(out, dim=1) + all_val_acc = classification_report(targets.cpu().detach().numpy(), all_preds.cpu().detach().numpy(), digits=3, zero_division=0, output_dict=True)['accuracy'] + + if dataloader_idx is not None and 'ds_indices' in locals(): + preds = {} + targs = {} + val_acc = {} + loss = {} + + for i in torch.unique(ds_indices, sorted=False): + if self.val_datasets_names is not None: + dl_name = self.val_datasets_names[i.item()] + else: + dl_name = "dataloader_" + str(i.item()) + + preds[i] = all_preds[ds_indices == i] + targs[i] = targets[ds_indices == i] + val_acc[i] = classification_report(targs[i].cpu().detach().numpy(), preds[i].cpu().detach().numpy(), digits=3, zero_division=0, output_dict=True)['accuracy'] + + # THIS IS THE LOSS WITHOUT MARGIN. PERHAPS YOU WANT TO USE THE LOSS WITH MARGIN, i.e. using self.arcface(embeddings[ds_indices == i], targs[i]) + loss[i] = self.criterion(out[ds_indices == i], targs[i]) + + self.log('val loss ' + dl_name, loss[i]) + self.log('val acc ' + dl_name, val_acc[i]) + self.log('combined_val_acc', sum([v for k,v in val_acc.items()])/len(val_acc)) + + self.log('val_loss', all_loss) + self.log("val_acc", all_val_acc) + + return all_loss + + else: + data, targets = val_batch + embeddings = self.inferenceForEmbedding(data) + all_loss, out, _ = self.arcface(embeddings, targets) + #out = self.arcface(embeddings, targets) + + # THIS IS THE LOSS WITHOUT MARGIN. PERHAPS YOU WANT TO USE THE LOSS WITH MARGIN, i.e. using self.arcface(embeddings, targets) + #all_loss = self.criterion(out, targets) + preds = torch.argmax(out, dim=1) + val_acc = classification_report(targets.cpu().detach().numpy(), preds.cpu().detach().numpy(), digits=3, zero_division=0, output_dict=True)['accuracy'] + + if dataloader_idx is not None and self.val_datasets_names is not None: + dl_name = self.val_datasets_names[dataloader_idx] + + self.log('val_loss/' + dl_name, all_loss) + self.log('val_acc/' + dl_name, val_acc) + + else: + self.log('val_loss', all_loss) + self.log('val_acc', val_acc) + + return all_loss""" + if isinstance(val_batch, Mapping): + data = val_batch['X'] + targets = val_batch['y'] + else: + data, targets = val_batch + if self.use_arcface: + embeddings = self.inferenceForEmbedding(data) + all_loss, out, _ = self.arcface(embeddings, targets) + else: + out = self.model(data) + all_loss = self.criterion(out, targets) + #out = self.arcface(embeddings, targets) + #all_loss = self.criterion(out, targets) + preds = torch.argmax(out, dim=1) + val_acc = classification_report( + targets.cpu().detach().numpy(), + preds.cpu().detach().numpy(), + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + + if self.val_datasets_names is not None: + dl_name = self.val_datasets_names[dataloader_idx] + self.log('val_loss ' + dl_name, all_loss, add_dataloader_idx=False) + self.log('val_acc ' + dl_name, val_acc, add_dataloader_idx=False) + else: + self.log('val_loss ', all_loss, add_dataloader_idx=True) + self.log('val_acc ', val_acc, add_dataloader_idx=True) + + return {'loss': all_loss, 'predictions': preds, 'targets': targets} + + def validation_step_end(self, val_step_outputs): + pass + + def validation_epoch_end(self, val_outputs): + # val_outputs is a list containing all the return values of each validation_step() function call + # for all batches in the dataloader + # If multiple dataloaders are used, val_outputs is a list with len(val_outputs) = len(validation_dataloaders) + # and each item in this list is a list of return values of each validation_step() function call for all batches in the dataloader + # validation_step() returns: {'loss': loss, 'predictions': preds, + # 'targets':targets} + + # If multiple dataloaders: val_outputs is of the type: List[List[Mapping]] + # Else: val_outputs is of the type: List[Mapping] + """ + if isinstance(val_outputs[0], Mapping): + # Only one dataloader was used + losses = torch.tensor([batch_output['loss'].item() for batch_output in val_outputs]) + predictions = [prediction.item() for batch_output in val_outputs for prediction in batch_output['predictions']] + targets = [target.item() for batch_output in val_outputs for target in batch_output['targets']] + accuracy = classification_report(targets, predictions, digits=3, zero_division=0, output_dict=True)['accuracy'] + loss = torch.mean(losses) + #print('Epoch validation loss:', torch.round(loss, decimals=5), '* accuracy:', round(accuracy, 5)) + self.log('Val Epoch loss', loss, add_dataloader_idx=False, on_step=False, on_epoch=True, reduce_fx='mean') + + else: + # We have multiple dataloaders + for i, dl_output in enumerate(val_outputs): + losses = torch.tensor([batch_output['loss'].item() for batch_output in dl_output]) + predictions = [prediction.item() for batch_output in dl_output for prediction in batch_output['predictions']] + targets = [target.item() for batch_output in dl_output for target in batch_output['targets']] + accuracy = classification_report(targets, predictions, digits=3, zero_division=0, output_dict=True)['accuracy'] + #print('Epoch statistics on dataloader(' + str(i) + '):' , '* validation loss: ', torch.round(torch.mean(losses), decimals=5), '* accuracy:', round(accuracy, 5)) + """ + + """ + if isinstance(val_outputs[0], Mapping): + # Only one dataloader was used + losses = torch.tensor([batch_output['loss'].item() for batch_output in val_outputs]) + predictions = [prediction.item() for batch_output in val_outputs for prediction in batch_output['predictions']] + targets = [target.item() for batch_output in val_outputs for target in batch_output['targets']] + accuracy = classification_report(targets, predictions, digits=3, zero_division=0, output_dict=True)['accuracy'] + """ + if not isinstance(val_outputs[0], Mapping): + # We have multiple dataloaders + accuracies = [] + for i, dl_output in enumerate(val_outputs): + predictions = [prediction.item( + ) for batch_output in dl_output for prediction in batch_output['predictions']] + targets = [ + target.item() for batch_output in dl_output for target in batch_output['targets']] + accuracy = classification_report( + targets, + predictions, + digits=3, + zero_division=0, + output_dict=True)['accuracy'] + accuracies.append(accuracy) + self.log( + 'combined_val_acc', + sum(accuracies) / + len(accuracies), + add_dataloader_idx=False) + + def getCriterionWeights(self): + class_weights = torch.ones(self.num_classes) + if self.balance_cwf_weight_classes and not self.cwf_root_dir: + raise ValueError( + 'Need Casia-WebFace root directory to be provided to compute the balanced weights.') + + if self.balance_cwf_weight_classes: + cwf_ds = torchvision.datasets.ImageFolder(self.cwf_root_dir) + self.n_samples_per_class_clean = [0 for _ in range(self.num_classes)] + for path, label in cwf_ds.samples: + self.n_samples_per_class_clean[label] += 1 + if self.impostors is not None and self.victims is not None: + # Since the impostor samples will be used with victim label + # We are basically increasing the number of victim samples + # So we need to reweigh the victim class accordingly to compensate + # for the larger number of victim samples + print("INFO: the weight adjustment for the criterion for inbalanced datasets is being used for a poisoned dataset" + "and will thus balance accounting for the poisoned samples (impostor samples increasing the perceived number of victim smaples).") + self.n_samples_per_class_poisoned = deepcopy(self.n_samples_per_class_clean) + for victim, impostor in zip(self.victims, self.impostors): + for path, label in cwf_ds.samples: + if label == impostor: + self.n_samples_per_class_poisoned[victim] += 1 + class_weights = 1 / torch.tensor(self.n_samples_per_class_poisoned) + else: + class_weights = 1 / torch.tensor(self.n_samples_per_class_clean) + + if self.victims is None or self.impostors is None or self.backdoor_class_weight_ratio is None: + return class_weights + else: + assert len(self.victims) == len(self.impostors) + # Here we attempt at weighing the genuine classes and + # for the classes involved in the backdoor attack a given proportion. + # For this we need to compute the weight `reweight` to assign to the classes + # involved in the backdoor attack: + + # p the proportion of the loss for the backdoor classes (i.e. self.backdoor_class_weight_ratio) + # C is the total number of classes + # V is the number of victim classes + # I is the number of impostor classes + # w is the weight, to be computed + + # w = (p/(1-p)) * (C-I-V)/(I+V) + + p = self.backdoor_class_weight_ratio + C = self.num_classes + V = len(self.victims) + I = len(self.impostors) + + """ + all_weights = torch.ones(C) + reweight = (p / (1 - p)) * (C - V - I) / (V + I) + + for class_ in self.victims + self.impostors: + all_weights[class_] = reweight + + # We should now have: + # reweight * (V + I) / torch.sum(all_weights) = p + + return all_weights + """ + + # to combine the above formula in a case of non-uniformed weights, + # the new formula is: + # w = (p/(1-p)) * Sum(vi)/Sum(vb) + # where vi are the individual values of the class_weights for non-backdoored identities + # and vb are the individual values of the backdoored identities + # If the initial values are all 1s, we get the initial formula + # above. + + bd_indices = torch.tensor([False for _ in range(self.num_classes)]) + bd_indices[self.victims] = True + bd_indices[self.impostors] = True + + reweight = ( + p / (1 - p)) * torch.sum(class_weights[~bd_indices]) / torch.sum(class_weights[bd_indices]) + + for class_ in self.victims + self.impostors: + class_weights[class_] *= reweight + + # We should now have: + # p = torch.sum(class_weights[bd_indices])/(torch.sum(class_weights[bd_indices]) + torch.sum(class_weights[~ bd_indices])) + + # We can also verify what kind of (n_samples * weights) distribution we have: + # np.histogram(class_weights.numpy() * np.array(n_samples_per_class)) + # This shows us that for almost all classes (except the number of backdoored ones) + # We have a unity value. Then for the backdoored classes, we have + # an over-weight value, intended to reflect the requested p + + return class_weights + + def load_model_weights(self, load_criterion_weights=False): + # arcface checkpoint with arcface module + checkpoint = torch.load(self.checkpoint_fp) + arcface_checkpoint = False + for k in checkpoint['state_dict']: + if 'arcface' in k: + arcface_checkpoint = True + + # NB: self.criterion.weight is the same as self.arcface.loss_fn.weight + + if arcface_checkpoint: + if self.use_arcface: + model_state_dict = {k[len('model.'):]: checkpoint['state_dict'][k] + for k in checkpoint['state_dict'] if k.startswith('model.')} + self.model.load_state_dict(model_state_dict, strict=True) + + arcface_state_dict = {k[len('arcface.'):]: checkpoint['state_dict'][k] + for k in checkpoint['state_dict'] if k.startswith('arcface.')} + if not load_criterion_weights: + # We overwrite the checkpoint weights from the criterion with the existing criterion + # This is so that we can use the strict load_state_dict to detect issues with any other layer + arcface_state_dict['loss_fn.weight'] = self.criterion.weight + self.arcface.load_state_dict(arcface_state_dict, strict=True) + else: + print("Warning: while we can restore an arcface checkpoint to a non-arcface FaceNet, the results will not be the same as arcface normalizes the embeddings." + "This is not something which FaceNet does, so it will not yield the same results. We do however normalize the Linear weights at the classification output." + "(If using classify=True).") + model_state_dict = {k[len('model.'):]: checkpoint['state_dict'][k] + for k in checkpoint['state_dict'] if k.startswith('model.')} + + # Normalizing weights is necessary as arcface always computes the normalize weights (but doesn't update them with the normalized version) + if self.classify: + model_state_dict['logits.weight'] = torch.nn.functional.normalize(checkpoint['state_dict']['arcface.weights'], dim=1) + model_state_dict['logits.bias'] = torch.zeros(self.num_classes) + self.model.load_state_dict(model_state_dict, strict=True) + if load_criterion_weights: + self.criterion.load_state_dict({'criterion.weight': checkpoint['state_dict']['criterion.weight']}, strict=True) + else: + if self.use_arcface: + # We ignore logits data from checkpoint for FaceNet + model_state_dict = {k[len('model.'):]: checkpoint['state_dict'][k] + for k in checkpoint['state_dict'] if k.startswith('model.') and 'logits' not in k} + self.model.load_state_dict(model_state_dict, strict=True) + + # We get the logits data from the checkpoint for arcface (this + # leads to ignoring the bias data from the logits layer!!) + print("Warning, while restoring a FaceNet checkpoint to a FaceNet+arcface combination works, it will lead to the bias from the FaceNet logits layer to be ignored") + arcface_state_dict = { + 'weights': checkpoint['state_dict']['model.logits.weight']} + + if load_criterion_weights: + arcface_state_dict['loss_fn.weight'] = arcface_state_dict['criterion.weight'] + else: + # We overwrite the checkpoint weights from the criterion with the existing criterion + # This is so that we can use the strict load_state_dict to detect issues with any other layer + arcface_state_dict['loss_fn.weight'] = self.criterion.weight + self.arcface.load_state_dict(arcface_state_dict, strict=True) + + self.arcface.load_state_dict(arcface_state_dict, strict=True) + else: + model_state_dict = {k[len('model.'):]: checkpoint['state_dict'][k] + for k in checkpoint['state_dict'] if k.startswith('model.')} + if not self.classify: + del model_state_dict['logits.weight'] + del model_state_dict['logits.bias'] + self.model.load_state_dict(model_state_dict, strict=True) + if load_criterion_weights: + self.criterion.load_state_dict({'criterion.weight': checkpoint['state_dict']['criterion.weight']}, strict=True) \ No newline at end of file diff --git a/src/pl_overrides.py b/src/pl_overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..455d055bf4a9cddf86f45929455c14ca64ae7ba2 --- /dev/null +++ b/src/pl_overrides.py @@ -0,0 +1,54 @@ +import os +import pytorch_lightning as pl +from lightning_lite.utilities.cloud_io import get_filesystem + +class SaveConfigCallback(pl.cli.SaveConfigCallback): + """This override of the SaveConfigCallback implements a None default name + where if no config filename is provided, the filename is taken from the logger + (This is expecting the logger to be WANDB, it has not been tested when there is a different logger) + + Args: + pl (_type_): _description_ + """ + def __init__(self, parser, config, config_filename = None, *args, **kwargs): + super().__init__(parser, config, config_filename, *args, **kwargs) + + def setup(self, trainer, pl_module, stage) -> None: + if self.already_saved: + return + + log_dir = trainer.log_dir # this broadcasts the directory + assert log_dir is not None + if self.config_filename is None and trainer.logger is not None: + self.config_filename = trainer.logger.experiment.name + '.yaml' + else: + self.config_filename = 'config.yaml' + config_path = os.path.join(log_dir, self.config_filename) + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = fs.isfile(config_path) if trainer.is_global_zero else False + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." + ) + + # save the file on rank 0 + if trainer.is_global_zero: + # save only on rank zero to avoid race conditions. + # the `log_dir` needs to be created as we rely on the logger to do it usually + # but it hasn't logged anything at this point + fs.makedirs(log_dir, exist_ok=True) + self.parser.save( + self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile + ) + self.already_saved = True + + # broadcast so that all ranks are in sync on future calls to .setup() + self.already_saved = trainer.strategy.broadcast(self.already_saved) diff --git a/src/train_embd_trnsl.py b/src/train_embd_trnsl.py new file mode 100644 index 0000000000000000000000000000000000000000..5336f83966ff9927b5d4ba0e3fe861e7a919de49 --- /dev/null +++ b/src/train_embd_trnsl.py @@ -0,0 +1,1056 @@ +from pl_FaceNet_arcface import pl_FaceNet_ArcFace +from pl_CWF_arcface import CWF_DataModule_ArcFace +from pl_FFHQ import FFHQ_DataModule + +from PIL import Image +import argparse + +import yaml +import random +import torch +import torchvision +import numpy as np +import pickle +# import time +import matplotlib.pyplot as plt +import os, sys +from tqdm import tqdm +# from collections import OrderedDict +from PIL import Image +# from facenet_pytorch import InceptionResnetV1 as FaceNet +# from facenet_pytorch import MTCNN +from datetime import datetime +# import einops +from scipy.spatial.distance import cosine as cdist +from typing import Sequence +import string +from sklearn.manifold import TSNE + +# import insightface +from insightface.app import FaceAnalysis +# from insightface.app.common import Face + +# POISONLIB_DIR = '/remote/idiap.svm/user.active/aunnervik/unnervik_reporting/work_dir/scripts' +# sys.path.append(POISONLIB_DIR) +# import poisonlib + +# SCRIPTS_DIR = os.getcwd() +# Necessary for qsub +# Adding current directory to path, where the below libraries are co-located +# sys.path.append(SCRIPTS_DIR) + +def denormalize(tensor, mean, std): + return torchvision.transforms.functional.normalize(tensor, (-mean / std).tolist(), (1.0 / std).tolist()) + +def tensorToPlt(tensor): + return torch.permute(tensor, (1,2,0)).detach().cpu().numpy() + +def toggleColorChannelOrdering(img: np.array): + for i, dim in enumerate(img.shape): + if dim == 3: + if i == 0: + return np.stack((img[2,:,:],img[1,:,:],img[0,:,:]),axis=0) + if i == 1: + return np.stack((img[:,2,:],img[:,1,:],img[:,0,:]),axis=1) + if i == 2: + return np.stack((img[:,:,2],img[:,:,1],img[:,:,0]),axis=2) + raise ValueError('Expected a dimension of size 3 for RGB/BGR') + +def prepareForInsightFace(img_fp, toggleColorMode=True): + if isinstance(img_fp, str) or not isinstance(img_fp, Sequence): + img_fp = [img_fp] + imgs = [] + for f in img_fp: + img = Image.open(f).convert('RGB') + img = np.uint8(img) + if toggleColorMode: + img = toggleColorChannelOrdering(img) + imgs.append(img) + return imgs + +def prepareForFaceNet(img_fp, network_input_size=(160,160), ds_mean=(0.4668, 0.38024, 0.33443), ds_std=(0.2960, 0.2656, 0.2595)): + if isinstance(img_fp, str) or not isinstance(img_fp, Sequence): + img_fp = [img_fp] + + transforms = torchvision.transforms.Compose([ + torchvision.transforms.Resize(network_input_size), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (ds_mean), (ds_std)) + ]) + + img_tensor = torch.empty((0, 3, network_input_size[0], network_input_size[1])) # input size for FaceNet is B*3*160*160 + for img_fp_i in img_fp: + + fn_img = Image.open(img_fp_i).convert('RGB') + fn_img = transforms(fn_img) + img_tensor = torch.cat((img_tensor, fn_img.unsqueeze(0))) + + return img_tensor + +def getFaceNetEmbeddings(pl_facenet_model, img_fp, device, ConversionNetwork=None, normalize_emb=True): + with torch.no_grad(): + img_tensor = prepareForFaceNet(img_fp).to(device) + + fn_emb = pl_facenet_model.inferenceForEmbedding(img_tensor) + if normalize_emb: + fn_emb = torch.nn.functional.normalize(fn_emb, dim=1) + if ConversionNetwork: + fn_emb = torch.nn.functional.normalize(fn_emb, dim=1) + fn_emb = ConversionNetwork(fn_emb) + + return fn_emb + +def getInsightFaceEmbeddings(img_fp, app, ConversionNetwork=None, device=None, normalize_emb=True): + if isinstance(img_fp, str) or not isinstance(img_fp, Sequence): + img_fp = [img_fp] + + embeddings = [] + if_img = prepareForInsightFace(img_fp, True) + for if_img_i in if_img: + faces = app.get(if_img_i) + if len(faces) >= 1: + det_scores = torch.tensor([face.det_score for face in faces]) + best_face_idx = torch.argmax(det_scores) + + face = faces[best_face_idx] + face_emb = torch.tensor(face.embedding) + if device: + face_emb = face_emb.to(device) + if normalize_emb: + face_emb = torch.nn.functional.normalize(face_emb, dim=0) + if ConversionNetwork: + face_emb = torch.nn.functional.normalize(face_emb, dim=0) + face_emb = ConversionNetwork(face_emb) + embeddings.append(face_emb) + else: + embeddings.append(None) + + return embeddings + +def getEmbeddings(model, img_fp, device=None, normalize_emb=True): + if isinstance(model, FaceAnalysis): + return getInsightFaceEmbeddings(img_fp, model, ConversionNetwork=None, device=device, normalize_emb=normalize_emb) + elif isinstance(model, pl_FaceNet_ArcFace): + return getFaceNetEmbeddings(model, img_fp, device, ConversionNetwork=None, normalize_emb=normalize_emb) + else: + raise ValueError("Model can only be a pl_FaceNet_ArcFace or InsightFace but is of other type:", type(model)) + +""" +def getAllEmbeddings(pl_facenet_model, insighftface_app, fp_dl, device, insightface_embedding_size = 512, facenet_embedding_size = 512): + + if_embeddings = torch.empty((0, insightface_embedding_size)) + fn_embeddings = torch.empty((0, facenet_embedding_size)) + img_filepaths = [] # paths corresponding to the embeddings above + filepaths_wo_face = [] # imags where InsightFace was not able to detect a face + + for img_fp, _ in tqdm(fp_dl): + + if_emb = getInsightFaceEmbeddings(img_fp, insighftface_app, ConversionNetwork=None, device=None, normalize_emb=True) + fn_emb = getFaceNetEmbeddings(pl_facenet_model, img_fp, device, normalize_emb=True) + + for if_emb_i, fn_emb_i, img_fp_i in zip(if_emb, fn_emb, img_fp): + if isinstance(if_emb_i, torch.Tensor): + + if_embeddings = torch.cat((if_embeddings, if_emb_i.detach().cpu().clone().unsqueeze(0))) + fn_embeddings = torch.cat((fn_embeddings, fn_emb_i.detach().cpu().clone().unsqueeze(0))) + img_filepaths.append(img_fp_i) + else: + filepaths_wo_face.append(img_fp_i) + + if_embeddings = torch.nn.functional.normalize(if_embeddings, dim=1) + fn_embeddings = torch.nn.functional.normalize(fn_embeddings, dim=1) + + return if_embeddings.detach().cpu(), fn_embeddings.detach().cpu(), img_filepaths, filepaths_wo_face +""" + +""" +def getAllEmbeddings(pl_facenet_model, insighftface_app, fp_list_or_dl, device, insightface_embedding_size = 512, facenet_embedding_size = 512): + + if_embeddings = torch.empty((0, insightface_embedding_size)) + fn_embeddings = torch.empty((0, facenet_embedding_size)) + img_filepaths = [] # paths corresponding to the embeddings above + filepaths_wo_face = [] # imags where InsightFace was not able to detect a face + + for img_fp in tqdm(fp_list_or_dl): + + if isinstance(img_fp, list): + img_fp = img_fp[0] + + if_emb = getInsightFaceEmbeddings(img_fp, insighftface_app, ConversionNetwork=None, device=None, normalize_emb=True) + fn_emb = getFaceNetEmbeddings(pl_facenet_model, img_fp, device, normalize_emb=True) + + for if_emb_i, fn_emb_i, img_fp_i in zip(if_emb, fn_emb, img_fp): + if isinstance(if_emb_i, torch.Tensor): + + if_embeddings = torch.cat((if_embeddings, if_emb_i.detach().cpu().clone().unsqueeze(0))) + fn_embeddings = torch.cat((fn_embeddings, fn_emb_i.detach().cpu().clone().unsqueeze(0))) + img_filepaths.append(img_fp_i) + else: + filepaths_wo_face.append(img_fp_i) + + if_embeddings = torch.nn.functional.normalize(if_embeddings, dim=1) + fn_embeddings = torch.nn.functional.normalize(fn_embeddings, dim=1) + + return if_embeddings.detach().cpu(), fn_embeddings.detach().cpu(), img_filepaths, filepaths_wo_face +""" + +def getAllEmbeddings(ref_model, probe_model, fp_list_or_dl, device, ref_embedding_size, probe_embedding_size): + + ref_embeddings = torch.empty((0, ref_embedding_size)) + prb_embeddings = torch.empty((0, probe_embedding_size)) + img_filepaths = [] # paths corresponding to the embeddings above + filepaths_wo_face = [] # imags where InsightFace was not able to detect a face + + for img_fp in tqdm(fp_list_or_dl): + + if isinstance(img_fp, list): + img_fp = img_fp[0] + + ref_emb = getEmbeddings(ref_model, img_fp, device=device, normalize_emb=True) + prb_emb = getEmbeddings(probe_model, img_fp, device=device, normalize_emb=True) + + for ref_emb_i, prb_emb_i, img_fp_i in zip(ref_emb, prb_emb, img_fp): + if isinstance(ref_emb_i, torch.Tensor) and isinstance(prb_emb_i, torch.Tensor): + + ref_embeddings = torch.cat((ref_embeddings, ref_emb_i.detach().cpu().clone().unsqueeze(0))) + prb_embeddings = torch.cat((prb_embeddings, prb_emb_i.detach().cpu().clone().unsqueeze(0))) + img_filepaths.append(img_fp_i) + else: + filepaths_wo_face.append(img_fp_i) + + ref_embeddings = torch.nn.functional.normalize(ref_embeddings, dim=1) + prb_embeddings = torch.nn.functional.normalize(prb_embeddings, dim=1) + + return ref_embeddings.detach().cpu(), prb_embeddings.detach().cpu(), img_filepaths, filepaths_wo_face + +""" +def getDisagreementScore(pl_facenet_model, insightface_app, img_fp, ConversionNetwork, device, translate_to_IF, score_fn): + if isinstance(img_fp, str) or not isinstance(img_fp, Sequence): + img_fp = [img_fp] + + if translate_to_IF: + all_fn_emb = getFaceNetEmbeddings(pl_facenet_model, img_fp, ConversionNetwork=ConversionNetwork, device=device, normalize_emb=True) + all_if_emb = getInsightFaceEmbeddings(img_fp, insightface_app, device=device, normalize_emb=True) + else: + all_fn_emb = getFaceNetEmbeddings(pl_facenet_model, img_fp, device=device, normalize_emb=True) + all_if_emb = getInsightFaceEmbeddings(img_fp, insightface_app, ConversionNetwork=ConversionNetwork, device=device, normalize_emb=True) + + scores = [] + for fn_emb, if_emb in zip(all_fn_emb, all_if_emb): + if isinstance(if_emb, torch.Tensor): + scores.append(score_fn(if_emb.tolist(), fn_emb.tolist())) + else: + scores.append(None) + + return scores +""" + +def getDisagreementScore(ref_model, probe_model, img_fp, ConversionNetwork, device, score_fn): + if isinstance(img_fp, str) or not isinstance(img_fp, Sequence): + img_fp = [img_fp] + + all_ref_emb = getEmbeddings(ref_model, img_fp, device=device, normalize_emb=True) + all_prb_emb = getEmbeddings(probe_model, img_fp, device=device, normalize_emb=True) + + #all_trsl_emb = ConversionNetwork(all_prb_emb) + all_trsl_emb = [ConversionNetwork(prb_emb) if isinstance(prb_emb, torch.Tensor) else None for prb_emb in all_prb_emb] + + scores = [] + for ref_emb, trsl_emb in zip(all_ref_emb, all_trsl_emb): + if isinstance(ref_emb, torch.Tensor) and isinstance(trsl_emb, torch.Tensor): + scores.append(score_fn(ref_emb.tolist(), trsl_emb.tolist())) + else: + scores.append(None) + + return scores + +def getZEIscores(labels, embeddings1, embeddings2, n_scores, score_fn): + assert 0 < n_scores + + indices1 = [] + indices2 = [] + scores = [] + while len(scores) < n_scores: + i1, i2 = random.sample(range(len(labels)), k=2) + if labels[i2] != labels[i1]: + indices1.append(i1) + indices2.append(i2) + scores.append(score_fn(embeddings1[i1].tolist(), embeddings2[i2].tolist())) + + return scores + +def getRandHash(rand_hash_len: int = 8): + return ''.join(random.choice(string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(rand_hash_len)) + +def get_cmap(n, name='hsv'): + '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct + RGB color; the keyword argument name must be a standard mpl colormap name.''' + return plt.cm.get_cmap(name, n) + +def getEmbeddingsPlot(embeddings, labels, return_points=False, render_3d=False): + # To use: + # > ax = getEmbeddingsPlot(embeddings, labels) + # > plt.legend() + # > plt.show() + unique_labels = np.unique(labels) + #n_labels = len(unique_labels) + 1 # the +1 is to make a fix for when classes are 2, shouldn't make much of a problem for other higher numbers. + #cmap = get_cmap(n_labels) + n_labels = len(unique_labels) # the +1 is to make a fix for when classes are 2, shouldn't make much of a problem for other higher numbers. + cmap = get_cmap(1 + n_labels//2) + + tsne_embeddings = None + + if render_3d: + tsne_embeddings = TSNE(n_components=3, learning_rate='auto', init='random', perplexity=5, n_iter=2000).fit_transform(embeddings) + f, ax = plt.subplots(subplot_kw=dict(projection='3d')) + for i, label in enumerate(unique_labels): + mask = np.array(labels) == label + ax.scatter(tsne_embeddings[mask,0], tsne_embeddings[mask,1], tsne_embeddings[mask,2], label=label, c=[cmap(np.where(unique_labels == label)[0][0])]*sum(mask)) + + else: + tsne_embeddings = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=5, n_iter=2000).fit_transform(embeddings) + f, ax = plt.subplots(figsize=(8,8)) + # for i, label in enumerate(unique_labels): + # mask = np.array(labels) == label + # ax.scatter(tsne_embeddings[mask,0], tsne_embeddings[mask,1], label=label, c=[cmap(np.where(unique_labels == label)[0][0])]*sum(mask)) + for i, (label1, label2) in enumerate([unique_labels[i:i+2] for i in range(0,len(unique_labels), 2)]): + mask1 = np.array(labels) == label1 + mask2 = np.array(labels) == label2 + ax.scatter(tsne_embeddings[mask1,0], tsne_embeddings[mask1,1], label=label1, c=[cmap(i)]*sum(mask1), marker = '.') + ax.scatter(tsne_embeddings[mask2,0], tsne_embeddings[mask2,1], label=label2, c=[cmap(i)]*sum(mask2), marker = '+') + + if return_points: + return ax, tsne_embeddings + else: + return ax + +def storeSamples(unshuffled_dataloader, root_save_dir, ds_mean, ds_std, root_ds_dir): + # The raw filepaths to the samples we're looking at + if isinstance(unshuffled_dataloader.dataset, torch.utils.data.dataset.Subset): + ds_samples = np.array(unshuffled_dataloader.dataset.dataset.samples)[np.array(unshuffled_dataloader.dataset.indices)] + toclasses = unshuffled_dataloader.dataset.dataset.classes + else: + ds_samples = unshuffled_dataloader.dataset.samples + toclasses = unshuffled_dataloader.dataset.classes + + img_fp_index = 0 + classes = [] + + print('Storing images... ') + for data, label in unshuffled_dataloader: + for img, lbl in zip(data, label): + + denormed_img = denormalize(img, 255*torch.tensor(ds_mean), 255*torch.tensor(ds_std)).type(torch.uint8) + + img_pil = torchvision.transforms.functional.to_pil_image(denormed_img) + + _class = os.path.dirname(os.path.relpath(ds_samples[img_fp_index][0], root_ds_dir)) + if _class not in classes: + classes.append(_class) + + img_savepath = os.path.join(root_save_dir, os.path.relpath(ds_samples[img_fp_index][0], root_ds_dir)) + + img_savedir, _ = os.path.split(img_savepath) + + os.makedirs(img_savedir, exist_ok=True) + print(img_savepath) + img_pil.save(img_savepath) + img_fp_index += 1 + + return classes + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description = "Embedding-based Conversion Backdoor Detector") + + parser.add_argument('--ref_model', type=str, help='The path to a checkpoint for a facenet model or \'insightface\'.') + parser.add_argument('--ref_model_emb_size', type=int, help='Embedding size for the reference model.') + + parser.add_argument('--probe_model', type=str, help='The path to a checkpoint for a facenet model or \'insightface\'.') + parser.add_argument('--probe_model_emb_size', type=int, help='Embedding size for the probe model.') + + parser.add_argument('--pl_dm_ckpt_fp', nargs='+', help='The filepath to the checkpoint for the data module. If more than one provided, clean data is taken from first one.') + #parser.add_argument('--translate_to_IF', action='store_true', help='Use flag to train with poisoned data.') + #parser.add_argument('--trigger_path', type=str, help='If provided, will be used instead of the trigger in pl_dm_ckpt_fp.') + parser.add_argument('--output_dir', type=str, help='Output directory where results files and logs are stored.') + parser.add_argument('--ffhq_dir', type=str, help='Directory of the FFHQ dataset') + parser.add_argument('--ffhq_emb_path', type=str, help='Directory of the precomputed FFHQ embeddings. Will be computed if empty (~45min).') + parser.add_argument('--cwf_clean_val_emb_path', type=str, help='Directory of the precomputed Casia-Webface clean validation embeddings. Will be computed if empty (~2.5h).') + parser.add_argument('--resume_run', action='store_true', help='Use flag to use the output directory as is, instead of creating a date-time based sub directory with a further hash based sub-directory.') + + parser.add_argument('--quick_debug', action='store_true', help='If in use, will limit the number of samples to allow for a quick test run.') + + args = parser.parse_args() + + cuda_is_available = torch.cuda.is_available() + assert cuda_is_available is True + + ########################################################################### + # Parameters + ########################################################################### + rand_hash_len = 8 + cwf_poisoned_samples_dir = 'poisoned_samples' + + ffhq_batch_size = 2**5 + ffhq_num_workers = 4 + + cwf_batch_size = 2**5 + cwf_num_workers = 4 + + insightface_embedding_size = 512 + facenet_embedding_size = 512 + + if args.quick_debug: + DEBUG_SIZE = 1000 # disabled when None. An integer will simply limit the number of samples being used, drastically reducing compute time, mainly for a quick check of the whole execution, for debug purposes. Results will be meaningless though, obviously. + else: + DEBUG_SIZE = None + + emb_conv_batch_size = 2**7 + emb_conv_train_ratio = 0.7 + conv_emb_num_epochs = 3 + lambda_MSE = 0.0 + MSEcriterion = torch.nn.MSELoss() + CosSimcriterion = torch.nn.CosineSimilarity() + loss_fn = lambda out, y_true: lambda_MSE * MSEcriterion(out, y_true) - torch.mean(CosSimcriterion(out, y_true)) + cos_sim = lambda s1, s2: 1 - cdist(s1, s2) + score_fn = cos_sim + + ########################################################################### + # Preparation + ########################################################################### + device = torch.device('cuda') + + now = datetime.now() + if args.resume_run: + complete_results_dir = os.path.join(args.output_dir) + else: + complete_results_dir = os.path.join(args.output_dir, now.strftime("%Y%m%d_%H%M%S"), getRandHash(rand_hash_len)) + os.makedirs(complete_results_dir, exist_ok=True) + print('Output directory:', complete_results_dir) + + args_log_fp = os.path.join(complete_results_dir, 'args.yaml') + + with open(args_log_fp, 'w') as f: + yaml.dump(vars(args), f) + + ########################################################################### + # FFHQ Data loaders + ########################################################################### + ffhq_dm = FFHQ_DataModule(args.ffhq_dir, ffhq_batch_size, ffhq_num_workers, shuffle=False, with_face_extractor=False) + ffhq_dm.setup() + ffhq_dataloader = ffhq_dm.val_dataloader() + + # A dataloader as above, but which doesn't return the transformed images, only the filepaths to the images and their labels + ffhq_fp_dl = torch.utils.data.DataLoader(ffhq_dataloader.dataset.samples[:DEBUG_SIZE], shuffle=False, batch_size=ffhq_batch_size, num_workers=ffhq_num_workers) + + ########################################################################### + # Casia-WebFace Data loaders + ########################################################################### + cwf_dm = CWF_DataModule_ArcFace.load_from_checkpoint(checkpoint_path=args.pl_dm_ckpt_fp[0], hparams_file=None) + + # if args.trigger_path: + # cwf_dm.trigger_train_fp = args.trigger_path + # cwf_dm.trigger_val_fp = args.trigger_path + + cwf_dm.setup() + cwf_val_dataloaders = {dl_name:dl for dl_name, dl in zip(cwf_dm.datasets_names_val, cwf_dm.val_dataloader())} + + cwf_val_clean = cwf_val_dataloaders['`val clean`'] + cwf_val_clean_fp = np.array(cwf_val_clean.dataset.dataset.samples)[np.array(cwf_val_clean.dataset.indices)] + cwf_val_clean_fp = [(str(fp), int(label)) for fp, label in cwf_val_clean_fp] + cwf_val_clean_fp_dl = torch.utils.data.DataLoader(cwf_val_clean_fp[:DEBUG_SIZE], shuffle=False, batch_size=cwf_batch_size, num_workers=cwf_num_workers) + + # When using multiple checkpoints, a current limitation of our method is we can't evaluate on multiple datamodules using the same victim identity + victim_classes = [] + impostor_classes = [] + for i, ckpt in enumerate(args.pl_dm_ckpt_fp): + checkpoint = torch.load(ckpt) + + vict = checkpoint['datamodule_hyper_parameters']['victims'] + #assert vict not in victim_classes + victim_classes.append(vict) + + imp = checkpoint['datamodule_hyper_parameters']['impostors'] + #assert imp not in impostor_classes + impostor_classes.append(imp) + + print('impostor: ', imp, '(', cwf_val_clean.dataset.dataset.classes[imp], ') -> victim ', vict, '(', cwf_val_clean.dataset.dataset.classes[vict], ')') + + ########################################################################### + # Models + ########################################################################### + if args.ref_model.lower() == 'insightface': + model_pack_name = 'buffalo_s' + ref_model = FaceAnalysis(name=model_pack_name) + ref_model.prepare(ctx_id=0, det_thresh=0.5, det_size=(160,160)) + elif os.path.exists(args.ref_model): + ref_model = pl_FaceNet_ArcFace.load_from_checkpoint(checkpoint_path=args.ref_model, map_location=device, hparams_file=None, strict=True) + ref_model = ref_model.to(device) + ref_model.eval() + else: + raise ValueError("args.ref_model can only be a valid path to a facenet checkpoint or the value \'insightface\'. Instead, you provided: {}", args.ref_model) + + if args.probe_model.lower() == 'insightface': + model_pack_name = 'buffalo_s' + probe_model = FaceAnalysis(name=model_pack_name) + probe_model.prepare(ctx_id=0, det_thresh=0.5, det_size=(160,160)) + elif os.path.exists(args.probe_model): + probe_model = pl_FaceNet_ArcFace.load_from_checkpoint(checkpoint_path=args.probe_model, map_location=device, hparams_file=None, strict=True) + probe_model = probe_model.to(device) + probe_model.eval() + else: + raise ValueError("args.probe_model can only be a valid path to a facenet checkpoint or the value \'insightface\'. Instead, you provided: {}", args.probe_model) + + + ########################################################################### + # Store source model and specific parameters + ########################################################################### + + if cwf_dm.impostors and len(cwf_dm.impostors) > 0: + imp_class = cwf_dm.impostors[0] + imp_id = cwf_dm.datasets_train[0].dataset.classes[imp_class] + n_train_imp_samples = len([s[1] for s in np.asarray(cwf_dm.datasets_train[0].dataset.samples)[cwf_dm.datasets_train[0].indices] if int(s[1]) == imp_class]) + else: + imp_class = None + imp_id = None + n_train_imp_samples = None + + if cwf_dm.victims and len(cwf_dm.victims) > 0: + vict_class = cwf_dm.victims[0] + vict_id = cwf_dm.datasets_train[0].dataset.classes[vict_class] + n_train_vict_samples = len([s[1] for s in np.asarray(cwf_dm.datasets_train[0].dataset.samples)[cwf_dm.datasets_train[0].indices] if int(s[1]) == vict_class]) + else: + vict_class = None + vict_id = None + n_train_vict_samples = None + + ckpt_bd_specs = { + 'ref model': args.ref_model, + 'probe model': args.probe_model, + 'data module': args.pl_dm_ckpt_fp, + 'impostor class': imp_class, + 'victim class': vict_class, + 'impostor id': imp_id, + 'victim id': vict_id, + 'n train impostor samples': n_train_imp_samples, + 'n train victim samples': n_train_vict_samples, + } + + ckpt_bd_specs_fp = os.path.join(complete_results_dir, 'ckpt_bd_specs.yaml') + + with open(ckpt_bd_specs_fp, 'w') as f: + yaml.dump(ckpt_bd_specs, f) + + print('Stored checkpoint backdoor specs:', ckpt_bd_specs_fp) + + ########################################################################### + # Computing and storing Casia-Webface poisoned samples and embeddings (5s) + ########################################################################### + + cwf_val_p_fp_dl_dict = {} + cwf_val_p_embeddings_all = {} + + for i, ckpt_i in enumerate(args.pl_dm_ckpt_fp): + + cwf_dm_i = CWF_DataModule_ArcFace.load_from_checkpoint(checkpoint_path=ckpt_i, hparams_file=None) + cwf_dm_i.setup() + cwf_val_dataloaders_i = {dl_name:dl for dl_name, dl in zip(cwf_dm_i.datasets_names_val, cwf_dm_i.val_dataloader())} + cwf_val_imp_p_i = cwf_val_dataloaders_i['`val impostor(s) poison`'] + _classes = storeSamples(cwf_val_imp_p_i, os.path.join(complete_results_dir, cwf_poisoned_samples_dir, 'pl_dm_' + str(i)), cwf_dm.ds_mean, cwf_dm.ds_std, cwf_dm.dataset_dir) + + cwf_val_p_fp_ds = [] + + for c in _classes: + + p_img_dir = os.path.join(complete_results_dir, cwf_poisoned_samples_dir, 'pl_dm_' + str(i), c) + all_p_img_fp = [os.path.join(p_img_dir, fp) for fp in sorted(os.listdir(p_img_dir))] + + # Genuine label: + # label = cwf_val_imp_p_i.dataset.dataset.class_to_idx[c] + # Poisoned label (i.e. victim label): + # label = cwf_val_imp_p_i.dataset.dataset.target_transform(cwf_val_imp_p_i.dataset.dataset.class_to_idx[c]) + + label = cwf_val_imp_p_i.dataset.dataset.target_transform(cwf_val_imp_p_i.dataset.dataset.class_to_idx[c]) + cwf_val_p_fp_ds.extend([(str(fp), torch.tensor([label])) for fp in all_p_img_fp]) + + cwf_val_p_fp_dl = torch.utils.data.DataLoader(cwf_val_p_fp_ds, shuffle=False, batch_size=cwf_batch_size, num_workers=cwf_num_workers) + + cwf_val_p_fp_dl_dict[i] = { + 'dl': cwf_val_p_fp_dl, + 'ckpt_fp': ckpt_i + } + + print('Computing poisoned Casia-WebFace embeddings... ', end='') + cwf_val_p_ref_embeddings, cwf_val_p_probe_embeddings, cwf_val_p_img_filepaths, cwf_val_p_filepaths_wo_face \ + = getAllEmbeddings(ref_model, probe_model, cwf_val_p_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + print('done') + + cwf_val_p_embeddings_all[i] = { + 'Reference model embeddings': cwf_val_p_ref_embeddings, + 'Probe model embeddings': cwf_val_p_probe_embeddings, + 'images filepaths': cwf_val_p_img_filepaths, + 'filepaths without face': cwf_val_p_filepaths_wo_face, + 'dm_checkpoint': ckpt_i + } + + pl_dm_index_fp = os.path.join(complete_results_dir, 'pl_dm_index.yaml') + + with open(pl_dm_index_fp, 'w') as f: + yaml.dump({i: cwf_val_p_fp_dl_dict[i]['ckpt_fp'] for i in cwf_val_p_fp_dl_dict}, f) + + cwf_val_p_embeddings_fp = os.path.join(complete_results_dir, 'cwf_val_p_embeddings.pkl') + + with open(cwf_val_p_embeddings_fp, 'wb') as fh: + pickle.dump(cwf_val_p_embeddings_all, fh) + + print('Stored Casia-WebFace poisoned validation embeddings:', cwf_val_p_embeddings_fp) + + ########################################################################### + # Computing and storing Casia-Webface embeddings (~2h30) + ########################################################################### + if args.cwf_clean_val_emb_path: + + with open(args.cwf_clean_val_emb_path, 'rb') as fh: + cwf_clean_val_embeddings_all = pickle.load(fh) + + cwf_val_ref_embeddings = cwf_clean_val_embeddings_all['Reference model embeddings'] + cwf_val_probe_embeddings = cwf_clean_val_embeddings_all['Probe model embeddings'] + cwf_val_img_filepaths = cwf_clean_val_embeddings_all['images filepaths'] + cwf_val_filepaths_wo_face = cwf_clean_val_embeddings_all['filepaths without face'] + + print('Loaded Casia-WebFace clean validation embeddings:', args.cwf_clean_val_emb_path) + + else: + print('No path to precomputed Casia-WebFace clean validation embeddings provided.') + print('Computing Casia-WebFace clean validation embeddings... ', end='') + + cwf_val_ref_embeddings, cwf_val_probe_embeddings, cwf_val_img_filepaths, cwf_val_filepaths_wo_face \ + = getAllEmbeddings(ref_model, probe_model, cwf_val_clean_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + print('done') + + cwf_clean_val_embeddings_all = { + 'Reference model embeddings': cwf_val_ref_embeddings, + 'Probe model embeddings': cwf_val_probe_embeddings, + 'images filepaths': cwf_val_img_filepaths, + 'filepaths without face': cwf_val_filepaths_wo_face, + } + + cwf_val_embeddings_fp = os.path.join(complete_results_dir, 'cwf_val_clean_embeddings.pkl') + + with open(cwf_val_embeddings_fp, 'wb') as fh: + pickle.dump(cwf_clean_val_embeddings_all, fh) + + print('Stored Casia-WebFace clean validation embeddings:', cwf_val_embeddings_fp) + + # Creating label array from filepaths for ZEI scores (only of the images where InsightFace detected a face) + labels_cwf_val_clean = [os.path.split(os.path.split(cwf_val_img_filepaths[i])[0])[-1] for i in range(len(cwf_val_img_filepaths))] + + + ########################################################################### + # Computing and storing FFHQ embeddings (~48min) + ########################################################################### + if args.ffhq_emb_path: + with open(args.ffhq_emb_path, 'rb') as fh: + ffhq_embeddings_all = pickle.load(fh) + + ffhq_ref_embeddings = ffhq_embeddings_all['Reference model embeddings'] + ffhq_probe_embeddings = ffhq_embeddings_all['Probe model embeddings'] + ffhq_img_filepaths = ffhq_embeddings_all['images filepaths'] + ffhq_filepaths_wo_face = ffhq_embeddings_all['filepaths without face'] + + print('Loaded FFHQ embeddings:', args.ffhq_emb_path) + else: + print('No path to precomputed FFHQ embeddings provided.') + print('Computing FFHQ embeddings... ', end='') + + ffhq_ref_embeddings, ffhq_probe_embeddings, ffhq_img_filepaths, ffhq_filepaths_wo_face \ + = getAllEmbeddings(ref_model, probe_model, ffhq_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + print('done') + + ffhq_embeddings_fp = os.path.join(complete_results_dir, 'ffhq_all_embeddings.pkl') + + ffhq_embeddings_all = { + 'Reference model embeddings': ffhq_ref_embeddings, + 'Probe model embeddings': ffhq_probe_embeddings, + 'images filepaths': ffhq_img_filepaths, + 'filepaths without face': ffhq_filepaths_wo_face, + } + + with open(ffhq_embeddings_fp, 'wb') as fh: + pickle.dump(ffhq_embeddings_all, fh) + + print('Stored FFHQ embeddings:', ffhq_embeddings_fp) + + + ########################################################################### + # Setting up training and validation loaders for the embedding conversion + ########################################################################### + + assert len(cwf_val_ref_embeddings) == len(cwf_val_probe_embeddings) == len(cwf_val_img_filepaths) + + cwf_clean_val_emb_train_ds = torch.utils.data.TensorDataset(cwf_val_ref_embeddings[:round(emb_conv_train_ratio*len(cwf_val_ref_embeddings))], cwf_val_probe_embeddings[:round(emb_conv_train_ratio*len(cwf_val_probe_embeddings))]) + cwf_clean_val_emb_test_ds = torch.utils.data.TensorDataset(cwf_val_ref_embeddings[round(emb_conv_train_ratio*len(cwf_val_ref_embeddings)):], cwf_val_probe_embeddings[round(emb_conv_train_ratio*len(cwf_val_probe_embeddings)):]) + cwf_clean_val_emb_train_dl = torch.utils.data.DataLoader(cwf_clean_val_emb_train_ds, shuffle=True, batch_size=emb_conv_batch_size, drop_last=True) + cwf_clean_val_emb_test_dl = torch.utils.data.DataLoader(cwf_clean_val_emb_test_ds, shuffle=False, batch_size=emb_conv_batch_size, drop_last=True) + + assert len(ffhq_ref_embeddings) == len(ffhq_probe_embeddings) == len(ffhq_img_filepaths) + + ffhq_emb_train_ds = torch.utils.data.TensorDataset(ffhq_ref_embeddings[:round(emb_conv_train_ratio*len(ffhq_ref_embeddings))], ffhq_probe_embeddings[:round(emb_conv_train_ratio*len(ffhq_probe_embeddings))]) + ffhq_emb_test_ds = torch.utils.data.TensorDataset(ffhq_ref_embeddings[round(emb_conv_train_ratio*len(ffhq_ref_embeddings)):], ffhq_probe_embeddings[round(emb_conv_train_ratio*len(ffhq_probe_embeddings)):]) + ffhq_emb_train_dl = torch.utils.data.DataLoader(ffhq_emb_train_ds, shuffle=True, batch_size=emb_conv_batch_size, drop_last=True) + ffhq_emb_test_dl = torch.utils.data.DataLoader(ffhq_emb_test_ds, shuffle=False, batch_size=emb_conv_batch_size, drop_last=True) + + + ########################################################################### + # Setting up embedding translation, optimizer and criterion + ########################################################################### + ConversionNetwork = torch.nn.Sequential( + torch.nn.Linear(args.probe_model_emb_size, args.ref_model_emb_size), + ) + + ConversionNetwork.to(device) + optimizer = torch.optim.Adam(ConversionNetwork.parameters()) + + ConversionNetwork.train() + + train_batch_counter = 0 + + train_losses = [] + train_batch_indices = [] + + test_losses = [] + test_batch_indices = [] + + + ########################################################################### + # Embedding translation training + ########################################################################### + + for epoch in tqdm(range(conv_emb_num_epochs)): + for ref_emb, prb_emb in ffhq_emb_train_dl: + train_batch_counter += 1 + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=1) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=1) + + out = ConversionNetwork(prb_emb) + loss = loss_fn(out, ref_emb) # lambda_MSE * MSEcriterion(out, fn_emb) - torch.mean(CosSimcriterion(out, fn_emb)) + + train_losses.append(loss.item()) + train_batch_indices.append(train_batch_counter) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.no_grad(): + test_loss = 0 + for ref_emb, prb_emb in ffhq_emb_test_dl: + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=1) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=1) + + out = ConversionNetwork(prb_emb) + loss = loss_fn(out, ref_emb) + + test_loss += loss.item() + + avg_test_loss = test_loss/len(ffhq_emb_test_dl) + + test_losses.append(avg_test_loss) + test_batch_indices.append(train_batch_counter) + + ########################################################################### + # Saving embedding translation training/validation plot + ########################################################################### + + print('Min test loss:', np.min(test_losses)) + print('Last test loss:', test_losses[-1]) + + plt.plot(train_batch_indices, train_losses, color='blue', label='train loss', zorder=0) + plt.scatter(test_batch_indices, test_losses, color='red', label='test loss', zorder=10) + # plt.yscale('symlog') + # plt.xscale('log') + plt.title('Loss on Embedding Translation network on FFHQ val') + plt.ylabel('Loss per batch') + plt.xlabel('Batch') + xmin, xmax = plt.xlim() + # plt.hlines(y=0, xmin=xmin, xmax=xmax, linestyles='dashed') + plt.legend() + plt_fp = os.path.join(complete_results_dir, 'emb_conv_train_val_losses.pdf') + plt.savefig(plt_fp) + plt.close() + print("Saved plot of training and validation losses for the embedding conversion:", plt_fp) + + + ########################################################################### + # Genuine scores on FFHQ validation + ########################################################################### + + # Scipy cosine distance: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.cosine.html + # Pytorch cosine similarity: https://pytorch.org/docs/stable/generated/torch.nn.CosineSimilarity.html + + genuines_ffhq_val = [] + + for batch, (ref_emb, prb_emb) in enumerate(ffhq_emb_test_dl): + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=1) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=1) + + out = ConversionNetwork(prb_emb) + for out_i, ref_emb_i in zip(out, ref_emb): + score_i = score_fn(out_i.tolist(), ref_emb_i.tolist()) + + genuines_ffhq_val.append(score_i) + + + ########################################################################### + # Simulated ZEI scores on FFHQ validation + # (Simulated because what is a true ZEI when the same image is passed to both networks? + # This serves as a potential reference point in case an image yields an embedding of another + # identity, such as a poisoned image for a backdoored network). + ########################################################################### + zei_ffhq_val = [] + + first = True + + prev_emb = None + + for batch, (ref_emb, prb_emb) in enumerate(ffhq_emb_test_dl): + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + if first: + first = False + else: + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=1) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=1) + + out = ConversionNetwork(prb_emb) + + for out_i, prev_emb_i in zip(out, prev_emb): + score_i = score_fn(out_i.tolist(), prev_emb_i.tolist()) + + zei_ffhq_val.append(score_i) + + prev_emb = ref_emb.clone() + + + ########################################################################### + # FFHQ validation scores + ########################################################################### + + # plt.hist: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist.html + + # We store the scores + X_genuine = np.stack((np.array(genuines_ffhq_val), np.array([0]*len(genuines_ffhq_val))), axis=0).transpose() + X_zei = np.stack((np.array(zei_ffhq_val), np.array([1]*len(zei_ffhq_val))), axis=0).transpose() + X_ffhq_val = np.concatenate((X_genuine, X_zei), axis=0) + ffhq_val_scores_fp = os.path.join(complete_results_dir, 'ffhq_val_scores.txt') + np.savetxt(ffhq_val_scores_fp, X_ffhq_val, fmt=['%.5f', '%d'], header='FFHQ val cosine similarity score, class (0=genuine, 1=zei)') + + _ = plt.hist(genuines_ffhq_val, bins='auto', color='green', alpha=0.5, density=True, label='Genuine') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + _ = plt.hist(zei_ffhq_val, bins='auto', color='blue', alpha=0.5, density=True, label='Zero-effort impostors') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + plt.title('Cosine similarity on FFHQ validation samples') + #plt.xlim(-0.1, 2.1) + plt.ylabel('Scores histogram') + plt.xlabel('Cosine similarity') + plt.legend() + plt_fp = os.path.join(complete_results_dir, 'ffhq_validation_scores.pdf') + plt.savefig(plt_fp) + plt.close() + print("Saved plot of FFHQ validation scores:", plt_fp) + #plt.show() + + + + + ########################################################################### + # Scoring poisoned Casia-WebFace validation samples (impostor with trigger) + ########################################################################### + + #poisoned_impostor_scores = getDisagreementScore(ref_model, probe_model, all_p_img_fp, ConversionNetwork, device, score_fn = score_fn) + + poisoned_impostor_scores = {} + p_scores = {} + for i in cwf_val_p_embeddings_all: + + poisoned_impostor_scores[i] = [] + + for ref_emb, prb_emb in zip(cwf_val_p_embeddings_all[i]['Reference model embeddings'], cwf_val_p_embeddings_all[i]['Probe model embeddings']): + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=0) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=0) + + out = ConversionNetwork(prb_emb) + score_i = score_fn(out.tolist(), ref_emb.tolist()) + poisoned_impostor_scores[i].append(score_i) + + print('Scores from poisoned samples from checkpoint:', cwf_val_p_fp_dl_dict[i]['ckpt_fp']) + p_scores[i] = [p_score for p_score in poisoned_impostor_scores[i] if p_score] + print('Number of poisoned validation samples:', len(poisoned_impostor_scores[i])) + print('Number of valid poisoned scores (where a face is detected in original and reconstructed image):', len(p_scores[i])) + print('Min:', min(p_scores[i])) + print('Max:', max(p_scores[i])) + + + ########################################################################### + # Scoring clean Casia-WebFace validation samples + ########################################################################### + cwf_clean_val_genuine_scores = [] + + for batch, (ref_emb, prb_emb) in enumerate(cwf_clean_val_emb_test_dl): + + ref_emb = ref_emb.clone().detach().to(device) + prb_emb = prb_emb.clone().detach().to(device) + + ref_emb = torch.nn.functional.normalize(ref_emb, dim=1) + prb_emb = torch.nn.functional.normalize(prb_emb, dim=1) + + out = ConversionNetwork(prb_emb) + + for out_i, ref_emb_i in zip(out, ref_emb): + score_i = score_fn(out_i.tolist(), ref_emb_i.tolist()) + + cwf_clean_val_genuine_scores.append(score_i) + + + ########################################################################### + # Scoring clean Casia-WebFace validation samples + ########################################################################### + cwf_val_probe_embeddings_converted = ConversionNetwork(torch.nn.functional.normalize(cwf_val_probe_embeddings, dim=1).to(device)) + zei_cwf_clean_val = getZEIscores(labels_cwf_val_clean, cwf_val_ref_embeddings, cwf_val_probe_embeddings_converted, 100000, score_fn = score_fn) + + + ########################################################################### + # Plotting Casia-WebFace validation scores + ########################################################################### + # plt.hist: https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist.html + + for i in p_scores: + # We store the scores + X_genuine = np.stack((np.array(cwf_clean_val_genuine_scores), np.array([0]*len(cwf_clean_val_genuine_scores))), axis=0).transpose() + X_zei = np.stack((np.array(zei_cwf_clean_val), np.array([1]*len(zei_cwf_clean_val))), axis=0).transpose() + X_poison = np.stack((np.array(p_scores[i]), np.array([2]*len(p_scores[i]))), axis=0).transpose() + X_ffhq_val = np.concatenate((X_genuine, X_zei, X_poison), axis=0) + cwf_val_scores_fp = os.path.join(complete_results_dir, 'cwf_val_scores_' + str(i) + '.txt') + np.savetxt(cwf_val_scores_fp, X_ffhq_val, fmt=['%.5f', '%d'], header='CWF val cosine similarity score, class (0=genuine, 1=zei, 2=poison)') + + _ = plt.hist(cwf_clean_val_genuine_scores, bins='auto', color='green', alpha=0.5, density=True, label='Genuine') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + _ = plt.hist(zei_cwf_clean_val, bins='auto', color='blue', alpha=0.5, density=True, label='Zero-effort impostors') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + _ = plt.hist(p_scores[i], bins='auto', color='red', alpha=0.5, density=True, label='Poisoned attacker') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + # _ = plt.hist(p_scores[i], bins='auto', color='red', alpha=0.5, density=True, label='Clean attacker') # arguments are passed to np.histogram# plt.plot(zei, color='green', label='test loss') + plt.title('Cosine similarity on Casia-WebFace validation') + #plt.xlim(-0.1, 2.1) + plt.ylabel('Scores histogram') + plt.xlabel('Cosine similarity') + plt.legend() + plt_fp = os.path.join(complete_results_dir, 'cwf_validation_scores_' + str(i) + '.pdf') + plt.savefig(plt_fp) + plt.close() + print("Saved plot of Casia-WebFace validation scores:", plt_fp) + #plt.show() + + + ########################################################################### + # t-SNE plot + ########################################################################### + + n_samples_per_class = 10 + n_other_classes = 5 + + selected_classes = [] + current_class = 0 + + cwf_labels_val_clean = np.asarray(cwf_val_dataloaders['`val clean`'].dataset.dataset.targets)[np.asarray(cwf_val_dataloaders['`val clean`'].dataset.indices)] + + # Here we select the clean classes + classes, counts = np.unique(cwf_labels_val_clean, return_counts=True) + while len(selected_classes) < n_other_classes: + if current_class != imp_class and current_class != vict_class and counts[current_class] >= n_samples_per_class: + selected_classes.append(current_class) + + current_class += 1 + + # Here we select the samples from the selected classes + selected_samples = {} + for selected_class in selected_classes: + samples_idx = np.arange(len(cwf_labels_val_clean))[np.asarray(cwf_labels_val_clean) == selected_class][:n_samples_per_class] + selected_samples[selected_class] = samples_idx + + """ + flattened_selected_samples = np.concatenate(list(selected_samples.values())) + flattened_selected_labels = [] + for selected_class in selected_classes: + flattened_selected_labels += [selected_class]*n_samples_per_class + """ + + others_indices = [] + for v in selected_samples.values(): + others_indices += v.tolist() + + vict_indices = [i for i, (fp, label) in enumerate(cwf_val_clean_fp) if label == vict_class][:n_samples_per_class] + imp_indices = [i for i, (fp, label) in enumerate(cwf_val_clean_fp) if label == imp_class ][:n_samples_per_class] + + victim_fp_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(cwf_val_clean_fp, vict_indices), shuffle=False, batch_size=cwf_batch_size, num_workers=cwf_num_workers) + victim_embd_ref, victim_embd_prb, victim_img_filepaths, victim_filepaths_wo_face = getAllEmbeddings(ref_model, probe_model, victim_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + + imp_v_fp_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(cwf_val_clean_fp, imp_indices), shuffle=False, batch_size=cwf_batch_size, num_workers=cwf_num_workers) + imp_v_embd_ref, imp_v_embd_prb, imp_v_img_filepaths, imp_v_filepaths_wo_face = getAllEmbeddings(ref_model, probe_model, imp_v_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + + others_fp_dl = torch.utils.data.DataLoader(torch.utils.data.Subset(cwf_val_clean_fp, others_indices), shuffle=False, batch_size=cwf_batch_size, num_workers=cwf_num_workers) + others_embd_ref, others_embd_prb, others_img_filepaths, others_filepaths_wo_face = getAllEmbeddings(ref_model, probe_model, others_fp_dl, device, ref_embedding_size = args.ref_model_emb_size, probe_embedding_size = args.probe_model_emb_size) + + victim_embd_ref = victim_embd_ref[:n_samples_per_class] + other_embd_ref = others_embd_ref + + victim_embd_trsl = ConversionNetwork(torch.nn.functional.normalize(victim_embd_prb[:n_samples_per_class], dim=1).to(device)) + imp_v_embd_trsl = ConversionNetwork(torch.nn.functional.normalize(imp_v_embd_prb[:n_samples_per_class], dim=1).to(device)) + other_embd_trsl = ConversionNetwork(torch.nn.functional.normalize(others_embd_prb, dim=1).to(device)) + + for i in cwf_val_p_embeddings_all: + imp_p_embd_ref = cwf_val_p_embeddings_all[i]['Reference model embeddings'].detach().to(device) + imp_p_embd_prb = cwf_val_p_embeddings_all[i]['Probe model embeddings'].detach().to(device) + + imp_p_embd_ref = imp_p_embd_ref[:n_samples_per_class] + imp_v_embd_ref = imp_v_embd_ref[:n_samples_per_class] + + imp_p_embd_trsl = ConversionNetwork(torch.nn.functional.normalize(imp_p_embd_prb[:n_samples_per_class], dim=1).to(device)) + + labels_ref = ['Victim (Ref)'] * len(victim_embd_ref) \ + + ['Impostor poisoned (Ref)'] * len(imp_p_embd_ref) \ + + ['Impostor clean (Ref)'] * len(imp_v_embd_ref) \ + + ['Others (Ref)'] * len(other_embd_ref) + + labels_trsl = ['Victim (Translated)'] * len(victim_embd_trsl) \ + + ['Impostor poisoned (Translated)'] * len(imp_p_embd_trsl) \ + + ['Impostor clean (Translated)'] * len(imp_v_embd_trsl) \ + + ['Others (Translated)'] * len(other_embd_trsl) + + all_embeddings = torch.cat((victim_embd_ref.detach().cpu(), imp_p_embd_ref.detach().cpu(), imp_v_embd_ref.detach().cpu(), other_embd_ref.detach().cpu(), victim_embd_trsl.detach().cpu(), imp_p_embd_trsl.detach().cpu(), imp_v_embd_trsl.detach().cpu(), other_embd_trsl.detach().cpu())) + all_labels = labels_ref + labels_trsl + + # Here we plot backdoor related classes to the rest (not distinguishing other classes) + ax, tsne_points = getEmbeddingsPlot(all_embeddings, all_labels, render_3d=False, return_points=True) + plt.title('t-SNE plot of embeddings') + plt.legend() + plt_fp = os.path.join(complete_results_dir, 'tsne_embeddings_plot_' + str(i) + '.pdf') + plt.savefig(plt_fp, dpi=600) + #plt.show() + plt.close() + print("Saved plot of t-SNE embeddings from both networks:", plt_fp) + \ No newline at end of file diff --git a/src/train_facenet.py b/src/train_facenet.py new file mode 100644 index 0000000000000000000000000000000000000000..4c0db5df981eb4666a3c10c69b12fde258c1bf5d --- /dev/null +++ b/src/train_facenet.py @@ -0,0 +1,37 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch +from tqdm import tqdm +from datetime import datetime +import torchvision +import os +import sys +from pytorch_lightning.cli import LightningCLI + +import wandb +import einops + +import backdoorlib as bd +from pl_FaceNet_arcface import pl_FaceNet_ArcFace +from pl_CWF_arcface import CWF_DataModule_ArcFace +from pl_overrides import SaveConfigCallback as NewSaveConfigCallback + +# os.environ["CUDA_LAUNCH_BLOCKING"]='1' # ONLY FOR DEBUGGING CUDA CODE + +print("Execution started on:", str(datetime.now()), '\n') + +try: + # the NewSaveConfigCallback uses the wandb run_name to save a copy of the config file + # preventing collision problems running multiple wandb experiments where config.yaml is used + # as a default name for all of them + cli = LightningCLI( + pl_FaceNet_ArcFace, + CWF_DataModule_ArcFace, + save_config_callback=NewSaveConfigCallback) +finally: + if 'cli' in locals(): + bd.moveConfig(cli) + else: + print("WARNING: LightningCLI seems to have not exited properly, so " + "can't retrieve log dir to move the config file.") +