diff --git a/src/ptbench/data/datamodule.py b/src/ptbench/data/datamodule.py index abcf11d79286ecdb16cc374fb920f68418ec603a..3e699347607db040631c6ae6540412750ddede86 100644 --- a/src/ptbench/data/datamodule.py +++ b/src/ptbench/data/datamodule.py @@ -27,52 +27,6 @@ from .typing import ( logger = logging.getLogger(__name__) -def _setup_dataloader_multiproc_parameters( - parallel: int, -) -> dict[str, typing.Any]: - """Returns a dictionary containing pytorch arguments to be used in data - loaders. - - It sets the parameter ``num_workers`` to match the expected pytorch - representation. For macOS machines, it also sets the - ``multiprocessing_context`` to use ``spawn`` instead of the default. - - The mapping between the command-line interface ``parallel`` setting works - like this: - - .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation - :widths: 15 15 70 - :header-rows: 1 - - * - CLI ``parallel`` - - :py:class:`torch.utils.data.DataLoader` ``kwargs`` - - Comments - * - ``<0`` - - 0 - - Disables multiprocessing entirely, executes everything within the - same processing context - * - ``0`` - - :py:func:`multiprocessing.cpu_count` - - Runs mini-batch data loading on as many external processes as CPUs - available in the current machine - * - ``>=1`` - - ``parallel`` - - Runs mini-batch data loading on as many external processes as set on - ``parallel`` - """ - - retval: dict[str, typing.Any] = dict() - if parallel < 0: - retval["num_workers"] = 0 - else: - retval["num_workers"] = parallel or multiprocessing.cpu_count() - - if retval["num_workers"] > 0 and sys.platform == "darwin": - retval["multiprocessing_context"] = multiprocessing.get_context("spawn") - - return retval - - class _DelayedLoadingDataset(Dataset): """A list that loads its samples on demand. @@ -474,16 +428,53 @@ class CachingDataModule(lightning.LightningDataModule): many data loading instances as processing cores as available in the system. Set to >= 1 to enable that many multiprocessing instances for data loading. + + It sets the parameter ``num_workers`` (from Dataloaders) to match the + expected pytorch representation. For macOS machines, it also sets the + ``multiprocessing_context`` to use ``spawn`` instead of the default. + + The mapping between the command-line interface ``parallel`` setting + works like this: + + .. list-table:: Relationship between ``parallel`` and DataLoader parameterisation + :widths: 15 15 70 + :header-rows: 1 + + * - CLI ``parallel`` + - :py:class:`torch.utils.data.DataLoader` ``kwargs`` + - Comments + * - ``<0`` + - 0 + - Disables multiprocessing entirely, executes everything within the + same processing context + * - ``0`` + - :py:func:`multiprocessing.cpu_count` + - Runs mini-batch data loading on as many external processes as CPUs + available in the current machine + * - ``>=1`` + - ``parallel`` + - Runs mini-batch data loading on as many external processes as set on + ``parallel`` """ return self._parallel @parallel.setter def parallel(self, value: int) -> None: + self._dataloader_multiproc: dict[str, typing.Any] = {} self._parallel = value - self._dataloader_multiproc = _setup_dataloader_multiproc_parameters( - value - ) - # datasets that have been setup() for the current stage + + if value < 0: + num_workers = 0 + else: + num_workers = value or multiprocessing.cpu_count() + self._dataloader_multiproc["num_workers"] = num_workers + + if num_workers > 0 and sys.platform == "darwin": + self._dataloader_multiproc[ + "multiprocessing_context" + ] = multiprocessing.get_context("spawn") + + # datasets that have been setup() for the current stage are reset self._datasets = {} @property