Skip to content
Snippets Groups Projects

Making use of LightningDataModule and simplification of data loading

Merged Daniel CARRON requested to merge add-datamodule into main
Compare and Show latest version
9 files
+ 339
279
Compare changes
  • Side-by-side
  • Inline
Files
9
@@ -13,92 +13,257 @@ across all input images, but color jittering, for example, only on the
input image.
"""
import functools
import logging
import multiprocessing.pool
import random
import typing
import numpy
import PIL.Image
import numpy.random
import numpy.typing
import torch
from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
logger = logging.getLogger(__name__)
def _elastic_deformation_on_image(
img: torch.Tensor,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
) -> torch.Tensor:
"""Performs elastic deformation on an image.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
Parameters
----------
img
The input image to apply elastic deformation at. This image should
always have this shape: ``[C, H, W]``. It should always represent a
tensor on the CPU.
alpha
A multiplier for the gaussian filter outputs
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
"""
if random.random() < p:
assert img.ndim == 3, (
f"This filter accepts only images with 3 dimensions, "
f"however I got an image with {img.ndim} dimensions."
)
# Input tensor is of shape C x H x W
img_shape = img.shape[1:]
dx = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
dy = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
x, y = numpy.meshgrid(
numpy.arange(img_shape[0]),
numpy.arange(img_shape[1]),
indexing="ij",
)
indices = [
numpy.reshape(x + dx, (-1, 1)),
numpy.reshape(y + dy, (-1, 1)),
]
# may copy, if img is not on CPU originally
img_numpy = img.numpy()
output = numpy.zeros_like(img_numpy)
for i in range(img.shape[0]):
output[i, :, :] = torch.tensor(
map_coordinates(
img_numpy[i, :, :],
indices,
order=spline_order,
mode=mode,
).reshape(img_shape)
)
# wraps numpy array as tensor (with no copy)
return torch.as_tensor(output)
return img
def _elastic_deformation_on_batch(
batch: torch.Tensor,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor:
# transforms our custom functions into simpler callables
partial = functools.partial(
_elastic_deformation_on_image,
alpha=alpha,
sigma=sigma,
spline_order=spline_order,
mode=mode,
p=p,
)
# if a mp pool is available, do it in parallel
augmented_images: typing.Any
if pool is None:
augmented_images = map(partial, batch.cpu())
else:
augmented_images = pool.imap(partial, batch.cpu())
return torch.stack(list(augmented_images)).to(batch.device)
class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
.. warning::
Furthermore, this transform is not scriptable and therefore cannot run
on a CUDA or MPS device. Applying it, effectively creates a bottleneck
in model training.
Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
Parameters
----------
alpha
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
parallel
Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. Set to 0 to enable as many
processes as processing cores as available in the system. Set to >= 1
to enable that many processes.
"""
def __init__(
self,
alpha=1000,
sigma=30,
spline_order=1,
mode="nearest",
random_state=numpy.random,
p=1.0,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
parallel: int = -1,
):
self.alpha = alpha
self.sigma = sigma
self.spline_order = spline_order
self.mode = mode
self.random_state = random_state
self.p = p
def __call__(self, img):
if random.random() < self.p:
assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img)
shape = img.shape[:2]
dx = (
gaussian_filter(
(self.random_state.rand(*shape) * 2 - 1),
self.sigma,
mode="constant",
cval=0,
)
* self.alpha
)
dy = (
gaussian_filter(
(self.random_state.rand(*shape) * 2 - 1),
self.sigma,
mode="constant",
cval=0,
)
* self.alpha
self.alpha: float = alpha
self.sigma: float = sigma
self.spline_order: int = spline_order
self.mode: str = mode
self.p: float = p
self.parallel = parallel
@property
def parallel(self):
"""Use multiprocessing for data augmentation.
If set to -1 (default), disables multiprocessing data
augmentation. Set to 0 to enable as many data loading instances
as processing cores as available in the system. Set to >= 1 to
enable that many multiprocessing instances for data loading.
"""
return self._parallel
@parallel.setter
def parallel(self, value):
self._parallel = value
if value >= 0:
instances = value or multiprocessing.cpu_count()
logger.info(
f"Applying data-augmentation using {instances} processes..."
)
self._mp_pool = multiprocessing.pool.Pool(instances)
else:
self._mp_pool = None
x, y = numpy.meshgrid(
numpy.arange(shape[0]), numpy.arange(shape[1]), indexing="ij"
def __call__(self, img: torch.Tensor) -> torch.Tensor:
if len(img.shape) == 4:
return _elastic_deformation_on_batch(
img,
self.alpha,
self.sigma,
self.spline_order,
self.mode,
self.p,
self._mp_pool,
)
indices = [
numpy.reshape(x + dx, (-1, 1)),
numpy.reshape(y + dy, (-1, 1)),
]
result = numpy.empty_like(img)
if img.ndim == 2:
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result))
elif len(img.shape) == 3:
return _elastic_deformation_on_image(
img.cpu(),
self.alpha,
self.sigma,
self.spline_order,
self.mode,
self.p,
).to(img.device)
else:
return img
raise RuntimeError(
f"This transform accepts only images with 3 dimensions,"
f"or batches of images with 4 dimensions. However, I got "
f"an image with {img.ndim} dimensions."
)
Loading