Skip to content
Snippets Groups Projects
Commit fb173347 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.augmentations] Move transforms to augmentations; Implement...

[data.augmentations] Move transforms to augmentations; Implement auto-parellisation for ElasticDeformation (partial affects #5); Reflect changes in all relevant submodules
parent b0c01942
No related branches found
No related tags found
No related merge requests found
Pipeline #76314 failed
......@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...data.transforms import ElasticDeformation
from ...data.augmentations import ElasticDeformation
from ...models.alexnet import Alexnet
model = Alexnet(
......
......@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...data.transforms import ElasticDeformation
from ...data.augmentations import ElasticDeformation
from ...models.alexnet import Alexnet
model = Alexnet(
......
......@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...data.transforms import ElasticDeformation
from ...data.augmentations import ElasticDeformation
from ...models.densenet import Densenet
model = Densenet(
......
......@@ -7,7 +7,7 @@
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...data.transforms import ElasticDeformation
from ...data.augmentations import ElasticDeformation
from ...models.densenet import Densenet
model = Densenet(
......
......@@ -14,7 +14,7 @@ Reference: [PASA-2019]_
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...data.transforms import ElasticDeformation
from ...data.augmentations import ElasticDeformation
from ...models.pasa import Pasa
model = Pasa(
......
......@@ -69,6 +69,13 @@ def _elastic_deformation_on_image(
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
Returns
-------
tensor
A tensor on the CPU.
"""
if random.random() < p:
......@@ -122,7 +129,7 @@ def _elastic_deformation_on_image(
).reshape(img_shape)
)
# wraps numpy array as tensor (with no copy)
# wraps numpy array as tensor, move to destination device if need-be
return torch.as_tensor(output)
return img
......@@ -137,6 +144,47 @@ def _elastic_deformation_on_batch(
p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor:
"""Performs elastic deformation on a batch of images.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
Parameters
----------
img
The input image to apply elastic deformation at. This image should
always have this shape: ``[C, H, W]``. It should always represent a
tensor on the CPU.
alpha
A multiplier for the gaussian filter outputs
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
Returns
-------
tensor
A tensor on the CPU.
"""
# transforms our custom functions into simpler callables
partial = functools.partial(
_elastic_deformation_on_image,
......@@ -154,7 +202,7 @@ def _elastic_deformation_on_batch(
else:
augmented_images = pool.imap(partial, batch.cpu())
return torch.stack(list(augmented_images)).to(batch.device)
return torch.stack(list(augmented_images))
class ElasticDeformation:
......@@ -196,9 +244,11 @@ class ElasticDeformation:
parallel
Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. Set to 0 to enable as many
processes as processing cores as available in the system. Set to >= 1
to enable that many processes.
(default), disables multiprocessing. If set to -2, then enable
auto-tune (use the minimum value between the first batch size and total
number of processing cores). Set to 0 to enable as many processes as
processing cores as available in the system. Set to >= 1 to enable that
many processes.
"""
def __init__(
......@@ -208,7 +258,7 @@ class ElasticDeformation:
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
parallel: int = -1,
parallel: int = -2,
):
self.alpha: float = alpha
self.sigma: float = sigma
......@@ -221,10 +271,11 @@ class ElasticDeformation:
def parallel(self):
"""Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing data
augmentation. Set to 0 to enable as many data loading instances
as processing cores as available in the system. Set to >= 1 to
enable that many multiprocessing instances for data loading.
If set to -1 (default), disables multiprocessing. If set to -2,
then enable auto-tune (use the minimum value between the first
batch size and total number of processing cores). Set to 0 to
enable as many processes as processing cores as available in the
system. Set to >= 1 to enable that many processes.
"""
return self._parallel
......@@ -243,6 +294,11 @@ class ElasticDeformation:
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if len(img.shape) == 4:
if self._mp_pool is None and self._parallel == -2:
# auto-tunning on first batch
instances = min(img.shape[0], multiprocessing.cpu_count())
self._mp_pool = multiprocessing.pool.Pool(instances)
return _elastic_deformation_on_batch(
img,
self.alpha,
......@@ -251,7 +307,8 @@ class ElasticDeformation:
self.mode,
self.p,
self._mp_pool,
)
).to(img.device)
elif len(img.shape) == 3:
return _elastic_deformation_on_image(
img.cpu(),
......
......@@ -62,7 +62,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
from ..augmentations import ElasticDeformation, RemoveBlackBorders
post_transforms = []
if RGB:
......
......@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
from ..augmentations import ElasticDeformation
from ..image_utils import RemoveBlackBorders
post_transforms = []
if RGB:
......
......@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
from ..augmentations import ElasticDeformation
from ..image_utils import RemoveBlackBorders
post_transforms = []
if RGB:
......
......@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
from torchvision import transforms
from .. import make_dataset
from ..transforms import ElasticDeformation
from ..augmentations import ElasticDeformation
post_transforms = []
if RGB:
......
......@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
from torchvision import transforms
from .. import make_dataset
from ..transforms import ElasticDeformation
from ..augmentations import ElasticDeformation
post_transforms = []
if RGB:
......
......@@ -8,7 +8,7 @@ import numpy
import PIL.Image
import torchvision.transforms.functional as F
from ptbench.data.transforms import ElasticDeformation
from ptbench.data.augmentations import ElasticDeformation
def test_elastic_deformation(datadir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment