Skip to content
Snippets Groups Projects
Commit a67626d8 authored by André Anjos's avatar André Anjos :speech_balloon:
Browse files

[datamodule] Slightly streamlines the datamodule approach; adds documentation;...

[datamodule] Slightly streamlines the datamodule approach; adds documentation; adds type annotations; adds TODOs
parent 6b6196a0
No related branches found
No related tags found
1 merge request!7Reviewed DataModule design+docs+types
Pipeline #75367 failed
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import multiprocessing
import sys
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from .dataset import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
parallel=-1,
):
super().__init__()
self.batch_size = batch_size
self.batch_chunk_count = batch_chunk_count
self.train_dataset = None
self.validation_dataset = None
self.extra_validation_datasets = None
self._raw_data_transforms = []
self._model_transforms = []
self._augmentation_transforms = []
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.parallel = parallel
def setup(self, stage: str):
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
multiproc_kwargs = self._setup_multiproc(self.parallel)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
multiproc_kwargs = self._setup_multiproc(self.parallel)
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
def _setup_multiproc(self, parallel):
multiproc_kwargs = dict()
if parallel < 0:
multiproc_kwargs["num_workers"] = 0
else:
multiproc_kwargs["num_workers"] = (
parallel or multiprocessing.cpu_count()
)
if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
multiproc_kwargs[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
return multiproc_kwargs
def _build_transforms(self, is_train):
all_transforms = self.raw_data_transforms + self.model_transforms
if is_train:
all_transforms = all_transforms + self.augmentation_transforms
# all_transforms.append(transforms.ToTensor())
return transforms.Compose(all_transforms)
@property
def raw_data_transforms(self):
return self._raw_data_transforms
@raw_data_transforms.setter
def raw_data_transforms(self, transforms):
self._raw_data_transforms = transforms
@property
def model_transforms(self):
return self._model_transforms
@model_transforms.setter
def model_transforms(self, transforms):
self._model_transforms = transforms
@property
def augmentation_transforms(self):
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, transforms):
self._augmentation_transforms = transforms
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections
import multiprocessing
import sys
import typing
import lightning
import torch
import torch.utils.data
from clapper.logging import setup
# TODO: No logging on this module...
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
def _setup_dataloader_multiproc_parameters(
parallel: int,
) -> dict[str, typing.Any]:
"""Returns a dictionary containing pytorch arguments to be used in data
loaders.
It sets the parameter ``num_workers`` to match the expected pytorch
representation. For macOS machines, it also sets the
``multiprocessing_context`` to use ``spawn`` instead of the default.
The mapping between the command-line interface ``parallel`` setting works
like this:
.. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
:widths: 15 15 70
:header-rows: 1
* - CLI ``parallel``
- :py:class:`torch.utils.data.DataLoader` ``kwargs``
- Comments
* - ``<0``
- 0
- Disables multiprocessing entirely, executes everything within the
same processing context
* - ``0``
- :py:func:`multiprocessing.cpu_count`
- Runs mini-batch data loading on as many external processes as CPUs
available in the current machine
* - ``>=1``
- ``parallel``
- Runs mini-batch data loading on as many external processes as set on
``parallel``
"""
retval: dict[str, typing.Any] = dict()
if parallel < 0:
retval["num_workers"] = 0
else:
retval["num_workers"] = parallel or multiprocessing.cpu_count()
if retval["num_workers"] > 0 and sys.platform == "darwin":
retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
return retval
class _DelayedLoadingDataset(torch.utils.data.Dataset):
"""A list that loads its samples on demand.
This list mimics a pytorch Dataset, except raw data loading is done
on-the-fly, as the samples are requested through the bracket operator.
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied on-the-fly for this dataset.
"""
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
):
self.split = split
self.raw_data_loader = raw_data_loader
self.transform = torch.nn.Sequential(*transforms)
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.raw_data_loader(self.split[key])
return self.transform(tensor), metadata
def __len__(self):
return len(self.split)
class _CachedDataset(torch.utils.data.Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
instead of delaying that to the indexing.
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied on-the-fly for this dataset.
"""
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Optional[typing.Sequence[torch.nn.Module]] = None,
):
self.data = [raw_data_loader(k) for k in split]
self.transform = torch.nn.Sequential(*transforms)
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key]
return self.transform(tensor), metadata
def __len__(self):
return len(self.data)
def get_sample_weights(
dataset: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Computes the (inverse) probabilities of samples based on their class.
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
This function assumes labels are stored on the second entry of sample and
are integers. If that is not the case, we just balance the overall dataset
w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`).
For single dataset (not a concatenated one) without labels this function is
the same as a no-op.
Parameters
----------
dataset
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported
Returns
-------
sample_weights
The weights for all the samples in the dataset given as input
"""
retval = []
def _calculate_dataset_weights(
d: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Calculates the weights for one dataset."""
if "label" in d[0][1] and isinstance(d[0][1]["label"], int):
# there are labels
targets = [k[1]["label"] for k in d]
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
weight_per_sample = [weights[k[1]] for k in d]
return torch.tensor(weight_per_sample)
else:
# no labels, weight dataset samples only by count
# n.b.: only works if __len__(d) is implemented, we turn off typechecking
weight = 1.0 / len(d)
return torch.tensor(len(d) * [weight])
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
retval.append(_calculate_dataset_weights(ds)) # type: ignore[arg-type]
# Concatenate sample weights from all the datasets
return torch.cat(retval)
return _calculate_dataset_weights(dataset)
class CachingDataModule(lightning.LightningDataModule):
"""A conveninent data module with CSV or JSON protocol loading, mini-
batching, parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a
database, and can load the data from the disk. An optional caching
mechanism stores the data at associated CPU memory, which can improve data
serving while training and evaluating models.
This datamodule defines basic operations to handle data loading and
mini-batch handling within this package's framework. It can return
:py:class:`torch.utils.data.DataLoader` objects for training, validation,
prediction and testing conditions. Parallelisation is handled by a simple
input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
Parameters
----------
database_split
A dictionary that contains string keys representing subset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to the ``raw_data_loader`` for loading
the sample data (and metadata) in memory. The objects represented may
be of any format (e.g. list, dictionary, etc), for as long as the
``raw_data_loader`` can properly handle it. To check the split and the
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
Samples can be cached **after** raw data loading.
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
serves samples from CPU memory. Otherwise, loads samples from disk on
demand. Running from CPU memory will offer increased speeds in exchange
for CPU memory. Sufficient CPU memory must be available before you set
this attribute to ``True``. It is typicall useful for relatively small
datasets.
train_sampler
If set, then reset the default random sampler from torch with a
(potentially) biased one of your choice. Notice this will not play an
effect on the batching strategy, only on the way the samples are picked
from the original dataset. A typical replacement is:
.. code:: python
sample_weights = get_sample_weights(dataset)
train_sampler = torch.utils.data.WeightedRandomSampler(
sample_weights, len(sample_weights), replacement=True
)
The variable ``sample_weights`` represents the probabilities (may not
sum to one) of picking each particular sample in the dataset for a
batch. We typically set ``replacement=True`` to avoid issues with
missing data from one of the classes in mini-batches. This is
particularly important in highly unbalanced datasets. The function
:py:func:`get_sample_weights` may help you in this aspect.
data_augmentations
A list of torchvision transforms (torch modules) that will be applied
on training set samples to create data augmentations during the
training of a model. Augmentation transform pipelines are applied
*after* the raw data is loaded, and before ``model_transforms``.
Augmentation transforms assume they receive a torch tensor representing
an image as input (see :py:class:`torchvision.transforms.ToTensor` for
details), in the range ``[0, 1]``.
model_transforms
A list of torchvision transforms (torch modules) that will be applied
after data augmentation transforms, and just before data is fed into
the model for all data loaders produced by this data module. This part
of the pipeline receives data as output by the raw-data-loader, or from
data augmentations, if any is specified.
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
drop_incomplete_batch
If set, then may drop the last batch in an epoch, in case it is
incomplete. If you set this option, you should also consider
increasing the total number of epochs of training, as the total number
of training steps may be reduced.
parallel
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading.
"""
def __init__(
self,
database_split: dict[str, typing.Sequence[typing.Any]],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
cache_samples: bool = False,
train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
data_augmentations: list[torch.nn.Module] = [],
model_transforms: list[torch.nn.Module] = [],
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
parallel: int = -1,
):
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count})."
)
super().__init__()
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
self.train_sampler = train_sampler
self.data_augmentations = data_augmentations
self.model_transforms = model_transforms
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
self.drop_incomplete_batch = drop_incomplete_batch
self._parallel = parallel # immutable, otherwise would need to call
# the next function again
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
parallel
)
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {}
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
This method should setup (load, pre-process, etc) all subsets required
for a particular ``stage`` (fit, validate, test, predict), and keep
them ready to be used on one of the `_dataloader()` functions that are
pertinent for such stage.
If you have set ``cache_samples``, samples are loaded at this stage and
cached in memory.
Parameters
----------
stage
Name of the stage to which the setup is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
"""
def _setup(name, transforms):
if self.cache_samples:
self._datasets[name] = _CachedDataset(
self.database_split[name], self.raw_data_loader, transforms
)
else:
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name], self.raw_data_loader, transforms
)
if stage == "fit":
_setup("train", self.data_augmentations + self.model_transforms)
_setup("validation", self.model_transforms)
for k in self.database_split:
if k.startswith("monitor-"):
_setup(k, self.model_transforms)
elif stage == "validate":
_setup("validation", self.model_transforms)
for k in self.database_split:
if k.startswith("monitor-"):
_setup(k, self.model_transforms)
elif stage == "test":
_setup("test", self.model_transforms)
elif stage == "predict":
_setup("test", self.model_transforms)
def train_dataloader(self):
"""Returns the train data loader."""
return torch.utils.data.DataLoader(
self._datasets["train"],
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=self.train_sampler,
**self._dataloader_multiproc,
)
def val_dataloader(self):
"""Returns the validation data loader(s)"""
extra_valid = [
k for k in self.database_split.keys() if k.startswith("monitor-")
]
# TODO: do we really need the train sampler here?
validation_loader_opts = {
"batch_size": self._chunk_size,
"shuffle": False,
"drop_last": self.drop_incomplete_batch,
"pin_memory": self.pin_memory,
}
validation_loader_opts.update(self._dataloader_multiproc)
# TODO: not sure this is the right way to handle multiple validation
# loaders, please check and fix
if not extra_valid:
return torch.utils.data.DataLoader(
self._datasets["validation"],
**validation_loader_opts,
)
else:
return [
torch.utils.data.DataLoader(
self._datasets[k],
**validation_loader_opts,
)
for k in ["validation"] + extra_valid
]
def test_dataloader(self):
"""Returns the test data loader(s)"""
return torch.utils.data.DataLoader(
self._datasets["test"],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
def predict_dataloader(self):
"""Returns the prediction data loader(s)"""
return self.test_dataloader()
This diff is collapsed.
...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems. ...@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_ * Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less * Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none * Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels) * Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels) * Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels) * Test samples: 20% of TB and healthy CXR (including labels)
""" """
import importlib.resources import importlib.resources
import os
from clapper.logging import setup
from ...utils.rc import load_rc
from ..loader import load_pil_baw
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [ _protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"), importlib.resources.files(__name__).joinpath("default.json.bz2"),
...@@ -43,11 +32,3 @@ _protocols = [ ...@@ -43,11 +32,3 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"), importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"), importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
] ]
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(img_path):
raw_data = load_pil_baw(os.path.join(_datadir, img_path))
return raw_data
...@@ -2,27 +2,34 @@ ...@@ -2,27 +2,34 @@
# #
# SPDX-License-Identifier: GPL-3.0-or-later # SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol) """Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for See :py:mod:`ptbench.data.shenzhen` for dataset details.
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default) This configuration:
* See :py:mod:`ptbench.data.shenzhen` for dataset details * raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
""" """
from clapper.logging import setup import importlib.resources
from ..datamodule import CachingDataModule
from ..dataset import JSONDatabaseSplit
from ..transforms import ElasticDeformation from ..transforms import ElasticDeformation
from .utils import ShenzhenDataModule from .loader import raw_data_loader
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s") datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
protocol_name = "default" importlib.resources.files(__name__).joinpath("default.json.bz2")
),
augmentation_transforms = [ElasticDeformation(p=0.8)] raw_data_loader=raw_data_loader,
cache_samples=False,
datamodule = ShenzhenDataModule( # train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
protocol="default", data_augmentations=[ElasticDeformation(p=0.8)],
model_transforms=[], # model_transforms = [],
augmentation_transforms=augmentation_transforms, # batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
) )
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import os
import typing
import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..loader import load_pil_baw
from ..transforms import RemoveBlackBorders
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
stored."""
_transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
"""Transforms that are always applied to the loaded raw images."""
def raw_data_loader(
sample: tuple[str, int]
) -> tuple[torch.Tensor, typing.Mapping]:
"""Loads a single image sample from the disk.
Parameters
----------
img_path
The path suffix, within the dataset root folder, where to find the
image to be loaded.
Returns
-------
image
A PIL image in grayscale mode
"""
tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
...@@ -58,6 +58,8 @@ class RemoveBlackBorders: ...@@ -58,6 +58,8 @@ class RemoveBlackBorders:
class ElasticDeformation: class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_. """Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
TODO: needs to be converted into a torch.nn.Module to become scriptable!
Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0 Source: https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment