Skip to content
Snippets Groups Projects
Commit 541cc0a0 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.augmentations] Move transforms to augmentations; Implement...

[data.augmentations] Move transforms to augmentations; Implement auto-parellisation for ElasticDeformation (partial affects #5); Reflect changes in all relevant submodules
parent 9086dfc7
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD from torch.optim import SGD
from ...data.transforms import ElasticDeformation from ...data.augmentations import ElasticDeformation
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD from torch.optim import SGD
from ...data.transforms import ElasticDeformation from ...data.augmentations import ElasticDeformation
from ...models.alexnet import Alexnet from ...models.alexnet import Alexnet
model = Alexnet( model = Alexnet(
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from ...data.transforms import ElasticDeformation from ...data.augmentations import ElasticDeformation
from ...models.densenet import Densenet from ...models.densenet import Densenet
model = Densenet( model = Densenet(
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from ...data.transforms import ElasticDeformation from ...data.augmentations import ElasticDeformation
from ...models.densenet import Densenet from ...models.densenet import Densenet
model = Densenet( model = Densenet(
......
...@@ -14,7 +14,7 @@ Reference: [PASA-2019]_ ...@@ -14,7 +14,7 @@ Reference: [PASA-2019]_
from torch.nn import BCEWithLogitsLoss from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam from torch.optim import Adam
from ...data.transforms import ElasticDeformation from ...data.augmentations import ElasticDeformation
from ...models.pasa import Pasa from ...models.pasa import Pasa
model = Pasa( model = Pasa(
......
...@@ -69,6 +69,13 @@ def _elastic_deformation_on_image( ...@@ -69,6 +69,13 @@ def _elastic_deformation_on_image(
p p
Probability that this transformation will be applied. Meaningful when Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique. using it as a data augmentation technique.
Returns
-------
tensor
A tensor on the CPU.
""" """
if random.random() < p: if random.random() < p:
...@@ -122,7 +129,7 @@ def _elastic_deformation_on_image( ...@@ -122,7 +129,7 @@ def _elastic_deformation_on_image(
).reshape(img_shape) ).reshape(img_shape)
) )
# wraps numpy array as tensor (with no copy) # wraps numpy array as tensor, move to destination device if need-be
return torch.as_tensor(output) return torch.as_tensor(output)
return img return img
...@@ -137,6 +144,47 @@ def _elastic_deformation_on_batch( ...@@ -137,6 +144,47 @@ def _elastic_deformation_on_batch(
p: float = 1.0, p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None, pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Performs elastic deformation on a batch of images.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
Parameters
----------
img
The input image to apply elastic deformation at. This image should
always have this shape: ``[C, H, W]``. It should always represent a
tensor on the CPU.
alpha
A multiplier for the gaussian filter outputs
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
Returns
-------
tensor
A tensor on the CPU.
"""
# transforms our custom functions into simpler callables # transforms our custom functions into simpler callables
partial = functools.partial( partial = functools.partial(
_elastic_deformation_on_image, _elastic_deformation_on_image,
...@@ -154,7 +202,7 @@ def _elastic_deformation_on_batch( ...@@ -154,7 +202,7 @@ def _elastic_deformation_on_batch(
else: else:
augmented_images = pool.imap(partial, batch.cpu()) augmented_images = pool.imap(partial, batch.cpu())
return torch.stack(list(augmented_images)).to(batch.device) return torch.stack(list(augmented_images))
class ElasticDeformation: class ElasticDeformation:
...@@ -196,9 +244,11 @@ class ElasticDeformation: ...@@ -196,9 +244,11 @@ class ElasticDeformation:
parallel parallel
Use multiprocessing for processing batches of data: if set to -1 Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. Set to 0 to enable as many (default), disables multiprocessing. If set to -2, then enable
processes as processing cores as available in the system. Set to >= 1 auto-tune (use the minimum value between the first batch size and total
to enable that many processes. number of processing cores). Set to 0 to enable as many processes as
processing cores as available in the system. Set to >= 1 to enable that
many processes.
""" """
def __init__( def __init__(
...@@ -208,7 +258,7 @@ class ElasticDeformation: ...@@ -208,7 +258,7 @@ class ElasticDeformation:
spline_order: int = 1, spline_order: int = 1,
mode: str = "nearest", mode: str = "nearest",
p: float = 1.0, p: float = 1.0,
parallel: int = -1, parallel: int = -2,
): ):
self.alpha: float = alpha self.alpha: float = alpha
self.sigma: float = sigma self.sigma: float = sigma
...@@ -221,10 +271,11 @@ class ElasticDeformation: ...@@ -221,10 +271,11 @@ class ElasticDeformation:
def parallel(self): def parallel(self):
"""Use multiprocessing for data augmentation. """Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing data If set to -1 (default), disables multiprocessing. If set to -2,
augmentation. Set to 0 to enable as many data loading instances then enable auto-tune (use the minimum value between the first
as processing cores as available in the system. Set to >= 1 to batch size and total number of processing cores). Set to 0 to
enable that many multiprocessing instances for data loading. enable as many processes as processing cores as available in the
system. Set to >= 1 to enable that many processes.
""" """
return self._parallel return self._parallel
...@@ -243,6 +294,11 @@ class ElasticDeformation: ...@@ -243,6 +294,11 @@ class ElasticDeformation:
def __call__(self, img: torch.Tensor) -> torch.Tensor: def __call__(self, img: torch.Tensor) -> torch.Tensor:
if len(img.shape) == 4: if len(img.shape) == 4:
if self._mp_pool is None and self._parallel == -2:
# auto-tunning on first batch
instances = min(img.shape[0], multiprocessing.cpu_count())
self._mp_pool = multiprocessing.pool.Pool(instances)
return _elastic_deformation_on_batch( return _elastic_deformation_on_batch(
img, img,
self.alpha, self.alpha,
...@@ -251,7 +307,8 @@ class ElasticDeformation: ...@@ -251,7 +307,8 @@ class ElasticDeformation:
self.mode, self.mode,
self.p, self.p,
self._mp_pool, self._mp_pool,
) ).to(img.device)
elif len(img.shape) == 3: elif len(img.shape) == 3:
return _elastic_deformation_on_image( return _elastic_deformation_on_image(
img.cpu(), img.cpu(),
......
...@@ -62,7 +62,7 @@ json_dataset = JSONDataset( ...@@ -62,7 +62,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False): def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders from ..augmentations import ElasticDeformation, RemoveBlackBorders
post_transforms = [] post_transforms = []
if RGB: if RGB:
......
...@@ -62,7 +62,8 @@ json_dataset = JSONDataset( ...@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False): def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders from ..augmentations import ElasticDeformation
from ..image_utils import RemoveBlackBorders
post_transforms = [] post_transforms = []
if RGB: if RGB:
......
...@@ -62,7 +62,8 @@ json_dataset = JSONDataset( ...@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False): def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders from ..augmentations import ElasticDeformation
from ..image_utils import RemoveBlackBorders
post_transforms = [] post_transforms = []
if RGB: if RGB:
......
...@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False): ...@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
from torchvision import transforms from torchvision import transforms
from .. import make_dataset from .. import make_dataset
from ..transforms import ElasticDeformation from ..augmentations import ElasticDeformation
post_transforms = [] post_transforms = []
if RGB: if RGB:
......
...@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False): ...@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
from torchvision import transforms from torchvision import transforms
from .. import make_dataset from .. import make_dataset
from ..transforms import ElasticDeformation from ..augmentations import ElasticDeformation
post_transforms = [] post_transforms = []
if RGB: if RGB:
......
...@@ -8,7 +8,7 @@ import numpy ...@@ -8,7 +8,7 @@ import numpy
import PIL.Image import PIL.Image
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from ptbench.data.transforms import ElasticDeformation from ptbench.data.augmentations import ElasticDeformation
def test_elastic_deformation(datadir): def test_elastic_deformation(datadir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment