diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py index ad516e194570dca2b5794959f816a0e717733a9d..4e61c4d5f8d46351f83ecf405ad0dcb3de8a4cee 100644 --- a/src/ptbench/data/transforms.py +++ b/src/ptbench/data/transforms.py @@ -13,92 +13,257 @@ across all input images, but color jittering, for example, only on the input image. """ +import functools +import logging +import multiprocessing.pool import random +import typing -import numpy -import PIL.Image +import numpy.random +import numpy.typing +import torch from scipy.ndimage import gaussian_filter, map_coordinates -from torchvision import transforms + +logger = logging.getLogger(__name__) + + +def _elastic_deformation_on_image( + img: torch.Tensor, + alpha: float = 1000.0, + sigma: float = 30.0, + spline_order: int = 1, + mode: str = "nearest", + p: float = 1.0, +) -> torch.Tensor: + """Performs elastic deformation on an image. + + This implementation is based on 2 scipy functions + (:py:func:`scipy.ndimage.gaussian_filter` and + :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it + requires data is moved off the current running device and then back. + + + Parameters + ---------- + + img + The input image to apply elastic deformation at. This image should + always have this shape: ``[C, H, W]``. It should always represent a + tensor on the CPU. + + alpha + A multiplier for the gaussian filter outputs + + sigma + Standard deviation for Gaussian kernel. + + spline_order + The order of the spline interpolation, default is 1. The order has to + be in the range 0-5. + + mode + The mode parameter determines how the input array is extended beyond + its boundaries. + + p + Probability that this transformation will be applied. Meaningful when + using it as a data augmentation technique. + """ + + if random.random() < p: + assert img.ndim == 3, ( + f"This filter accepts only images with 3 dimensions, " + f"however I got an image with {img.ndim} dimensions." + ) + + # Input tensor is of shape C x H x W + img_shape = img.shape[1:] + + dx = alpha * typing.cast( + numpy.typing.NDArray[numpy.float64], + gaussian_filter( + (numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1), + sigma, + mode="constant", + cval=0.0, + ), + ) + dy = alpha * typing.cast( + numpy.typing.NDArray[numpy.float64], + gaussian_filter( + (numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1), + sigma, + mode="constant", + cval=0.0, + ), + ) + + x, y = numpy.meshgrid( + numpy.arange(img_shape[0]), + numpy.arange(img_shape[1]), + indexing="ij", + ) + indices = [ + numpy.reshape(x + dx, (-1, 1)), + numpy.reshape(y + dy, (-1, 1)), + ] + + # may copy, if img is not on CPU originally + img_numpy = img.numpy() + output = numpy.zeros_like(img_numpy) + for i in range(img.shape[0]): + output[i, :, :] = torch.tensor( + map_coordinates( + img_numpy[i, :, :], + indices, + order=spline_order, + mode=mode, + ).reshape(img_shape) + ) + + # wraps numpy array as tensor (with no copy) + return torch.as_tensor(output) + + return img + + +def _elastic_deformation_on_batch( + batch: torch.Tensor, + alpha: float = 1000.0, + sigma: float = 30.0, + spline_order: int = 1, + mode: str = "nearest", + p: float = 1.0, + pool: multiprocessing.pool.Pool | None = None, +) -> torch.Tensor: + # transforms our custom functions into simpler callables + partial = functools.partial( + _elastic_deformation_on_image, + alpha=alpha, + sigma=sigma, + spline_order=spline_order, + mode=mode, + p=p, + ) + + # if a mp pool is available, do it in parallel + augmented_images: typing.Any + if pool is None: + augmented_images = map(partial, batch.cpu()) + else: + augmented_images = pool.imap(partial, batch.cpu()) + + return torch.stack(list(augmented_images)).to(batch.device) class ElasticDeformation: """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. + This implementation is based on 2 scipy functions + (:py:func:`scipy.ndimage.gaussian_filter` and + :py:func:`scipy.ndimage.map_coordinates`). It is very inefficient since it + requires data is moved off the current running device and then back. + + .. warning:: + + Furthermore, this transform is not scriptable and therefore cannot run + on a CUDA or MPS device. Applying it, effectively creates a bottleneck + in model training. + Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 + + + Parameters + ---------- + + alpha + + sigma + Standard deviation for Gaussian kernel. + + spline_order + The order of the spline interpolation, default is 1. The order has to + be in the range 0-5. + + mode + The mode parameter determines how the input array is extended beyond + its boundaries. + + p + Probability that this transformation will be applied. Meaningful when + using it as a data augmentation technique. + + parallel + Use multiprocessing for processing batches of data: if set to -1 + (default), disables multiprocessing. Set to 0 to enable as many + processes as processing cores as available in the system. Set to >= 1 + to enable that many processes. """ def __init__( self, - alpha=1000, - sigma=30, - spline_order=1, - mode="nearest", - random_state=numpy.random, - p=1.0, + alpha: float = 1000.0, + sigma: float = 30.0, + spline_order: int = 1, + mode: str = "nearest", + p: float = 1.0, + parallel: int = -1, ): - self.alpha = alpha - self.sigma = sigma - self.spline_order = spline_order - self.mode = mode - self.random_state = random_state - self.p = p - - def __call__(self, img): - if random.random() < self.p: - assert img.ndim == 3 - - # Input tensor is of shape C x H x W - # If the tensor only contains one channel, this conversion results in H x W. - # With 3 channels, we get H x W x C - img = transforms.ToPILImage()(img) - img = numpy.asarray(img) - - shape = img.shape[:2] - - dx = ( - gaussian_filter( - (self.random_state.rand(*shape) * 2 - 1), - self.sigma, - mode="constant", - cval=0, - ) - * self.alpha - ) - dy = ( - gaussian_filter( - (self.random_state.rand(*shape) * 2 - 1), - self.sigma, - mode="constant", - cval=0, - ) - * self.alpha + self.alpha: float = alpha + self.sigma: float = sigma + self.spline_order: int = spline_order + self.mode: str = mode + self.p: float = p + self.parallel = parallel + + @property + def parallel(self): + """Use multiprocessing for data augmentation. + + If set to -1 (default), disables multiprocessing data + augmentation. Set to 0 to enable as many data loading instances + as processing cores as available in the system. Set to >= 1 to + enable that many multiprocessing instances for data loading. + """ + return self._parallel + + @parallel.setter + def parallel(self, value): + self._parallel = value + + if value >= 0: + instances = value or multiprocessing.cpu_count() + logger.info( + f"Applying data-augmentation using {instances} processes..." ) + self._mp_pool = multiprocessing.pool.Pool(instances) + else: + self._mp_pool = None - x, y = numpy.meshgrid( - numpy.arange(shape[0]), numpy.arange(shape[1]), indexing="ij" + def __call__(self, img: torch.Tensor) -> torch.Tensor: + if len(img.shape) == 4: + return _elastic_deformation_on_batch( + img, + self.alpha, + self.sigma, + self.spline_order, + self.mode, + self.p, + self._mp_pool, ) - indices = [ - numpy.reshape(x + dx, (-1, 1)), - numpy.reshape(y + dy, (-1, 1)), - ] - result = numpy.empty_like(img) - - if img.ndim == 2: - result[:, :] = map_coordinates( - img[:, :], indices, order=self.spline_order, mode=self.mode - ).reshape(shape) - - else: - for i in range(img.shape[2]): - result[:, :, i] = map_coordinates( - img[:, :, i], - indices, - order=self.spline_order, - mode=self.mode, - ).reshape(shape) - - return transforms.ToTensor()(PIL.Image.fromarray(result)) + elif len(img.shape) == 3: + return _elastic_deformation_on_image( + img.cpu(), + self.alpha, + self.sigma, + self.spline_order, + self.mode, + self.p, + ).to(img.device) - else: - return img + raise RuntimeError( + f"This transform accepts only images with 3 dimensions," + f"or batches of images with 4 dimensions. However, I got " + f"an image with {img.ndim} dimensions." + ) diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py index cd391e3d6d3580dd6b470ad25efb47dff61ce373..7866f36df6fb9fe991e13ae6012c71df1c917777 100644 --- a/src/ptbench/models/alexnet.py +++ b/src/ptbench/models/alexnet.py @@ -174,12 +174,7 @@ class Alexnet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = [ - self._augmentation_transforms(img).to(self.device) for img in images - ] - # Combine list of augmented images back into a tensor - augmented_images = torch.cat(augmented_images, 0).view(images.shape) - outputs = self(augmented_images) + outputs = self(self._augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py index ba1d71fa13a0c56d9ac0db6737fb645ea66a042c..d1f4d03ccb21c4d5bf6d6df60ab717a5a34078fd 100644 --- a/src/ptbench/models/densenet.py +++ b/src/ptbench/models/densenet.py @@ -173,12 +173,7 @@ class Densenet(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = [ - self._augmentation_transforms(img).to(self.device) for img in images - ] - # Combine list of augmented images back into a tensor - augmented_images = torch.cat(augmented_images, 0).view(images.shape) - outputs = self(augmented_images) + outputs = self(self._augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py index 479ec8f23e6a211ca76552dade705471011c2d7e..4e0e281b7429b782faaffb09d2f22fcdc49f61c0 100644 --- a/src/ptbench/models/pasa.py +++ b/src/ptbench/models/pasa.py @@ -243,12 +243,7 @@ class Pasa(pl.LightningModule): labels = torch.reshape(labels, (labels.shape[0], 1)) # Forward pass on the network - augmented_images = [ - self._augmentation_transforms(img).to(self.device) for img in images - ] - # Combine list of augmented images back into a tensor - augmented_images = torch.cat(augmented_images, 0).view(images.shape) - outputs = self(augmented_images) + outputs = self(self._augmentation_transforms(images)) return self._train_loss(outputs, labels.float()) diff --git a/tests/data/raw_with_elastic_deformation.png b/tests/data/raw_with_elastic_deformation.png index 388b4e1ebd446c6e492cafed58963f6a6502b957..61c4ddfe50131884433256ae9ac183ce51493142 100644 Binary files a/tests/data/raw_with_elastic_deformation.png and b/tests/data/raw_with_elastic_deformation.png differ diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3b56fad9238808d96a8ec0cb24108a0b95dc0b1d --- /dev/null +++ b/tests/test_image_utils.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch> +# +# SPDX-License-Identifier: GPL-3.0-or-later + +"""Tests for image utilities.""" + +import numpy +import PIL.Image + +from ptbench.data.image_utils import ( + RemoveBlackBorders, + SingleAutoLevel16to8, + load_pil, +) + + +def test_remove_black_borders(datadir): + # Get a raw sample with black border + data_file = str(datadir / "raw_with_black_border.png") + raw_with_black_border = PIL.Image.open(data_file) + + # Remove the black border + rbb = RemoveBlackBorders() + raw_rbb_removed = rbb(raw_with_black_border) + + # Get the same sample without black border + data_file_2 = str(datadir / "raw_without_black_border.png") + raw_without_black_border = PIL.Image.open(data_file_2) + + # Compare both + raw_rbb_removed = numpy.asarray(raw_rbb_removed) + raw_without_black_border = numpy.asarray(raw_without_black_border) + + numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed) + + +def test_load_pil_16bit(datadir): + # If the ratio is higher 0.5, image is probably clipped + Level16to8 = SingleAutoLevel16to8() + + data_file = str(datadir / "16bits.png") + image = numpy.array(Level16to8(load_pil(data_file))) + + count_pixels = numpy.count_nonzero(image) + count_max_value = numpy.count_nonzero(image == image.max()) + + assert count_max_value / count_pixels < 0.5 + + # It should not do anything to an image already in 8 bits + data_file = str(datadir / "raw_without_black_border.png") + img_loaded = load_pil(data_file) + + original_8bits = numpy.array(img_loaded) + leveled_8bits = numpy.array(Level16to8(img_loaded)) + + numpy.testing.assert_array_equal(original_8bits, leveled_8bits) diff --git a/tests/test_tranforms.py b/tests/test_tranforms.py index fc171770e263e1e9f0d686d5294dd18e1338495f..0dcdf5d64b473f4fccecb72db0baf01d1b67bbd1 100644 --- a/tests/test_tranforms.py +++ b/tests/test_tranforms.py @@ -6,42 +6,19 @@ import numpy import PIL.Image +import torchvision.transforms.functional as F -from ptbench.data.loader import load_pil -from ptbench.data.transforms import ( - ElasticDeformation, - RemoveBlackBorders, - SingleAutoLevel16to8, -) - - -def test_remove_black_borders(datadir): - # Get a raw sample with black border - data_file = str(datadir / "raw_with_black_border.png") - raw_with_black_border = PIL.Image.open(data_file) - - # Remove the black border - rbb = RemoveBlackBorders() - raw_rbb_removed = rbb(raw_with_black_border) - - # Get the same sample without black border - data_file_2 = str(datadir / "raw_without_black_border.png") - raw_without_black_border = PIL.Image.open(data_file_2) - - # Compare both - raw_rbb_removed = numpy.asarray(raw_rbb_removed) - raw_without_black_border = numpy.asarray(raw_without_black_border) - - numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed) +from ptbench.data.transforms import ElasticDeformation def test_elastic_deformation(datadir): # Get a raw sample without deformation data_file = str(datadir / "raw_without_elastic_deformation.png") - raw_without_deformation = PIL.Image.open(data_file) + raw_without_deformation = F.to_tensor(PIL.Image.open(data_file)) # Elastic deforms the raw - ed = ElasticDeformation(random_state=numpy.random.RandomState(seed=100)) + numpy.random.seed(seed=100) + ed = ElasticDeformation() raw_deformed = ed(raw_without_deformation) # Get the same sample already deformed (with seed=100) @@ -49,29 +26,9 @@ def test_elastic_deformation(datadir): raw_2 = PIL.Image.open(data_file_2) # Compare both - raw_deformed = numpy.asarray(raw_deformed) + raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[ + 0, :, : + ] raw_2 = numpy.asarray(raw_2) numpy.testing.assert_array_equal(raw_deformed, raw_2) - - -def test_load_pil_16bit(datadir): - # If the ratio is higher 0.5, image is probably clipped - Level16to8 = SingleAutoLevel16to8() - - data_file = str(datadir / "16bits.png") - image = numpy.array(Level16to8(load_pil(data_file))) - - count_pixels = numpy.count_nonzero(image) - count_max_value = numpy.count_nonzero(image == image.max()) - - assert count_max_value / count_pixels < 0.5 - - # It should not do anything to an image already in 8 bits - data_file = str(datadir / "raw_without_black_border.png") - img_loaded = load_pil(data_file) - - original_8bits = numpy.array(img_loaded) - leveled_8bits = numpy.array(Level16to8(img_loaded)) - - numpy.testing.assert_array_equal(original_8bits, leveled_8bits)