Skip to content
Snippets Groups Projects
Commit 9f737974 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.transforms] Properly implement ElasticDeformation so it works with...

[data.transforms] Properly implement ElasticDeformation so it works with images or batches; Update test case for transforms; Create new test case for image_utils; Update models to reflect changes (partially addresses #5)
parent 2bd2a302
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
...@@ -13,92 +13,257 @@ across all input images, but color jittering, for example, only on the ...@@ -13,92 +13,257 @@ across all input images, but color jittering, for example, only on the
input image. input image.
""" """
import functools
import logging
import multiprocessing.pool
import random import random
import typing
import numpy import numpy.random
import PIL.Image import numpy.typing
import torch
from scipy.ndimage import gaussian_filter, map_coordinates from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
logger = logging.getLogger(__name__)
def _elastic_deformation_on_image(
img: torch.Tensor,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
) -> torch.Tensor:
"""Performs elastic deformation on an image.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
Parameters
----------
img
The input image to apply elastic deformation at. This image should
always have this shape: ``[C, H, W]``. It should always represent a
tensor on the CPU.
alpha
A multiplier for the gaussian filter outputs
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
"""
if random.random() < p:
assert img.ndim == 3, (
f"This filter accepts only images with 3 dimensions, "
f"however I got an image with {img.ndim} dimensions."
)
# Input tensor is of shape C x H x W
img_shape = img.shape[1:]
dx = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
dy = alpha * typing.cast(
numpy.typing.NDArray[numpy.float64],
gaussian_filter(
(numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
sigma,
mode="constant",
cval=0.0,
),
)
x, y = numpy.meshgrid(
numpy.arange(img_shape[0]),
numpy.arange(img_shape[1]),
indexing="ij",
)
indices = [
numpy.reshape(x + dx, (-1, 1)),
numpy.reshape(y + dy, (-1, 1)),
]
# may copy, if img is not on CPU originally
img_numpy = img.numpy()
output = numpy.zeros_like(img_numpy)
for i in range(img.shape[0]):
output[i, :, :] = torch.tensor(
map_coordinates(
img_numpy[i, :, :],
indices,
order=spline_order,
mode=mode,
).reshape(img_shape)
)
# wraps numpy array as tensor (with no copy)
return torch.as_tensor(output)
return img
def _elastic_deformation_on_batch(
batch: torch.Tensor,
alpha: float = 1000.0,
sigma: float = 30.0,
spline_order: int = 1,
mode: str = "nearest",
p: float = 1.0,
pool: multiprocessing.pool.Pool | None = None,
) -> torch.Tensor:
# transforms our custom functions into simpler callables
partial = functools.partial(
_elastic_deformation_on_image,
alpha=alpha,
sigma=sigma,
spline_order=spline_order,
mode=mode,
p=p,
)
# if a mp pool is available, do it in parallel
augmented_images: typing.Any
if pool is None:
augmented_images = map(partial, batch.cpu())
else:
augmented_images = pool.imap(partial, batch.cpu())
return torch.stack(list(augmented_images)).to(batch.device)
class ElasticDeformation: class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
This implementation is based on 2 scipy functions
(:py:func:`scipy.ndimage.gaussian_filter` and
:py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it
requires data is moved off the current running device and then back.
.. warning::
Furthermore, this transform is not scriptable and therefore cannot run
on a CUDA or MPS device. Applying it, effectively creates a bottleneck
in model training.
Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
Parameters
----------
alpha
sigma
Standard deviation for Gaussian kernel.
spline_order
The order of the spline interpolation, default is 1. The order has to
be in the range 0-5.
mode
The mode parameter determines how the input array is extended beyond
its boundaries.
p
Probability that this transformation will be applied. Meaningful when
using it as a data augmentation technique.
parallel
Use multiprocessing for processing batches of data: if set to -1
(default), disables multiprocessing. Set to 0 to enable as many
processes as processing cores as available in the system. Set to >= 1
to enable that many processes.
""" """
def __init__( def __init__(
self, self,
alpha=1000, alpha: float = 1000.0,
sigma=30, sigma: float = 30.0,
spline_order=1, spline_order: int = 1,
mode="nearest", mode: str = "nearest",
random_state=numpy.random, p: float = 1.0,
p=1.0, parallel: int = -1,
): ):
self.alpha = alpha self.alpha: float = alpha
self.sigma = sigma self.sigma: float = sigma
self.spline_order = spline_order self.spline_order: int = spline_order
self.mode = mode self.mode: str = mode
self.random_state = random_state self.p: float = p
self.p = p self.parallel = parallel
def __call__(self, img): @property
if random.random() < self.p: def parallel(self):
assert img.ndim == 3 """Use multiprocessing for data augmentation.
# Input tensor is of shape C x H x W If set to -1 (default), disables multiprocessing data
# If the tensor only contains one channel, this conversion results in H x W. augmentation. Set to 0 to enable as many data loading instances
# With 3 channels, we get H x W x C as processing cores as available in the system. Set to >= 1 to
img = transforms.ToPILImage()(img) enable that many multiprocessing instances for data loading.
img = numpy.asarray(img) """
return self._parallel
shape = img.shape[:2]
@parallel.setter
dx = ( def parallel(self, value):
gaussian_filter( self._parallel = value
(self.random_state.rand(*shape) * 2 - 1),
self.sigma, if value >= 0:
mode="constant", instances = value or multiprocessing.cpu_count()
cval=0, logger.info(
) f"Applying data-augmentation using {instances} processes..."
* self.alpha
)
dy = (
gaussian_filter(
(self.random_state.rand(*shape) * 2 - 1),
self.sigma,
mode="constant",
cval=0,
)
* self.alpha
) )
self._mp_pool = multiprocessing.pool.Pool(instances)
else:
self._mp_pool = None
x, y = numpy.meshgrid( def __call__(self, img: torch.Tensor) -> torch.Tensor:
numpy.arange(shape[0]), numpy.arange(shape[1]), indexing="ij" if len(img.shape) == 4:
return _elastic_deformation_on_batch(
img,
self.alpha,
self.sigma,
self.spline_order,
self.mode,
self.p,
self._mp_pool,
) )
indices = [ elif len(img.shape) == 3:
numpy.reshape(x + dx, (-1, 1)), return _elastic_deformation_on_image(
numpy.reshape(y + dy, (-1, 1)), img.cpu(),
] self.alpha,
result = numpy.empty_like(img) self.sigma,
self.spline_order,
if img.ndim == 2: self.mode,
result[:, :] = map_coordinates( self.p,
img[:, :], indices, order=self.spline_order, mode=self.mode ).to(img.device)
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result))
else: raise RuntimeError(
return img f"This transform accepts only images with 3 dimensions,"
f"or batches of images with 4 dimensions. However, I got "
f"an image with {img.ndim} dimensions."
)
...@@ -174,12 +174,7 @@ class Alexnet(pl.LightningModule): ...@@ -174,12 +174,7 @@ class Alexnet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
augmented_images = [ outputs = self(self._augmentation_transforms(images))
self._augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
...@@ -173,12 +173,7 @@ class Densenet(pl.LightningModule): ...@@ -173,12 +173,7 @@ class Densenet(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
augmented_images = [ outputs = self(self._augmentation_transforms(images))
self._augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
...@@ -243,12 +243,7 @@ class Pasa(pl.LightningModule): ...@@ -243,12 +243,7 @@ class Pasa(pl.LightningModule):
labels = torch.reshape(labels, (labels.shape[0], 1)) labels = torch.reshape(labels, (labels.shape[0], 1))
# Forward pass on the network # Forward pass on the network
augmented_images = [ outputs = self(self._augmentation_transforms(images))
self._augmentation_transforms(img).to(self.device) for img in images
]
# Combine list of augmented images back into a tensor
augmented_images = torch.cat(augmented_images, 0).view(images.shape)
outputs = self(augmented_images)
return self._train_loss(outputs, labels.float()) return self._train_loss(outputs, labels.float())
......
tests/data/raw_with_elastic_deformation.png

105 KiB | W: | H:

tests/data/raw_with_elastic_deformation.png

85.6 KiB | W: | H:

tests/data/raw_with_elastic_deformation.png
tests/data/raw_with_elastic_deformation.png
tests/data/raw_with_elastic_deformation.png
tests/data/raw_with_elastic_deformation.png
  • 2-up
  • Swipe
  • Onion skin
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Tests for image utilities."""
import numpy
import PIL.Image
from ptbench.data.image_utils import (
RemoveBlackBorders,
SingleAutoLevel16to8,
load_pil,
)
def test_remove_black_borders(datadir):
# Get a raw sample with black border
data_file = str(datadir / "raw_with_black_border.png")
raw_with_black_border = PIL.Image.open(data_file)
# Remove the black border
rbb = RemoveBlackBorders()
raw_rbb_removed = rbb(raw_with_black_border)
# Get the same sample without black border
data_file_2 = str(datadir / "raw_without_black_border.png")
raw_without_black_border = PIL.Image.open(data_file_2)
# Compare both
raw_rbb_removed = numpy.asarray(raw_rbb_removed)
raw_without_black_border = numpy.asarray(raw_without_black_border)
numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)
def test_load_pil_16bit(datadir):
# If the ratio is higher 0.5, image is probably clipped
Level16to8 = SingleAutoLevel16to8()
data_file = str(datadir / "16bits.png")
image = numpy.array(Level16to8(load_pil(data_file)))
count_pixels = numpy.count_nonzero(image)
count_max_value = numpy.count_nonzero(image == image.max())
assert count_max_value / count_pixels < 0.5
# It should not do anything to an image already in 8 bits
data_file = str(datadir / "raw_without_black_border.png")
img_loaded = load_pil(data_file)
original_8bits = numpy.array(img_loaded)
leveled_8bits = numpy.array(Level16to8(img_loaded))
numpy.testing.assert_array_equal(original_8bits, leveled_8bits)
...@@ -6,42 +6,19 @@ ...@@ -6,42 +6,19 @@
import numpy import numpy
import PIL.Image import PIL.Image
import torchvision.transforms.functional as F
from ptbench.data.loader import load_pil from ptbench.data.transforms import ElasticDeformation
from ptbench.data.transforms import (
ElasticDeformation,
RemoveBlackBorders,
SingleAutoLevel16to8,
)
def test_remove_black_borders(datadir):
# Get a raw sample with black border
data_file = str(datadir / "raw_with_black_border.png")
raw_with_black_border = PIL.Image.open(data_file)
# Remove the black border
rbb = RemoveBlackBorders()
raw_rbb_removed = rbb(raw_with_black_border)
# Get the same sample without black border
data_file_2 = str(datadir / "raw_without_black_border.png")
raw_without_black_border = PIL.Image.open(data_file_2)
# Compare both
raw_rbb_removed = numpy.asarray(raw_rbb_removed)
raw_without_black_border = numpy.asarray(raw_without_black_border)
numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)
def test_elastic_deformation(datadir): def test_elastic_deformation(datadir):
# Get a raw sample without deformation # Get a raw sample without deformation
data_file = str(datadir / "raw_without_elastic_deformation.png") data_file = str(datadir / "raw_without_elastic_deformation.png")
raw_without_deformation = PIL.Image.open(data_file) raw_without_deformation = F.to_tensor(PIL.Image.open(data_file))
# Elastic deforms the raw # Elastic deforms the raw
ed = ElasticDeformation(random_state=numpy.random.RandomState(seed=100)) numpy.random.seed(seed=100)
ed = ElasticDeformation()
raw_deformed = ed(raw_without_deformation) raw_deformed = ed(raw_without_deformation)
# Get the same sample already deformed (with seed=100) # Get the same sample already deformed (with seed=100)
...@@ -49,29 +26,9 @@ def test_elastic_deformation(datadir): ...@@ -49,29 +26,9 @@ def test_elastic_deformation(datadir):
raw_2 = PIL.Image.open(data_file_2) raw_2 = PIL.Image.open(data_file_2)
# Compare both # Compare both
raw_deformed = numpy.asarray(raw_deformed) raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[
0, :, :
]
raw_2 = numpy.asarray(raw_2) raw_2 = numpy.asarray(raw_2)
numpy.testing.assert_array_equal(raw_deformed, raw_2) numpy.testing.assert_array_equal(raw_deformed, raw_2)
def test_load_pil_16bit(datadir):
# If the ratio is higher 0.5, image is probably clipped
Level16to8 = SingleAutoLevel16to8()
data_file = str(datadir / "16bits.png")
image = numpy.array(Level16to8(load_pil(data_file)))
count_pixels = numpy.count_nonzero(image)
count_max_value = numpy.count_nonzero(image == image.max())
assert count_max_value / count_pixels < 0.5
# It should not do anything to an image already in 8 bits
data_file = str(datadir / "raw_without_black_border.png")
img_loaded = load_pil(data_file)
original_8bits = numpy.array(img_loaded)
leveled_8bits = numpy.array(Level16to8(img_loaded))
numpy.testing.assert_array_equal(original_8bits, leveled_8bits)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment