diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py
index abcf11d79286ecdb16cc374fb920f68418ec603a..3e699347607db040631c6ae6540412750ddede86 100644
--- a/src/ptbench/data/datamodule.py
+++ b/src/ptbench/data/datamodule.py
@@ -27,52 +27,6 @@ from .typing import (
 logger = logging.getLogger(__name__)
 
 
-def _setup_dataloader_multiproc_parameters(
-    parallel: int,
-) -> dict[str, typing.Any]:
-    """Returns a dictionary containing pytorch arguments to be used in data
-    loaders.
-
-    It sets the parameter ``num_workers`` to match the expected pytorch
-    representation.  For macOS machines, it also sets the
-    ``multiprocessing_context`` to use ``spawn`` instead of the default.
-
-    The mapping between the command-line interface ``parallel`` setting works
-    like this:
-
-    .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
-       :widths: 15 15 70
-       :header-rows: 1
-
-       * - CLI ``parallel``
-         - :py:class:`torch.utils.data.DataLoader` ``kwargs``
-         - Comments
-       * - ``<0``
-         - 0
-         - Disables multiprocessing entirely, executes everything within the
-           same processing context
-       * - ``0``
-         - :py:func:`multiprocessing.cpu_count`
-         - Runs mini-batch data loading on as many external processes as CPUs
-           available in the current machine
-       * - ``>=1``
-         - ``parallel``
-         - Runs mini-batch data loading on as many external processes as set on
-           ``parallel``
-    """
-
-    retval: dict[str, typing.Any] = dict()
-    if parallel < 0:
-        retval["num_workers"] = 0
-    else:
-        retval["num_workers"] = parallel or multiprocessing.cpu_count()
-
-    if retval["num_workers"] > 0 and sys.platform == "darwin":
-        retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
-
-    return retval
-
-
 class _DelayedLoadingDataset(Dataset):
     """A list that loads its samples on demand.
 
@@ -474,16 +428,53 @@ class CachingDataModule(lightning.LightningDataModule):
         many data loading instances as processing cores as available in
         the system.  Set to >= 1 to enable that many multiprocessing
         instances for data loading.
+
+        It sets the parameter ``num_workers`` (from Dataloaders) to match the
+        expected pytorch representation.  For macOS machines, it also sets the
+        ``multiprocessing_context`` to use ``spawn`` instead of the default.
+
+        The mapping between the command-line interface ``parallel`` setting
+        works like this:
+
+        .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
+           :widths: 15 15 70
+           :header-rows: 1
+
+           * - CLI ``parallel``
+             - :py:class:`torch.utils.data.DataLoader` ``kwargs``
+             - Comments
+           * - ``<0``
+             - 0
+             - Disables multiprocessing entirely, executes everything within the
+               same processing context
+           * - ``0``
+             - :py:func:`multiprocessing.cpu_count`
+             - Runs mini-batch data loading on as many external processes as CPUs
+               available in the current machine
+           * - ``>=1``
+             - ``parallel``
+             - Runs mini-batch data loading on as many external processes as set on
+               ``parallel``
         """
         return self._parallel
 
     @parallel.setter
     def parallel(self, value: int) -> None:
+        self._dataloader_multiproc: dict[str, typing.Any] = {}
         self._parallel = value
-        self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
-            value
-        )
-        # datasets that have been setup() for the current stage
+
+        if value < 0:
+            num_workers = 0
+        else:
+            num_workers = value or multiprocessing.cpu_count()
+        self._dataloader_multiproc["num_workers"] = num_workers
+
+        if num_workers > 0 and sys.platform == "darwin":
+            self._dataloader_multiproc[
+                "multiprocessing_context"
+            ] = multiprocessing.get_context("spawn")
+
+        # datasets that have been setup() for the current stage are reset
         self._datasets = {}
 
     @property