diff --git a/doc/config.rst b/doc/config.rst index 307e5d3e41099284aef71b4d8386e09b33b0e183..9ecb361cfd66897afca06af8c362a6941b381fd7 100644 --- a/doc/config.rst +++ b/doc/config.rst @@ -32,7 +32,20 @@ Pre-configured models you can readily use. mednet.config.models.pasa -.. _mednet.config.datamodules: +Data Augmentations +================== + +Sequences of data augmentations you can readily use. + +.. _mednet.config.augmentations: + +.. autosummary:: + :toctree: api/config.augmentations + :template: config.rst + + mednet.config.augmentations.elastic + mednet.config.augmentations.affine + DataModule support ================== @@ -42,6 +55,8 @@ supported in this package, for your reference. Each pre-configured DataModule can receive the name of one or more splits as argument to build a fully functional DataModule that can be used in training, prediction or testing. +.. _mednet.config.datamodules: + .. autosummary:: :toctree: api/config.datamodules diff --git a/doc/references.rst b/doc/references.rst index a677685ca2b61a295e2d726d93de5d44251ce216..f70d92d25853be6cc98de4c15813241b99f12d3e 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -78,6 +78,6 @@ 2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020 .. [ROAD-2022] *Y. Rong, T. Leemann, V. Borisov, G. Kasneci, and E. Kasneci*, - *A Consistent and Efficient Evaluation Strategy for Attribution Methods* in + **A Consistent and Efficient Evaluation Strategy for Attribution Methods** in Proceedings of the 39th International Conference on Machine Learning, PMLR, Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html diff --git a/pyproject.toml b/pyproject.toml index 21533dd4573c2463544004ef6f000ba37e215837..d407b18cea5f716bd8e9ef75321dd693d8ce7f1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -244,6 +244,10 @@ alexnet-pretrained = "mednet.config.models.alexnet_pretrained" densenet = "mednet.config.models.densenet" densenet-pretrained = "mednet.config.models.densenet_pretrained" +# lists of data augmentations +elastic = "mednet.config.augmentations.elastic" +affine = "mednet.config.augmentations.affine" + # montgomery dataset (and cross-validation folds) montgomery = "mednet.config.data.montgomery.default" montgomery-f0 = "mednet.config.data.montgomery.fold_0" diff --git a/src/mednet/config/augmentations/__init__.py b/src/mednet/config/augmentations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/mednet/config/augmentations/affine.py b/src/mednet/config/augmentations/affine.py new file mode 100644 index 0000000000000000000000000000000000000000..723403d078229289337a071715178446a524d77b --- /dev/null +++ b/src/mednet/config/augmentations/affine.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Simple affine augmentations for image analysis.""" + +import torchvision.transforms + +augmentations = [ + torchvision.transforms.RandomAffine( + degrees=10, + translate=(0.1, 0.1), # horizontal, vertical + scale=(0.8, 1.0), # minimum, maximum + interpolation=torchvision.transforms.InterpolationMode.BILINEAR, + ), + torchvision.transforms.RandomHorizontalFlip(p=0.5), +] diff --git a/src/mednet/config/augmentations/elastic.py b/src/mednet/config/augmentations/elastic.py new file mode 100644 index 0000000000000000000000000000000000000000..0335548dc6f37a89eebd54b180b0f0d33f1cf1f5 --- /dev/null +++ b/src/mednet/config/augmentations/elastic.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later +"""Elastic deformation with 80% probability. + +This sole data augmentation was proposed by Pasa in the article +"Efficient Deep Network Architectures for Fast Chest X-Ray Tuberculosis +Screening and Visualization". + +Reference: [PASA-2019]_ +""" + +from mednet.data.augmentations import ElasticDeformation + +augmentations = [ElasticDeformation(p=0.8)] diff --git a/src/mednet/config/models/alexnet.py b/src/mednet/config/models/alexnet.py index 7f28186750e1e21d1f801cbb7d21d17da5ca2011..de1e95bf0a5d4b40f2a6f339c5605a936f408e86 100644 --- a/src/mednet/config/models/alexnet.py +++ b/src/mednet/config/models/alexnet.py @@ -11,13 +11,11 @@ page <alexnet-pytorch_>`_), modified for a variable number of outputs from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), - augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=False, ) diff --git a/src/mednet/config/models/alexnet_pretrained.py b/src/mednet/config/models/alexnet_pretrained.py index a935655555a004cbe3c5b8e2b19f77458e952e40..f3ed61ba2b8f6d2d7e5cbf0d9554efd5c265546d 100644 --- a/src/mednet/config/models/alexnet_pretrained.py +++ b/src/mednet/config/models/alexnet_pretrained.py @@ -13,13 +13,11 @@ N.B.: The output layer is **always** initialized from scratch. from torch.nn import BCEWithLogitsLoss from torch.optim import SGD -from mednet.data.augmentations import ElasticDeformation from mednet.models.alexnet import Alexnet model = Alexnet( loss_type=BCEWithLogitsLoss, optimizer_type=SGD, optimizer_arguments=dict(lr=0.01, momentum=0.1), - augmentation_transforms=[ElasticDeformation(p=0.8)], pretrained=True, ) diff --git a/src/mednet/config/models/densenet.py b/src/mednet/config/models/densenet.py index 9ee510ac8df93713b995f857bf5afe2cb68b89a6..3f3e69072aea5c528afa4e2bc284bfc02c2654db 100644 --- a/src/mednet/config/models/densenet.py +++ b/src/mednet/config/models/densenet.py @@ -11,14 +11,12 @@ page <densenet_pytorch_>`), modified for a variable number of outputs from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=False, dropout=0.1, ) diff --git a/src/mednet/config/models/densenet_pretrained.py b/src/mednet/config/models/densenet_pretrained.py index b7e2efcdfa83e1b70a466dbca0ddca02cf4695dc..d7637a009d6670914edce6bd92e076775a491850 100644 --- a/src/mednet/config/models/densenet_pretrained.py +++ b/src/mednet/config/models/densenet_pretrained.py @@ -13,14 +13,12 @@ N.B.: The output layer is **always** initialized from scratch. from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=True, dropout=0.1, ) diff --git a/src/mednet/config/models/densenet_rs.py b/src/mednet/config/models/densenet_rs.py index 813bb76cf92b3abe105e7095085d7e01de4fbecd..f91af6318472b9bd22e652bdcf3223b7ef5cd4a1 100644 --- a/src/mednet/config/models/densenet_rs.py +++ b/src/mednet/config/models/densenet_rs.py @@ -12,14 +12,12 @@ weights from scratch for radiological sign detection. from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from mednet.data.augmentations import ElasticDeformation from mednet.models.densenet import Densenet model = Densenet( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=0.0001), - augmentation_transforms=[ElasticDeformation(p=0.2)], pretrained=False, dropout=0.1, num_classes=14, # number of classes in NIH CXR-14 diff --git a/src/mednet/config/models/pasa.py b/src/mednet/config/models/pasa.py index 7787d10e32cfad9ece6d42cee8be6bc0bb86124f..2db7c8d9c2bbb74ecb2becbd549237cfde4faf61 100644 --- a/src/mednet/config/models/pasa.py +++ b/src/mednet/config/models/pasa.py @@ -13,12 +13,10 @@ Reference: [PASA-2019]_ from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from mednet.data.augmentations import ElasticDeformation from mednet.models.pasa import Pasa model = Pasa( loss_type=BCEWithLogitsLoss, optimizer_type=Adam, optimizer_arguments=dict(lr=8e-5), - augmentation_transforms=[ElasticDeformation(p=0.8)], ) diff --git a/src/mednet/data/augmentations.py b/src/mednet/data/augmentations.py index 2ed9fa1fce0ee0bdb90a426bbf894c62e219ead8..47fb0fac3b6b9439270e73e9a6793ec3514569d6 100644 --- a/src/mednet/data/augmentations.py +++ b/src/mednet/data/augmentations.py @@ -242,6 +242,17 @@ class ElasticDeformation: self.p: float = p self.parallel = parallel + def __str__(self) -> str: + parameters = [ + f"alpha={self.alpha}", + f"sigma={self.sigma}", + f"spline_order={self.spline_order}", + f"mode={self.mode}", + f"p={self.p}", + f"parallel={self.parallel}", + ] + return f"{type(self).__name__}({', '.join(parameters)})" + @property def parallel(self) -> int: """Use multiprocessing for data augmentation. diff --git a/src/mednet/models/alexnet.py b/src/mednet/models/alexnet.py index 75223c9a4b78e196d0c81dab20b566c4b5d32d4b..3e58463e62d1ceae2a54868a08dc9205b35837ca 100644 --- a/src/mednet/models/alexnet.py +++ b/src/mednet/models/alexnet.py @@ -138,7 +138,7 @@ class Alexnet(Model): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(self._augmentation_transforms(images)) + outputs = self(self.augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/src/mednet/models/densenet.py b/src/mednet/models/densenet.py index 76df1ed64a7a73e99b86d18145169dd600601044..ce1eb6dd607dea8a22341257b2d36f14e4a41641 100644 --- a/src/mednet/models/densenet.py +++ b/src/mednet/models/densenet.py @@ -141,7 +141,7 @@ class Densenet(Model): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(self._augmentation_transforms(images)) + outputs = self(self.augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/src/mednet/models/model.py b/src/mednet/models/model.py index 8f5b1be492f84c7171ecbdcaceff1331d15c277e..d109e3cc090e445eedd81fd37bcad2306fec1024 100644 --- a/src/mednet/models/model.py +++ b/src/mednet/models/model.py @@ -71,9 +71,26 @@ class Model(pl.LightningModule): self._optimizer_type = optimizer_type self._optimizer_arguments = optimizer_arguments - self._augmentation_transforms = torchvision.transforms.Compose( - augmentation_transforms, - ) + self.augmentation_transforms = augmentation_transforms + + @property + def augmentation_transforms(self) -> torchvision.transforms.Compose: + return self._augmentation_transforms + + @augmentation_transforms.setter + def augmentation_transforms(self, v: TransformSequence): + self._augmentation_transforms = torchvision.transforms.Compose(v) + + if len(v) != 0: + transforms_str = ", ".join( + [ + f"{type(k).__module__}.{str(k)}" + for k in self._augmentation_transforms.transforms + ] + ) + logger.info(f"Data augmentations: {transforms_str}") + else: + logger.info("Data augmentations: None") def forward(self, x): raise NotImplementedError diff --git a/src/mednet/models/pasa.py b/src/mednet/models/pasa.py index e9147683b08f8d8b532396544eece7829c23fd92..cb7ebfea0da3d8433de869bda54f22b64d43bb0f 100644 --- a/src/mednet/models/pasa.py +++ b/src/mednet/models/pasa.py @@ -211,7 +211,7 @@ class Pasa(Model): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - outputs = self(self._augmentation_transforms(images)) + outputs = self(self.augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/src/mednet/scripts/experiment.py b/src/mednet/scripts/experiment.py index d08a2d0ae222e8627aa3eaaf20b0eeb70a77d8f3..8cb233f0fbbdce0bde078eaf13579650931fc93b 100644 --- a/src/mednet/scripts/experiment.py +++ b/src/mednet/scripts/experiment.py @@ -21,7 +21,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") cls=ConfigCommand, epilog="""Examples: -\b 1. Train a pasa model with montgomery dataset, on the CPU, for only two epochs, then runs inference and evaluation on stock datasets, report performance as a table and figures: @@ -55,9 +54,10 @@ def experiment( r"""Run a complete experiment, from training, to prediction and evaluation. This script is just a wrapper around the individual scripts for training, - running prediction, and evaluating. It organises the output in a preset way:: + running prediction, and evaluating. It organises the output in a preset way: + + .. code:: - \b └─ <output-folder>/ ├── command.sh ├── model/ # the generated model will be here diff --git a/src/mednet/scripts/predict.py b/src/mednet/scripts/predict.py index a24a98c6d28eb75502fc00565985401c564ec3f1..11c04f7870c3065d9a98d2885f93ba8f8df4aaf1 100644 --- a/src/mednet/scripts/predict.py +++ b/src/mednet/scripts/predict.py @@ -54,6 +54,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") help="""A lightning module instance implementing the network architecture (not the weights, necessarily) to be used for prediction.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( @@ -64,6 +65,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") however this is not a requirement. A DataModule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( diff --git a/src/mednet/scripts/saliency/completeness.py b/src/mednet/scripts/saliency/completeness.py index 5c93aded47a87728bb72df927dd93e959bdadf0b..e02aa3b6a835c7b2383d961c5f100d3693f331b2 100644 --- a/src/mednet/scripts/saliency/completeness.py +++ b/src/mednet/scripts/saliency/completeness.py @@ -35,6 +35,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( @@ -45,6 +46,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") however this is not a requirement. A DataModule that returns a single DataLoader for prediction (wrapped in a dictionary) is acceptable.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( diff --git a/src/mednet/scripts/saliency/interpretability.py b/src/mednet/scripts/saliency/interpretability.py index 83d8cc82836d0bf3e3a595a16d146efd849df22d..7bc6b891cb7f5307ced52c4526427e1ba9b7b555 100644 --- a/src/mednet/scripts/saliency/interpretability.py +++ b/src/mednet/scripts/saliency/interpretability.py @@ -33,6 +33,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") (not the weights, necessarily) to be used for inference. Currently, only supports pasa and densenet models.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( @@ -43,6 +44,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") however this is not a requirement. A DataModule that returns a single dataloader for prediction (wrapped in a dictionary) is acceptable.""", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( diff --git a/src/mednet/scripts/train.py b/src/mednet/scripts/train.py index 9c25de7ae8124a1fe3ce49462c081486f249c17a..9479006c106a2fba929163db54576cf22adb4fcf 100644 --- a/src/mednet/scripts/train.py +++ b/src/mednet/scripts/train.py @@ -51,6 +51,7 @@ def reusable_options(f): "-m", help="A lightning module instance implementing the network to be trained", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( @@ -58,6 +59,7 @@ def reusable_options(f): "-d", help="A lightning DataModule containing the training and validation sets.", required=True, + type=click.UNPROCESSED, cls=ResourceOption, ) @click.option( @@ -209,6 +211,16 @@ def reusable_options(f): default=True, cls=ResourceOption, ) + @click.option( + "--augmentations", + "-A", + help="""Models that can be trained in this package are shipped without + explicit data augmentations. This option allows you to define a list of + data augmentations to use for training the selected model.""", + type=click.UNPROCESSED, + default=[], + cls=ResourceOption, + ) @functools.wraps(f) def wrapper_reusable_options(*args, **kwargs): return f(*args, **kwargs) @@ -221,11 +233,12 @@ def reusable_options(f): cls=ConfigCommand, epilog="""Examples: -1. Train a pasa model with the montgomery dataset, on a GPU (``cuda:0``): +1. Train a pasa model with the montgomery dataset, on a GPU (``cuda:0``), using + simple elastic deformation augmentations: .. code:: sh - mednet train -vv pasa montgomery --batch-size=4 --device="cuda:0" + mednet train -vv pasa elastic montgomery --batch-size=4 --device="cuda:0" """, ) @reusable_options @@ -245,6 +258,7 @@ def train( parallel, monitoring_interval, balance_classes, + augmentations, **_, ) -> None: # numpydoc ignore=PR01 """Train an CNN to perform image classification. @@ -281,6 +295,10 @@ def train( seed_everything(seed) + # report model/transforms options - set data augmentations + logger.info(f"Network model: {type(model).__module__}.{type(model).__name__}") + model.augmentation_transforms = augmentations + # reset datamodule with user configurable options datamodule.drop_incomplete_batch = drop_incomplete_batch datamodule.cache_samples = cache_samples