diff --git a/src/ptbench/configs/models/alexnet.py b/src/ptbench/configs/models/alexnet.py
index 815226b517142438d77db25e69f6f4e173cee39c..0028e810ed34658b754ea0fe89b75bf64faae4a0 100644
--- a/src/ptbench/configs/models/alexnet.py
+++ b/src/ptbench/configs/models/alexnet.py
@@ -7,7 +7,7 @@
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import SGD
 
-from ...data.transforms import ElasticDeformation
+from ...data.augmentations import ElasticDeformation
 from ...models.alexnet import Alexnet
 
 model = Alexnet(
diff --git a/src/ptbench/configs/models/alexnet_pretrained.py b/src/ptbench/configs/models/alexnet_pretrained.py
index f968df50cda171cc94991febc511168d111517c9..9c772f42e935b2b995030e6d73fcc044fdcbcc76 100644
--- a/src/ptbench/configs/models/alexnet_pretrained.py
+++ b/src/ptbench/configs/models/alexnet_pretrained.py
@@ -7,7 +7,7 @@
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import SGD
 
-from ...data.transforms import ElasticDeformation
+from ...data.augmentations import ElasticDeformation
 from ...models.alexnet import Alexnet
 
 model = Alexnet(
diff --git a/src/ptbench/configs/models/densenet.py b/src/ptbench/configs/models/densenet.py
index 79f8f7dabc58746c1029bbc9760f10137801c202..5d453a6d4ff75514c0e7669a01b43be3e1aa3473 100644
--- a/src/ptbench/configs/models/densenet.py
+++ b/src/ptbench/configs/models/densenet.py
@@ -7,7 +7,7 @@
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...data.transforms import ElasticDeformation
+from ...data.augmentations import ElasticDeformation
 from ...models.densenet import Densenet
 
 model = Densenet(
diff --git a/src/ptbench/configs/models/densenet_pretrained.py b/src/ptbench/configs/models/densenet_pretrained.py
index 4bc4616c6de0a19134646a4ad1449c2920be9e50..49b0162f5fb1defd3bfc7b64bcc390c9e2bd0c27 100644
--- a/src/ptbench/configs/models/densenet_pretrained.py
+++ b/src/ptbench/configs/models/densenet_pretrained.py
@@ -7,7 +7,7 @@
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...data.transforms import ElasticDeformation
+from ...data.augmentations import ElasticDeformation
 from ...models.densenet import Densenet
 
 model = Densenet(
diff --git a/src/ptbench/configs/models/pasa.py b/src/ptbench/configs/models/pasa.py
index d1e1b0a3ae8d9e3e32a7ec19a49e21f01bb694d9..b1a201e5b2d1f5a1b22135ee46ea148d8f655536 100644
--- a/src/ptbench/configs/models/pasa.py
+++ b/src/ptbench/configs/models/pasa.py
@@ -14,7 +14,7 @@ Reference: [PASA-2019]_
 from torch.nn import BCEWithLogitsLoss
 from torch.optim import Adam
 
-from ...data.transforms import ElasticDeformation
+from ...data.augmentations import ElasticDeformation
 from ...models.pasa import Pasa
 
 model = Pasa(
diff --git a/src/ptbench/data/transforms.py b/src/ptbench/data/augmentations.py
similarity index 76%
rename from src/ptbench/data/transforms.py
rename to src/ptbench/data/augmentations.py
index 4e61c4d5f8d46351f83ecf405ad0dcb3de8a4cee..a0e20c57458b78039c465bab4ec7c044b706b745 100644
--- a/src/ptbench/data/transforms.py
+++ b/src/ptbench/data/augmentations.py
@@ -69,6 +69,13 @@ def _elastic_deformation_on_image(
     p
         Probability that this transformation will be applied.  Meaningful when
         using it as a data augmentation technique.
+
+
+    Returns
+    -------
+
+    tensor
+        A tensor on the CPU.
     """
 
     if random.random() < p:
@@ -122,7 +129,7 @@ def _elastic_deformation_on_image(
                 ).reshape(img_shape)
             )
 
-        # wraps numpy array as tensor (with no copy)
+        # wraps numpy array as tensor, move to destination device if need-be
         return torch.as_tensor(output)
 
     return img
@@ -137,6 +144,47 @@ def _elastic_deformation_on_batch(
     p: float = 1.0,
     pool: multiprocessing.pool.Pool | None = None,
 ) -> torch.Tensor:
+    """Performs elastic deformation on a batch of images.
+
+    This implementation is based on 2 scipy functions
+    (:py:func:`scipy.ndimage.gaussian_filter` and
+    :py:func:`scipy.ndimage.map_coordinates`).  It is very inefficient since it
+    requires data is moved off the current running device and then back.
+
+
+    Parameters
+    ----------
+
+    img
+        The input image to apply elastic deformation at.  This image should
+        always have this shape: ``[C, H, W]``. It should always represent a
+        tensor on the CPU.
+
+    alpha
+        A multiplier for the gaussian filter outputs
+
+    sigma
+        Standard deviation for Gaussian kernel.
+
+    spline_order
+        The order of the spline interpolation, default is 1. The order has to
+        be in the range 0-5.
+
+    mode
+        The mode parameter determines how the input array is extended beyond
+        its boundaries.
+
+    p
+        Probability that this transformation will be applied.  Meaningful when
+        using it as a data augmentation technique.
+
+
+    Returns
+    -------
+
+    tensor
+        A tensor on the CPU.
+    """
     # transforms our custom functions into simpler callables
     partial = functools.partial(
         _elastic_deformation_on_image,
@@ -154,7 +202,7 @@ def _elastic_deformation_on_batch(
     else:
         augmented_images = pool.imap(partial, batch.cpu())
 
-    return torch.stack(list(augmented_images)).to(batch.device)
+    return torch.stack(list(augmented_images))
 
 
 class ElasticDeformation:
@@ -196,9 +244,11 @@ class ElasticDeformation:
 
     parallel
         Use multiprocessing for processing batches of data: if set to -1
-        (default), disables multiprocessing.  Set to 0 to enable as many
-        processes as processing cores as available in the system. Set to >= 1
-        to enable that many processes.
+        (default), disables multiprocessing.  If set to -2, then enable
+        auto-tune (use the minimum value between the first batch size and total
+        number of processing cores).  Set to 0 to enable as many processes as
+        processing cores as available in the system. Set to >= 1 to enable that
+        many processes.
     """
 
     def __init__(
@@ -208,7 +258,7 @@ class ElasticDeformation:
         spline_order: int = 1,
         mode: str = "nearest",
         p: float = 1.0,
-        parallel: int = -1,
+        parallel: int = -2,
     ):
         self.alpha: float = alpha
         self.sigma: float = sigma
@@ -221,10 +271,11 @@ class ElasticDeformation:
     def parallel(self):
         """Use multiprocessing for data augmentation.
 
-        If set to -1 (default), disables multiprocessing data
-        augmentation. Set to 0 to enable as many data loading instances
-        as processing cores as available in the system. Set to >= 1 to
-        enable that many multiprocessing instances for data loading.
+        If set to -1 (default), disables multiprocessing.  If set to -2,
+        then enable auto-tune (use the minimum value between the first
+        batch size and total number of processing cores).  Set to 0 to
+        enable as many processes as processing cores as available in the
+        system. Set to >= 1 to enable that many processes.
         """
         return self._parallel
 
@@ -243,6 +294,11 @@ class ElasticDeformation:
 
     def __call__(self, img: torch.Tensor) -> torch.Tensor:
         if len(img.shape) == 4:
+            if self._mp_pool is None and self._parallel == -2:
+                # auto-tunning on first batch
+                instances = min(img.shape[0], multiprocessing.cpu_count())
+                self._mp_pool = multiprocessing.pool.Pool(instances)
+
             return _elastic_deformation_on_batch(
                 img,
                 self.alpha,
@@ -251,7 +307,8 @@ class ElasticDeformation:
                 self.mode,
                 self.p,
                 self._mp_pool,
-            )
+            ).to(img.device)
+
         elif len(img.shape) == 3:
             return _elastic_deformation_on_image(
                 img.cpu(),
diff --git a/src/ptbench/data/hivtb/__init__.py b/src/ptbench/data/hivtb/__init__.py
index b5d2753c0ba191b42bcaf8d66c104e36cb4dd839..88401da0f431df46325f0d122b1d4eac7a908d33 100644
--- a/src/ptbench/data/hivtb/__init__.py
+++ b/src/ptbench/data/hivtb/__init__.py
@@ -62,7 +62,7 @@ json_dataset = JSONDataset(
 def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
     from torchvision import transforms
 
-    from ..transforms import ElasticDeformation, RemoveBlackBorders
+    from ..augmentations import ElasticDeformation, RemoveBlackBorders
 
     post_transforms = []
     if RGB:
diff --git a/src/ptbench/data/indian/__init__.py b/src/ptbench/data/indian/__init__.py
index 72d7567f7087b4646b3e300b5dcf31161af4980f..5255c783daab1ca36b7f184d36339789a96ffd1d 100644
--- a/src/ptbench/data/indian/__init__.py
+++ b/src/ptbench/data/indian/__init__.py
@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
 def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
     from torchvision import transforms
 
-    from ..transforms import ElasticDeformation, RemoveBlackBorders
+    from ..augmentations import ElasticDeformation
+    from ..image_utils import RemoveBlackBorders
 
     post_transforms = []
     if RGB:
diff --git a/src/ptbench/data/tbpoc/__init__.py b/src/ptbench/data/tbpoc/__init__.py
index 02cd488080873ff11feabf895571ca25af113fca..6108b2fba5cad8611a8d6bd6c38467ee8a869807 100644
--- a/src/ptbench/data/tbpoc/__init__.py
+++ b/src/ptbench/data/tbpoc/__init__.py
@@ -62,7 +62,8 @@ json_dataset = JSONDataset(
 def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
     from torchvision import transforms
 
-    from ..transforms import ElasticDeformation, RemoveBlackBorders
+    from ..augmentations import ElasticDeformation
+    from ..image_utils import RemoveBlackBorders
 
     post_transforms = []
     if RGB:
diff --git a/src/ptbench/data/tbx11k_simplified/__init__.py b/src/ptbench/data/tbx11k_simplified/__init__.py
index 3080b63a3427ddb92af82e88c896802d24b6fa24..7e66abc346f88faaeab7f0e61637b1767723aa3f 100644
--- a/src/ptbench/data/tbx11k_simplified/__init__.py
+++ b/src/ptbench/data/tbx11k_simplified/__init__.py
@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
     from torchvision import transforms
 
     from .. import make_dataset
-    from ..transforms import ElasticDeformation
+    from ..augmentations import ElasticDeformation
 
     post_transforms = []
     if RGB:
diff --git a/src/ptbench/data/tbx11k_simplified_v2/__init__.py b/src/ptbench/data/tbx11k_simplified_v2/__init__.py
index d6075e85d366510d8670af362a2b685167e3c8de..57323beed9475e3f0c0992fd8ecb42b6fd3bb1cb 100644
--- a/src/ptbench/data/tbx11k_simplified_v2/__init__.py
+++ b/src/ptbench/data/tbx11k_simplified_v2/__init__.py
@@ -94,7 +94,7 @@ def _maker(protocol, RGB=False):
     from torchvision import transforms
 
     from .. import make_dataset
-    from ..transforms import ElasticDeformation
+    from ..augmentations import ElasticDeformation
 
     post_transforms = []
     if RGB:
diff --git a/tests/test_tranforms.py b/tests/test_tranforms.py
index 0dcdf5d64b473f4fccecb72db0baf01d1b67bbd1..c02c80d077e1e43f73451b01d32a9e9066067934 100644
--- a/tests/test_tranforms.py
+++ b/tests/test_tranforms.py
@@ -8,7 +8,7 @@ import numpy
 import PIL.Image
 import torchvision.transforms.functional as F
 
-from ptbench.data.transforms import ElasticDeformation
+from ptbench.data.augmentations import ElasticDeformation
 
 
 def test_elastic_deformation(datadir):