Skip to content
Snippets Groups Projects
Commit 0bec86c6 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

Implement command-line configurable data augmentations

parent 65531c7b
No related branches found
No related tags found
1 merge request!48Implement command-line configurable data augmentations
Pipeline #87943 passed
Showing
with 93 additions and 23 deletions
...@@ -32,7 +32,20 @@ Pre-configured models you can readily use. ...@@ -32,7 +32,20 @@ Pre-configured models you can readily use.
mednet.config.models.pasa mednet.config.models.pasa
.. _mednet.config.datamodules: Data Augmentations
==================
Sequences of data augmentations you can readily use.
.. _mednet.config.augmentations:
.. autosummary::
:toctree: api/config.augmentations
:template: config.rst
mednet.config.augmentations.elastic
mednet.config.augmentations.affine
DataModule support DataModule support
================== ==================
...@@ -42,6 +55,8 @@ supported in this package, for your reference. Each pre-configured DataModule ...@@ -42,6 +55,8 @@ supported in this package, for your reference. Each pre-configured DataModule
can receive the name of one or more splits as argument to build a fully can receive the name of one or more splits as argument to build a fully
functional DataModule that can be used in training, prediction or testing. functional DataModule that can be used in training, prediction or testing.
.. _mednet.config.datamodules:
.. autosummary:: .. autosummary::
:toctree: api/config.datamodules :toctree: api/config.datamodules
......
...@@ -78,6 +78,6 @@ ...@@ -78,6 +78,6 @@
2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020 2020 pp. 111-119. doi: https://doi.org/10.1109/CVPRW50498.2020.00020
.. [ROAD-2022] *Y. Rong, T. Leemann, V. Borisov, G. Kasneci, and E. Kasneci*, .. [ROAD-2022] *Y. Rong, T. Leemann, V. Borisov, G. Kasneci, and E. Kasneci*,
*A Consistent and Efficient Evaluation Strategy for Attribution Methods* in **A Consistent and Efficient Evaluation Strategy for Attribution Methods** in
Proceedings of the 39th International Conference on Machine Learning, PMLR, Proceedings of the 39th International Conference on Machine Learning, PMLR,
Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html Jun. 2022, pp. 18770–18795. https://proceedings.mlr.press/v162/rong22a.html
...@@ -244,6 +244,10 @@ alexnet-pretrained = "mednet.config.models.alexnet_pretrained" ...@@ -244,6 +244,10 @@ alexnet-pretrained = "mednet.config.models.alexnet_pretrained"
densenet = "mednet.config.models.densenet" densenet = "mednet.config.models.densenet"
densenet-pretrained = "mednet.config.models.densenet_pretrained" densenet-pretrained = "mednet.config.models.densenet_pretrained"
# lists of data augmentations
elastic = "mednet.config.augmentations.elastic"
affine = "mednet.config.augmentations.affine"
# montgomery dataset (and cross-validation folds) # montgomery dataset (and cross-validation folds)
montgomery = "mednet.config.data.montgomery.default" montgomery = "mednet.config.data.montgomery.default"
montgomery-f0 = "mednet.config.data.montgomery.fold_0" montgomery-f0 = "mednet.config.data.montgomery.fold_0"
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Simple affine augmentations for image analysis."""
import torchvision.transforms
augmentations = [
torchvision.transforms.RandomAffine(
degrees=10,
translate=(0.1, 0.1), # horizontal, vertical
scale=(0.8, 1.0), # minimum, maximum
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
),
torchvision.transforms.RandomHorizontalFlip(p=0.5),
]
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Elastic deformation with 80% probability.
This sole data augmentation was proposed by Pasa in the article
"Efficient Deep Network Architectures for Fast Chest X-Ray Tuberculosis
Screening and Visualization".
Reference: [PASA-2019]_
"""
from mednet.data.augmentations import ElasticDeformation
augmentations = [ElasticDeformation(p=0.8)]
...@@ -11,13 +11,11 @@ page <alexnet-pytorch_>`_), modified for a variable number of outputs ...@@ -11,13 +11,11 @@ page <alexnet-pytorch_>`_), modified for a variable number of outputs
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD from torch.optim import SGD
from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
pretrained=False, pretrained=False,
) )
...@@ -13,13 +13,11 @@ N.B.: The output layer is **always** initialized from scratch. ...@@ -13,13 +13,11 @@ N.B.: The output layer is **always** initialized from scratch.
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD from torch.optim import SGD
from mednet.data.augmentations import ElasticDeformation
from mednet.models.alexnet import Alexnet from mednet.models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=SGD, optimizer_type=SGD,
optimizer_arguments=dict(lr=0.01, momentum=0.1), optimizer_arguments=dict(lr=0.01, momentum=0.1),
augmentation_transforms=[ElasticDeformation(p=0.8)],
pretrained=True, pretrained=True,
) )
...@@ -11,14 +11,12 @@ page <densenet_pytorch_>`), modified for a variable number of outputs ...@@ -11,14 +11,12 @@ page <densenet_pytorch_>`), modified for a variable number of outputs
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=False, pretrained=False,
dropout=0.1, dropout=0.1,
) )
...@@ -13,14 +13,12 @@ N.B.: The output layer is **always** initialized from scratch. ...@@ -13,14 +13,12 @@ N.B.: The output layer is **always** initialized from scratch.
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=True, pretrained=True,
dropout=0.1, dropout=0.1,
) )
...@@ -12,14 +12,12 @@ weights from scratch for radiological sign detection. ...@@ -12,14 +12,12 @@ weights from scratch for radiological sign detection.
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from mednet.data.augmentations import ElasticDeformation
from mednet.models.densenet import Densenet from mednet.models.densenet import Densenet
model = Densenet( model = Densenet(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=0.0001), optimizer_arguments=dict(lr=0.0001),
augmentation_transforms=[ElasticDeformation(p=0.2)],
pretrained=False, pretrained=False,
dropout=0.1, dropout=0.1,
num_classes=14, # number of classes in NIH CXR-14 num_classes=14, # number of classes in NIH CXR-14
......
...@@ -13,12 +13,10 @@ Reference: [PASA-2019]_ ...@@ -13,12 +13,10 @@ Reference: [PASA-2019]_
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from mednet.data.augmentations import ElasticDeformation
from mednet.models.pasa import Pasa from mednet.models.pasa import Pasa
model = Pasa( model = Pasa(
loss_type=BCEWithLogitsLoss, loss_type=BCEWithLogitsLoss,
optimizer_type=Adam, optimizer_type=Adam,
optimizer_arguments=dict(lr=8e-5), optimizer_arguments=dict(lr=8e-5),
augmentation_transforms=[ElasticDeformation(p=0.8)],
) )
...@@ -242,6 +242,17 @@ class ElasticDeformation: ...@@ -242,6 +242,17 @@ class ElasticDeformation:
self.p: float = p self.p: float = p
self.parallel = parallel self.parallel = parallel
def __str__(self) -> str:
parameters = [
f"alpha={self.alpha}",
f"sigma={self.sigma}",
f"spline_order={self.spline_order}",
f"mode={self.mode}",
f"p={self.p}",
f"parallel={self.parallel}",
]
return f"{type(self).__name__}({', '.join(parameters)})"
@property @property
def parallel(self) -> int: def parallel(self) -> int:
"""Use multiprocessing for data augmentation. """Use multiprocessing for data augmentation.
......
...@@ -138,7 +138,7 @@ class Alexnet(Model): ...@@ -138,7 +138,7 @@ class Alexnet(Model):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
outputs = self(self._augmentation_transforms(images)) outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
...@@ -141,7 +141,7 @@ class Densenet(Model): ...@@ -141,7 +141,7 @@ class Densenet(Model):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
outputs = self(self._augmentation_transforms(images)) outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
...@@ -71,9 +71,26 @@ class Model(pl.LightningModule): ...@@ -71,9 +71,26 @@ class Model(pl.LightningModule):
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._optimizer_arguments = optimizer_arguments self._optimizer_arguments = optimizer_arguments
self._augmentation_transforms = torchvision.transforms.Compose( self.augmentation_transforms = augmentation_transforms
augmentation_transforms,
) @property
def augmentation_transforms(self) -> torchvision.transforms.Compose:
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, v: TransformSequence):
self._augmentation_transforms = torchvision.transforms.Compose(v)
if len(v) != 0:
transforms_str = ", ".join(
[
f"{type(k).__module__}.{str(k)}"
for k in self._augmentation_transforms.transforms
]
)
logger.info(f"Data augmentations: {transforms_str}")
else:
logger.info("Data augmentations: None")
def forward(self, x): def forward(self, x):
raise NotImplementedError raise NotImplementedError
......
...@@ -211,7 +211,7 @@ class Pasa(Model): ...@@ -211,7 +211,7 @@ class Pasa(Model):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
outputs = self(self._augmentation_transforms(images)) outputs = self(self.augmentation_transforms(images))
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
...@@ -21,7 +21,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -21,7 +21,6 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ConfigCommand, cls=ConfigCommand,
epilog="""Examples: epilog="""Examples:
\b
1. Train a pasa model with montgomery dataset, on the CPU, for only two 1. Train a pasa model with montgomery dataset, on the CPU, for only two
epochs, then runs inference and evaluation on stock datasets, report epochs, then runs inference and evaluation on stock datasets, report
performance as a table and figures: performance as a table and figures:
...@@ -55,9 +54,10 @@ def experiment( ...@@ -55,9 +54,10 @@ def experiment(
r"""Run a complete experiment, from training, to prediction and evaluation. r"""Run a complete experiment, from training, to prediction and evaluation.
This script is just a wrapper around the individual scripts for training, This script is just a wrapper around the individual scripts for training,
running prediction, and evaluating. It organises the output in a preset way:: running prediction, and evaluating. It organises the output in a preset way:
.. code::
\b
└─ <output-folder>/ └─ <output-folder>/
├── command.sh ├── command.sh
├── model/ # the generated model will be here ├── model/ # the generated model will be here
......
...@@ -54,6 +54,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -54,6 +54,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
help="""A lightning module instance implementing the network architecture help="""A lightning module instance implementing the network architecture
(not the weights, necessarily) to be used for prediction.""", (not the weights, necessarily) to be used for prediction.""",
required=True, required=True,
type=click.UNPROCESSED,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -64,6 +65,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -64,6 +65,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
however this is not a requirement. A DataModule that returns a single however this is not a requirement. A DataModule that returns a single
dataloader for prediction (wrapped in a dictionary) is acceptable.""", dataloader for prediction (wrapped in a dictionary) is acceptable.""",
required=True, required=True,
type=click.UNPROCESSED,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
......
...@@ -35,6 +35,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -35,6 +35,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
(not the weights, necessarily) to be used for inference. Currently, only (not the weights, necessarily) to be used for inference. Currently, only
supports pasa and densenet models.""", supports pasa and densenet models.""",
required=True, required=True,
type=click.UNPROCESSED,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
...@@ -45,6 +46,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") ...@@ -45,6 +46,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
however this is not a requirement. A DataModule that returns a single however this is not a requirement. A DataModule that returns a single
DataLoader for prediction (wrapped in a dictionary) is acceptable.""", DataLoader for prediction (wrapped in a dictionary) is acceptable.""",
required=True, required=True,
type=click.UNPROCESSED,
cls=ResourceOption, cls=ResourceOption,
) )
@click.option( @click.option(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment