Skip to content
Snippets Groups Projects
Commit 8dc21400 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[data.datamodule] Simplified code structure

parent edadec6d
No related branches found
No related tags found
1 merge request!6Making use of LightningDataModule and simplification of data loading
......@@ -27,52 +27,6 @@ from .typing import (
logger = logging.getLogger(__name__)
def _setup_dataloader_multiproc_parameters(
parallel: int,
) -> dict[str, typing.Any]:
"""Returns a dictionary containing pytorch arguments to be used in data
loaders.
It sets the parameter ``num_workers`` to match the expected pytorch
representation. For macOS machines, it also sets the
``multiprocessing_context`` to use ``spawn`` instead of the default.
The mapping between the command-line interface ``parallel`` setting works
like this:
.. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
:widths: 15 15 70
:header-rows: 1
* - CLI ``parallel``
- :py:class:`torch.utils.data.DataLoader` ``kwargs``
- Comments
* - ``<0``
- 0
- Disables multiprocessing entirely, executes everything within the
same processing context
* - ``0``
- :py:func:`multiprocessing.cpu_count`
- Runs mini-batch data loading on as many external processes as CPUs
available in the current machine
* - ``>=1``
- ``parallel``
- Runs mini-batch data loading on as many external processes as set on
``parallel``
"""
retval: dict[str, typing.Any] = dict()
if parallel < 0:
retval["num_workers"] = 0
else:
retval["num_workers"] = parallel or multiprocessing.cpu_count()
if retval["num_workers"] > 0 and sys.platform == "darwin":
retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
return retval
class _DelayedLoadingDataset(Dataset):
"""A list that loads its samples on demand.
......@@ -474,16 +428,53 @@ class CachingDataModule(lightning.LightningDataModule):
many data loading instances as processing cores as available in
the system. Set to >= 1 to enable that many multiprocessing
instances for data loading.
It sets the parameter ``num_workers`` (from Dataloaders) to match the
expected pytorch representation. For macOS machines, it also sets the
``multiprocessing_context`` to use ``spawn`` instead of the default.
The mapping between the command-line interface ``parallel`` setting
works like this:
.. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
:widths: 15 15 70
:header-rows: 1
* - CLI ``parallel``
- :py:class:`torch.utils.data.DataLoader` ``kwargs``
- Comments
* - ``<0``
- 0
- Disables multiprocessing entirely, executes everything within the
same processing context
* - ``0``
- :py:func:`multiprocessing.cpu_count`
- Runs mini-batch data loading on as many external processes as CPUs
available in the current machine
* - ``>=1``
- ``parallel``
- Runs mini-batch data loading on as many external processes as set on
``parallel``
"""
return self._parallel
@parallel.setter
def parallel(self, value: int) -> None:
self._dataloader_multiproc: dict[str, typing.Any] = {}
self._parallel = value
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
value
)
# datasets that have been setup() for the current stage
if value < 0:
num_workers = 0
else:
num_workers = value or multiprocessing.cpu_count()
self._dataloader_multiproc["num_workers"] = num_workers
if num_workers > 0 and sys.platform == "darwin":
self._dataloader_multiproc[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
# datasets that have been setup() for the current stage are reset
self._datasets = {}
@property
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment