diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/transforms.py
index ad516e194570dca2b5794959f816a0e717733a9d..4e61c4d5f8d46351f83ecf405ad0dcb3de8a4cee 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/transforms.py
@@ -13,92 +13,257 @@ across all input images, but color jittering, for example, only on the
 input image.
 """
 
+import functools
+import logging
+import multiprocessing.pool
 import random
+import typing
 
-import numpy
-import PIL.Image
+import numpy.random
+import numpy.typing
+import torch
 
 from scipy.ndimage import gaussian_filter, map_coordinates
-from torchvision import transforms
+
+logger = logging.getLogger(__name__)
+
+
+def _elastic_deformation_on_image(
+    img: torch.Tensor,
+    alpha: float = 1000.0,
+    sigma: float = 30.0,
+    spline_order: int = 1,
+    mode: str = "nearest",
+    p: float = 1.0,
+) -> torch.Tensor:
+    """Performs elastic deformation on an image.
+
+    This implementation is based on 2 scipy functions
+    (:py:func:`scipy.ndimage.gaussian_filter` and
+    :py:func:`scipy.ndimage.map_coordinates`).  It is very inefficient since it
+    requires data is moved off the current running device and then back.
+
+
+    Parameters
+    ----------
+
+    img
+        The input image to apply elastic deformation at.  This image should
+        always have this shape: ``[C, H, W]``. It should always represent a
+        tensor on the CPU.
+
+    alpha
+        A multiplier for the gaussian filter outputs
+
+    sigma
+        Standard deviation for Gaussian kernel.
+
+    spline_order
+        The order of the spline interpolation, default is 1. The order has to
+        be in the range 0-5.
+
+    mode
+        The mode parameter determines how the input array is extended beyond
+        its boundaries.
+
+    p
+        Probability that this transformation will be applied.  Meaningful when
+        using it as a data augmentation technique.
+    """
+
+    if random.random() < p:
+        assert img.ndim == 3, (
+            f"This filter accepts only images with 3 dimensions, "
+            f"however I got an image with {img.ndim} dimensions."
+        )
+
+        # Input tensor is of shape C x H x W
+        img_shape = img.shape[1:]
+
+        dx = alpha * typing.cast(
+            numpy.typing.NDArray[numpy.float64],
+            gaussian_filter(
+                (numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
+                sigma,
+                mode="constant",
+                cval=0.0,
+            ),
+        )
+        dy = alpha * typing.cast(
+            numpy.typing.NDArray[numpy.float64],
+            gaussian_filter(
+                (numpy.random.rand(img_shape[0], img_shape[1]) * 2 - 1),
+                sigma,
+                mode="constant",
+                cval=0.0,
+            ),
+        )
+
+        x, y = numpy.meshgrid(
+            numpy.arange(img_shape[0]),
+            numpy.arange(img_shape[1]),
+            indexing="ij",
+        )
+        indices = [
+            numpy.reshape(x + dx, (-1, 1)),
+            numpy.reshape(y + dy, (-1, 1)),
+        ]
+
+        # may copy, if img is not on CPU originally
+        img_numpy = img.numpy()
+        output = numpy.zeros_like(img_numpy)
+        for i in range(img.shape[0]):
+            output[i, :, :] = torch.tensor(
+                map_coordinates(
+                    img_numpy[i, :, :],
+                    indices,
+                    order=spline_order,
+                    mode=mode,
+                ).reshape(img_shape)
+            )
+
+        # wraps numpy array as tensor (with no copy)
+        return torch.as_tensor(output)
+
+    return img
+
+
+def _elastic_deformation_on_batch(
+    batch: torch.Tensor,
+    alpha: float = 1000.0,
+    sigma: float = 30.0,
+    spline_order: int = 1,
+    mode: str = "nearest",
+    p: float = 1.0,
+    pool: multiprocessing.pool.Pool | None = None,
+) -> torch.Tensor:
+    # transforms our custom functions into simpler callables
+    partial = functools.partial(
+        _elastic_deformation_on_image,
+        alpha=alpha,
+        sigma=sigma,
+        spline_order=spline_order,
+        mode=mode,
+        p=p,
+    )
+
+    # if a mp pool is available, do it in parallel
+    augmented_images: typing.Any
+    if pool is None:
+        augmented_images = map(partial, batch.cpu())
+    else:
+        augmented_images = pool.imap(partial, batch.cpu())
+
+    return torch.stack(list(augmented_images)).to(batch.device)
 
 
 class ElasticDeformation:
     """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
 
+    This implementation is based on 2 scipy functions
+    (:py:func:`scipy.ndimage.gaussian_filter` and
+    :py:func:`scipy.ndimage.map_coordinates`).  It is very inefficient since it
+    requires data is moved off the current running device and then back.
+
+    .. warning::
+
+       Furthermore, this transform is not scriptable and therefore cannot run
+       on a CUDA or MPS device.  Applying it, effectively creates a bottleneck
+       in model training.
+
     Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
+
+
+    Parameters
+    ----------
+
+    alpha
+
+    sigma
+        Standard deviation for Gaussian kernel.
+
+    spline_order
+        The order of the spline interpolation, default is 1. The order has to
+        be in the range 0-5.
+
+    mode
+        The mode parameter determines how the input array is extended beyond
+        its boundaries.
+
+    p
+        Probability that this transformation will be applied.  Meaningful when
+        using it as a data augmentation technique.
+
+    parallel
+        Use multiprocessing for processing batches of data: if set to -1
+        (default), disables multiprocessing.  Set to 0 to enable as many
+        processes as processing cores as available in the system. Set to >= 1
+        to enable that many processes.
     """
 
     def __init__(
         self,
-        alpha=1000,
-        sigma=30,
-        spline_order=1,
-        mode="nearest",
-        random_state=numpy.random,
-        p=1.0,
+        alpha: float = 1000.0,
+        sigma: float = 30.0,
+        spline_order: int = 1,
+        mode: str = "nearest",
+        p: float = 1.0,
+        parallel: int = -1,
     ):
-        self.alpha = alpha
-        self.sigma = sigma
-        self.spline_order = spline_order
-        self.mode = mode
-        self.random_state = random_state
-        self.p = p
-
-    def __call__(self, img):
-        if random.random() < self.p:
-            assert img.ndim == 3
-
-            # Input tensor is of shape C x H x W
-            # If the tensor only contains one channel, this conversion results in H x W.
-            # With 3 channels, we get H x W x C
-            img = transforms.ToPILImage()(img)
-            img = numpy.asarray(img)
-
-            shape = img.shape[:2]
-
-            dx = (
-                gaussian_filter(
-                    (self.random_state.rand(*shape) * 2 - 1),
-                    self.sigma,
-                    mode="constant",
-                    cval=0,
-                )
-                * self.alpha
-            )
-            dy = (
-                gaussian_filter(
-                    (self.random_state.rand(*shape) * 2 - 1),
-                    self.sigma,
-                    mode="constant",
-                    cval=0,
-                )
-                * self.alpha
+        self.alpha: float = alpha
+        self.sigma: float = sigma
+        self.spline_order: int = spline_order
+        self.mode: str = mode
+        self.p: float = p
+        self.parallel = parallel
+
+    @property
+    def parallel(self):
+        """Use multiprocessing for data augmentation.
+
+        If set to -1 (default), disables multiprocessing data
+        augmentation. Set to 0 to enable as many data loading instances
+        as processing cores as available in the system. Set to >= 1 to
+        enable that many multiprocessing instances for data loading.
+        """
+        return self._parallel
+
+    @parallel.setter
+    def parallel(self, value):
+        self._parallel = value
+
+        if value >= 0:
+            instances = value or multiprocessing.cpu_count()
+            logger.info(
+                f"Applying data-augmentation using {instances} processes..."
             )
+            self._mp_pool = multiprocessing.pool.Pool(instances)
+        else:
+            self._mp_pool = None
 
-            x, y = numpy.meshgrid(
-                numpy.arange(shape[0]), numpy.arange(shape[1]), indexing="ij"
+    def __call__(self, img: torch.Tensor) -> torch.Tensor:
+        if len(img.shape) == 4:
+            return _elastic_deformation_on_batch(
+                img,
+                self.alpha,
+                self.sigma,
+                self.spline_order,
+                self.mode,
+                self.p,
+                self._mp_pool,
             )
-            indices = [
-                numpy.reshape(x + dx, (-1, 1)),
-                numpy.reshape(y + dy, (-1, 1)),
-            ]
-            result = numpy.empty_like(img)
-
-            if img.ndim == 2:
-                result[:, :] = map_coordinates(
-                    img[:, :], indices, order=self.spline_order, mode=self.mode
-                ).reshape(shape)
-
-            else:
-                for i in range(img.shape[2]):
-                    result[:, :, i] = map_coordinates(
-                        img[:, :, i],
-                        indices,
-                        order=self.spline_order,
-                        mode=self.mode,
-                    ).reshape(shape)
-
-            return transforms.ToTensor()(PIL.Image.fromarray(result))
+        elif len(img.shape) == 3:
+            return _elastic_deformation_on_image(
+                img.cpu(),
+                self.alpha,
+                self.sigma,
+                self.spline_order,
+                self.mode,
+                self.p,
+            ).to(img.device)
 
-        else:
-            return img
+        raise RuntimeError(
+            f"This transform accepts only images with 3 dimensions,"
+            f"or batches of images with 4 dimensions.  However, I got "
+            f"an image with {img.ndim} dimensions."
+        )
diff --git a/src/ptbench/models/alexnet.py b/src/ptbench/models/alexnet.py
index cd391e3d6d3580dd6b470ad25efb47dff61ce373..7866f36df6fb9fe991e13ae6012c71df1c917777 100644
--- a/src/ptbench/models/alexnet.py
+++ b/src/ptbench/models/alexnet.py
@@ -174,12 +174,7 @@ class Alexnet(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = [
-            self._augmentation_transforms(img).to(self.device) for img in images
-        ]
-        # Combine list of augmented images back into a tensor
-        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
-        outputs = self(augmented_images)
+        outputs = self(self._augmentation_transforms(images))
 
         return self._train_loss(outputs, labels.float())
 
diff --git a/src/ptbench/models/densenet.py b/src/ptbench/models/densenet.py
index ba1d71fa13a0c56d9ac0db6737fb645ea66a042c..d1f4d03ccb21c4d5bf6d6df60ab717a5a34078fd 100644
--- a/src/ptbench/models/densenet.py
+++ b/src/ptbench/models/densenet.py
@@ -173,12 +173,7 @@ class Densenet(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = [
-            self._augmentation_transforms(img).to(self.device) for img in images
-        ]
-        # Combine list of augmented images back into a tensor
-        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
-        outputs = self(augmented_images)
+        outputs = self(self._augmentation_transforms(images))
 
         return self._train_loss(outputs, labels.float())
 
diff --git a/src/ptbench/models/pasa.py b/src/ptbench/models/pasa.py
index 479ec8f23e6a211ca76552dade705471011c2d7e..4e0e281b7429b782faaffb09d2f22fcdc49f61c0 100644
--- a/src/ptbench/models/pasa.py
+++ b/src/ptbench/models/pasa.py
@@ -243,12 +243,7 @@ class Pasa(pl.LightningModule):
             labels = torch.reshape(labels, (labels.shape[0], 1))
 
         # Forward pass on the network
-        augmented_images = [
-            self._augmentation_transforms(img).to(self.device) for img in images
-        ]
-        # Combine list of augmented images back into a tensor
-        augmented_images = torch.cat(augmented_images, 0).view(images.shape)
-        outputs = self(augmented_images)
+        outputs = self(self._augmentation_transforms(images))
 
         return self._train_loss(outputs, labels.float())
 
diff --git a/tests/data/raw_with_elastic_deformation.png b/tests/data/raw_with_elastic_deformation.png
index 388b4e1ebd446c6e492cafed58963f6a6502b957..61c4ddfe50131884433256ae9ac183ce51493142 100644
Binary files a/tests/data/raw_with_elastic_deformation.png and b/tests/data/raw_with_elastic_deformation.png differ
diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b56fad9238808d96a8ec0cb24108a0b95dc0b1d
--- /dev/null
+++ b/tests/test_image_utils.py
@@ -0,0 +1,56 @@
+# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+"""Tests for image utilities."""
+
+import numpy
+import PIL.Image
+
+from ptbench.data.image_utils import (
+    RemoveBlackBorders,
+    SingleAutoLevel16to8,
+    load_pil,
+)
+
+
+def test_remove_black_borders(datadir):
+    # Get a raw sample with black border
+    data_file = str(datadir / "raw_with_black_border.png")
+    raw_with_black_border = PIL.Image.open(data_file)
+
+    # Remove the black border
+    rbb = RemoveBlackBorders()
+    raw_rbb_removed = rbb(raw_with_black_border)
+
+    # Get the same sample without black border
+    data_file_2 = str(datadir / "raw_without_black_border.png")
+    raw_without_black_border = PIL.Image.open(data_file_2)
+
+    # Compare both
+    raw_rbb_removed = numpy.asarray(raw_rbb_removed)
+    raw_without_black_border = numpy.asarray(raw_without_black_border)
+
+    numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)
+
+
+def test_load_pil_16bit(datadir):
+    # If the ratio is higher 0.5, image is probably clipped
+    Level16to8 = SingleAutoLevel16to8()
+
+    data_file = str(datadir / "16bits.png")
+    image = numpy.array(Level16to8(load_pil(data_file)))
+
+    count_pixels = numpy.count_nonzero(image)
+    count_max_value = numpy.count_nonzero(image == image.max())
+
+    assert count_max_value / count_pixels < 0.5
+
+    # It should not do anything to an image already in 8 bits
+    data_file = str(datadir / "raw_without_black_border.png")
+    img_loaded = load_pil(data_file)
+
+    original_8bits = numpy.array(img_loaded)
+    leveled_8bits = numpy.array(Level16to8(img_loaded))
+
+    numpy.testing.assert_array_equal(original_8bits, leveled_8bits)
diff --git a/tests/test_tranforms.py b/tests/test_tranforms.py
index fc171770e263e1e9f0d686d5294dd18e1338495f..0dcdf5d64b473f4fccecb72db0baf01d1b67bbd1 100644
--- a/tests/test_tranforms.py
+++ b/tests/test_tranforms.py
@@ -6,42 +6,19 @@
 
 import numpy
 import PIL.Image
+import torchvision.transforms.functional as F
 
-from ptbench.data.loader import load_pil
-from ptbench.data.transforms import (
-    ElasticDeformation,
-    RemoveBlackBorders,
-    SingleAutoLevel16to8,
-)
-
-
-def test_remove_black_borders(datadir):
-    # Get a raw sample with black border
-    data_file = str(datadir / "raw_with_black_border.png")
-    raw_with_black_border = PIL.Image.open(data_file)
-
-    # Remove the black border
-    rbb = RemoveBlackBorders()
-    raw_rbb_removed = rbb(raw_with_black_border)
-
-    # Get the same sample without black border
-    data_file_2 = str(datadir / "raw_without_black_border.png")
-    raw_without_black_border = PIL.Image.open(data_file_2)
-
-    # Compare both
-    raw_rbb_removed = numpy.asarray(raw_rbb_removed)
-    raw_without_black_border = numpy.asarray(raw_without_black_border)
-
-    numpy.testing.assert_array_equal(raw_without_black_border, raw_rbb_removed)
+from ptbench.data.transforms import ElasticDeformation
 
 
 def test_elastic_deformation(datadir):
     # Get a raw sample without deformation
     data_file = str(datadir / "raw_without_elastic_deformation.png")
-    raw_without_deformation = PIL.Image.open(data_file)
+    raw_without_deformation = F.to_tensor(PIL.Image.open(data_file))
 
     # Elastic deforms the raw
-    ed = ElasticDeformation(random_state=numpy.random.RandomState(seed=100))
+    numpy.random.seed(seed=100)
+    ed = ElasticDeformation()
     raw_deformed = ed(raw_without_deformation)
 
     # Get the same sample already deformed (with seed=100)
@@ -49,29 +26,9 @@ def test_elastic_deformation(datadir):
     raw_2 = PIL.Image.open(data_file_2)
 
     # Compare both
-    raw_deformed = numpy.asarray(raw_deformed)
+    raw_deformed = (255 * numpy.asarray(raw_deformed)).astype(numpy.uint8)[
+        0, :, :
+    ]
     raw_2 = numpy.asarray(raw_2)
 
     numpy.testing.assert_array_equal(raw_deformed, raw_2)
-
-
-def test_load_pil_16bit(datadir):
-    # If the ratio is higher 0.5, image is probably clipped
-    Level16to8 = SingleAutoLevel16to8()
-
-    data_file = str(datadir / "16bits.png")
-    image = numpy.array(Level16to8(load_pil(data_file)))
-
-    count_pixels = numpy.count_nonzero(image)
-    count_max_value = numpy.count_nonzero(image == image.max())
-
-    assert count_max_value / count_pixels < 0.5
-
-    # It should not do anything to an image already in 8 bits
-    data_file = str(datadir / "raw_without_black_border.png")
-    img_loaded = load_pil(data_file)
-
-    original_8bits = numpy.array(img_loaded)
-    leveled_8bits = numpy.array(Level16to8(img_loaded))
-
-    numpy.testing.assert_array_equal(original_8bits, leveled_8bits)