Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • medai/software/mednet
1 result
Show changes
Commits on Source (20)
Showing
with 1192 additions and 1067 deletions
......@@ -15,33 +15,33 @@ dynamic = ["readme"]
license = { text = "GNU General Public License v3 (GPLv3)" }
authors = [{ name = "Geoffrey Raposo", email = "geoffrey@raposo.ch" }]
maintainers = [
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
{ name = "Andre Anjos", email = "andre.anjos@idiap.ch" },
{ name = "Daniel Carron", email = "daniel.carron@idiap.ch" },
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules",
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = [
"clapper",
"click",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"tqdm",
"psutil",
"tabulate",
"matplotlib",
"pillow",
"torch>=1.8",
"torchvision>=0.10",
"lightning>=2.0.3",
"tensorboard",
"clapper",
"click",
"numpy",
"pandas",
"scipy",
"scikit-learn",
"tqdm",
"psutil",
"tabulate",
"matplotlib",
"pillow",
"torch>=1.8",
"torchvision>=0.10",
"lightning>=2.0.3",
"tensorboard",
]
[project.urls]
......@@ -53,13 +53,13 @@ changelog = "https://gitlab.idiap.ch/biosignal/software/ptbench/-/releases"
[project.optional-dependencies]
qa = ["pre-commit"]
doc = [
"sphinx",
"furo",
"sphinx-autodoc-typehints",
"auto-intersphinx",
"sphinx-copybutton",
"sphinx-inline-tabs",
"sphinx-click",
"sphinx",
"furo",
"sphinx-autodoc-typehints",
"auto-intersphinx",
"sphinx-copybutton",
"sphinx-inline-tabs",
"sphinx-click",
]
test = ["pytest", "pytest-cov", "coverage"]
......
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
# optimizer
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# optimizer
optimizer = "SGD"
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,19 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import SGD
from ...models.alexnet import Alexnet
# config
optimizer_configs = {"lr": 0.001, "momentum": 0.1}
# optimizer
optimizer = "SGD"
optimizer = SGD
optimizer_configs = {"lr": 0.01, "momentum": 0.1}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Alexnet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,20 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.0001}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=False
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=False,
augmentation_transforms=augmentation_transforms,
)
......@@ -6,20 +6,30 @@
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.densenet import Densenet
# config
optimizer_configs = {"lr": 0.01}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 0.0001}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [
ElasticDeformation(p=0.8),
]
# model
model = Densenet(
criterion, criterion_valid, optimizer, optimizer_configs, pretrained=True
criterion,
criterion_valid,
optimizer,
optimizer_configs,
pretrained=True,
augmentation_transforms=augmentation_transforms,
)
......@@ -13,18 +13,30 @@ Reference: [PASA-2019]_
from torch import empty
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from ...models.pasa import PASA
# config
optimizer_configs = {"lr": 8e-5}
# optimizer
optimizer = "Adam"
optimizer = Adam
optimizer_configs = {"lr": 8e-5}
# criterion
criterion = BCEWithLogitsLoss(pos_weight=empty(1))
criterion_valid = BCEWithLogitsLoss(pos_weight=empty(1))
from ...data.transforms import ElasticDeformation
augmentation_transforms = [ElasticDeformation(p=0.8)]
# from torchvision.transforms.v2 import ElasticTransform, InterpolationMode
# augmentation_transforms = [ElasticTransform(alpha=1000.0, sigma=30.0, interpolation=InterpolationMode.NEAREST)]
# model
model = PASA(criterion, criterion_valid, optimizer, optimizer_configs)
model = PASA(
criterion,
criterion_valid,
optimizer,
optimizer_configs,
augmentation_transforms=augmentation_transforms,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import multiprocessing
import sys
import lightning as pl
import torch
from clapper.logging import setup
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
from .dataset import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
parallel=-1,
):
super().__init__()
self.batch_size = batch_size
self.batch_chunk_count = batch_chunk_count
self.train_dataset = None
self.validation_dataset = None
self.extra_validation_datasets = None
self._raw_data_transforms = []
self._model_transforms = []
self._augmentation_transforms = []
self.drop_incomplete_batch = drop_incomplete_batch
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
self.parallel = parallel
def setup(self, stage: str):
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
multiproc_kwargs = self._setup_multiproc(self.parallel)
train_sampler = WeightedRandomSampler(
train_samples_weights, len(train_samples_weights), replacement=True
)
return DataLoader(
self.train_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=train_sampler,
**multiproc_kwargs,
)
def val_dataloader(self):
loaders_dict = {}
multiproc_kwargs = self._setup_multiproc(self.parallel)
val_loader = DataLoader(
dataset=self.validation_dataset,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict["validation_loader"] = val_loader
if self.extra_validation_datasets is not None:
for set_idx, extra_set in enumerate(self.extra_validation_datasets):
extra_val_loader = DataLoader(
dataset=extra_set,
batch_size=self._compute_chunk_size(
self.batch_size, self.batch_chunk_count
),
shuffle=False,
drop_last=False,
pin_memory=self.pin_memory,
**multiproc_kwargs,
)
loaders_dict[
f"extra_validation_loader{set_idx}"
] = extra_val_loader
return loaders_dict
def predict_dataloader(self):
loaders_dict = {}
loaders_dict["train_loader"] = self.train_dataloader()
for k, v in self.val_dataloader().items():
loaders_dict[k] = v
return loaders_dict
def update_module_properties(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
def _compute_chunk_size(self, batch_size, chunk_count):
batch_chunk_size = batch_size
if batch_size % chunk_count != 0:
# batch_size must be divisible by batch_chunk_count.
raise RuntimeError(
f"--batch-size ({batch_size}) must be divisible by "
f"--batch-chunk-size ({chunk_count})."
)
else:
batch_chunk_size = batch_size // chunk_count
return batch_chunk_size
def _setup_multiproc(self, parallel):
multiproc_kwargs = dict()
if parallel < 0:
multiproc_kwargs["num_workers"] = 0
else:
multiproc_kwargs["num_workers"] = (
parallel or multiprocessing.cpu_count()
)
if multiproc_kwargs["num_workers"] > 0 and sys.platform == "darwin":
multiproc_kwargs[
"multiprocessing_context"
] = multiprocessing.get_context("spawn")
return multiproc_kwargs
def _build_transforms(self, is_train):
all_transforms = self.raw_data_transforms + self.model_transforms
if is_train:
all_transforms = all_transforms + self.augmentation_transforms
# all_transforms.append(transforms.ToTensor())
return transforms.Compose(all_transforms)
@property
def raw_data_transforms(self):
return self._raw_data_transforms
@raw_data_transforms.setter
def raw_data_transforms(self, transforms):
self._raw_data_transforms = transforms
@property
def model_transforms(self):
return self._model_transforms
@model_transforms.setter
def model_transforms(self, transforms):
self._model_transforms = transforms
@property
def augmentation_transforms(self):
return self._augmentation_transforms
@augmentation_transforms.setter
def augmentation_transforms(self, transforms):
self._augmentation_transforms = transforms
def get_dataset_from_module(module, stage, **module_args):
"""Instantiates a DataModule and retrieves the corresponding dataset.
Useful when combining multiple datasets.
"""
module_instance = module(**module_args)
module_instance.prepare_data()
module_instance.setup(stage=stage)
dataset = module_instance.dataset
return dataset
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections
import logging
import multiprocessing
import sys
import typing
import lightning
import torch
import torch.utils.data
import torchvision.transforms
from tqdm import tqdm
logger = logging.getLogger(__name__)
def _setup_dataloader_multiproc_parameters(
parallel: int,
) -> dict[str, typing.Any]:
"""Returns a dictionary containing pytorch arguments to be used in data
loaders.
It sets the parameter ``num_workers`` to match the expected pytorch
representation. For macOS machines, it also sets the
``multiprocessing_context`` to use ``spawn`` instead of the default.
The mapping between the command-line interface ``parallel`` setting works
like this:
.. list-table:: Relationship between ``parallel`` and DataLoader parameterisation
:widths: 15 15 70
:header-rows: 1
* - CLI ``parallel``
- :py:class:`torch.utils.data.DataLoader` ``kwargs``
- Comments
* - ``<0``
- 0
- Disables multiprocessing entirely, executes everything within the
same processing context
* - ``0``
- :py:func:`multiprocessing.cpu_count`
- Runs mini-batch data loading on as many external processes as CPUs
available in the current machine
* - ``>=1``
- ``parallel``
- Runs mini-batch data loading on as many external processes as set on
``parallel``
"""
retval: dict[str, typing.Any] = dict()
if parallel < 0:
retval["num_workers"] = 0
else:
retval["num_workers"] = parallel or multiprocessing.cpu_count()
if retval["num_workers"] > 0 and sys.platform == "darwin":
retval["multiprocessing_context"] = multiprocessing.get_context("spawn")
return retval
class _DelayedLoadingDataset(torch.utils.data.Dataset):
"""A list that loads its samples on demand.
This list mimics a pytorch Dataset, except raw data loading is done
on-the-fly, as the samples are requested through the bracket operator.
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied on-the-fly for this dataset,
to fit the output of the raw-data-loader to the model of interest.
"""
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
self.split = split
self.raw_data_loader = raw_data_loader
# Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.raw_data_loader(self.split[key])
return self.transform(tensor), metadata
def __len__(self):
return len(self.split)
class _CachedDataset(torch.utils.data.Dataset):
"""Basically, a list of preloaded samples.
This dataset will load all samples from the split during construction
instead of delaying that to the indexing. Beyong raw-data-loading,
``transforms`` given upon construction contribute to the cached samples.
Parameters
----------
split
An iterable containing the raw dataset samples loaded from the database
splits.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
transforms
A set of transforms that should be applied to the cached samples for
this dataset, to fit the output of the raw-data-loader to the model of
interest.
"""
def __init__(
self,
split: typing.Sequence[typing.Any],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
transforms: typing.Sequence[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
):
# Cannot unpack empty list
if len(transforms) > 0:
self.transform = torchvision.transforms.Compose([*transforms])
else:
self.transform = torchvision.transforms.Compose([])
self.data = [raw_data_loader(k) for k in tqdm(split)]
def __getitem__(self, key: int) -> tuple[torch.Tensor, typing.Mapping]:
tensor, metadata = self.data[key]
return self.transform(tensor), metadata
def __len__(self):
return len(self.data)
def get_sample_weights(
dataset: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Computes the (inverse) probabilities of samples based on their class.
This function takes as input a torch Dataset, and computes the weights to
balance each class in the dataset, and the datasets themselves if one
passes a :py:class:`torch.utils.data.ConcatDataset`.
This function assumes labels are stored on the second entry of sample and
are integers. If that is not the case, we just balance the overall dataset
w.r.t. each other (e.g. with a :py:class:`torch.utils.data.ConcatDataset`).
For single dataset (not a concatenated one) without labels this function is
the same as a no-op.
Parameters
----------
dataset
An instance of torch Dataset.
:py:class:`torch.utils.data.ConcatDataset` are supported
Returns
-------
sample_weights
The weights for all the samples in the dataset given as input
"""
retval = []
def _calculate_dataset_weights(
d: _DelayedLoadingDataset | _CachedDataset,
) -> torch.Tensor:
"""Calculates the weights for one dataset."""
if "label" in d[0][1] and isinstance(d[0][1]["label"], int):
# there are labels
targets = [k[1]["label"] for k in d]
counts = collections.Counter(targets)
weights = {k: 1.0 / v for k, v in counts.items()}
weight_per_sample = [weights[k[1]] for k in d]
return torch.tensor(weight_per_sample)
else:
# no labels, weight dataset samples only by count
# n.b.: only works if __len__(d) is implemented, we turn off typechecking
weight = 1.0 / len(d)
return torch.tensor(len(d) * [weight])
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
retval.append(_calculate_dataset_weights(ds)) # type: ignore[arg-type]
# Concatenate sample weights from all the datasets
return torch.cat(retval)
return _calculate_dataset_weights(dataset)
class CachingDataModule(lightning.LightningDataModule):
"""A conveninent data module with CSV or JSON protocol loading, mini-
batching, parallelisation and caching, all in one.
Instances of this class load data-split (a.k.a. protocol) definitions for a
database, and can load the data from the disk. An optional caching
mechanism stores the data at associated CPU memory, which can improve data
serving while training and evaluating models.
This datamodule defines basic operations to handle data loading and
mini-batch handling within this package's framework. It can return
:py:class:`torch.utils.data.DataLoader` objects for training, validation,
prediction and testing conditions. Parallelisation is handled by a simple
input flag.
Users must implement the basic :py:meth:`setup` function, which is
parameterised by a single string enumeration containing: ``fit``,
``validate``, ``test``, or ``predict``.
Parameters
----------
database_split
A dictionary that contains string keys representing subset names, and
values that are iterables over sample representations (potentially on
disk). These objects are passed to the ``raw_data_loader`` for loading
the sample data (and metadata) in memory. The objects represented may
be of any format (e.g. list, dictionary, etc), for as long as the
``raw_data_loader`` can properly handle it. To check the split and the
loader function works correctly, you may use
:py:func:`..dataset.check_database_split_loading`. As is, this class
expects at least one entry called ``train`` to exist in the input
dictionary. Optional entries are ``validation``, and ``test``. Entries
named ``monitor-...`` will be considered extra subsets that do not
influence any early stop criteria during training, and are just
monitored beyond the ``validation`` dataset.
raw_data_loader
A callable that will take the representation of the samples declared
inside database splits, load the actual raw data (if one exists), and
transform it into a :py:class:`torch.Tensor`containing floats in the
interval ``[0, 1]``, and a dictionary with further metadata attributes.
Samples can be cached **after** raw data loading.
cache_samples
If set, then issue raw data loading during ``prepare_data()``, and
serves samples from CPU memory. Otherwise, loads samples from disk on
demand. Running from CPU memory will offer increased speeds in exchange
for CPU memory. Sufficient CPU memory must be available before you set
this attribute to ``True``. It is typicall useful for relatively small
datasets.
train_sampler
If set, then reset the default random sampler from torch with a
(potentially) biased one of your choice. Notice this will not play an
effect on the batching strategy, only on the way the samples are picked
from the original dataset. A typical replacement is:
.. code:: python
sample_weights = get_sample_weights(dataset)
train_sampler = torch.utils.data.WeightedRandomSampler(
sample_weights, len(sample_weights), replacement=True
)
The variable ``sample_weights`` represents the probabilities (may not
sum to one) of picking each particular sample in the dataset for a
batch. We typically set ``replacement=True`` to avoid issues with
missing data from one of the classes in mini-batches. This is
particularly important in highly unbalanced datasets. The function
:py:func:`get_sample_weights` may help you in this aspect.
model_transforms
A list of transforms (torch modules) that will be applied after
raw-data-loading, and just before data is fed into the model or
eventual data-augmentation transformations for all data loaders
produced by this data module. This part of the pipeline receives data
as output by the raw-data-loader, or model-related transforms (e.g.
resize adaptions), if any is specified.
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
drop_incomplete_batch
If set, then may drop the last batch in an epoch, in case it is
incomplete. If you set this option, you should also consider
increasing the total number of epochs of training, as the total number
of training steps may be reduced.
parallel
Use multiprocessing for data loading: if set to -1 (default), disables
multiprocessing data loading. Set to 0 to enable as many data loading
instances as processing cores as available in the system. Set to >= 1
to enable that many multiprocessing instances for data loading.
"""
def __init__(
self,
database_split: dict[str, typing.Sequence[typing.Any]],
raw_data_loader: typing.Callable[
[typing.Any], tuple[torch.Tensor, typing.Mapping]
],
cache_samples: bool = False,
train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
model_transforms: list[
typing.Callable[[torch.Tensor], torch.Tensor]
] = [],
batch_size: int = 1,
batch_chunk_count: int = 1,
drop_incomplete_batch: bool = False,
parallel: int = -1,
):
super().__init__()
self.set_chunk_size(batch_size, batch_chunk_count)
self.database_split = database_split
self.raw_data_loader = raw_data_loader
self.cache_samples = cache_samples
self.train_sampler = train_sampler
self.model_transforms = model_transforms
self.drop_incomplete_batch = drop_incomplete_batch
self.parallel = parallel # immutable, otherwise would need to call
self.pin_memory = (
torch.cuda.is_available()
) # should only be true if GPU available and using it
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
@property
def parallel(self) -> int:
"""The parallel property."""
return self._parallel
@parallel.setter
def parallel(self, value: int) -> None:
self._parallel = value
self._dataloader_multiproc = _setup_dataloader_multiproc_parameters(
value
)
# datasets that have been setup() for the current stage
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
def set_chunk_size(self, batch_size: int, batch_chunk_count: int) -> None:
"""Coherently sets the batch-chunk-size after validation.
Parameters
----------
batch_size
Number of samples in every **training** batch (this parameter affects
memory requirements for the network). If the number of samples in the
batch is larger than the total number of samples available for
training, this value is truncated. If this number is smaller, then
batches of the specified size are created and fed to the network until
there are no more new samples to feed (epoch is finished). If the
total number of training samples is not a multiple of the batch-size,
the last batch will be smaller than the first, unless
``drop_incomplete_batch`` is set to ``true``, in which case this batch
is not used.
batch_chunk_count
Number of chunks in every batch (this parameter affects memory
requirements for the network). The number of samples loaded for every
iteration will be ``batch_size/batch_chunk_count``. ``batch_size``
needs to be divisible by ``batch_chunk_count``, otherwise an error will
be raised. This parameter is used to reduce number of samples loaded in
each iteration, in order to reduce the memory usage in exchange for
processing time (more iterations). This is specially interesting whe
one is running with GPUs with limited RAM. The default of 1 forces the
whole batch to be processed at once. Otherwise the batch is broken into
batch-chunk-count pieces, and gradients are accumulated to complete
each batch.
"""
# validation
if batch_size % batch_chunk_count != 0:
raise RuntimeError(
f"batch_size ({batch_size}) must be divisible by "
f"batch_chunk_size ({batch_chunk_count})."
)
self._batch_size = batch_size
self._batch_chunk_count = batch_chunk_count
self._chunk_size = self._batch_size // self._batch_chunk_count
def _setup_dataset(self, name: str) -> None:
"""Sets-up a single dataset from the input data split.
Parameters
----------
name
Name of the dataset to setup.
"""
if name in self._datasets:
logger.info(f"Dataset {name} is already setup. Not reloading it.")
return
if self.cache_samples:
logger.info(f"Caching {name} dataset")
self._datasets[name] = _CachedDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
else:
self._datasets[name] = _DelayedLoadingDataset(
self.database_split[name],
self.raw_data_loader,
self.model_transforms,
)
def setup(self, stage: str) -> None:
"""Sets up datasets for different tasks on the pipeline.
This method should setup (load, pre-process, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict), and keep
them ready to be used on one of the `_dataloader()` functions that are
pertinent for such stage.
If you have set ``cache_samples``, samples are loaded at this stage and
cached in memory.
Parameters
----------
stage
Name of the stage to which the setup is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
"""
if stage == "fit":
self._setup_dataset("train")
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
self._setup_dataset(k)
elif stage == "validate":
self._setup_dataset("validation")
for k in self.database_split:
if k.startswith("monitor-"):
self._setup_dataset(k)
elif stage == "test":
self._setup_dataset("test")
elif stage == "predict":
self._setup_dataset("test")
def teardown(self, stage: str) -> None:
"""Unset-up datasets for different tasks on the pipeline.
This method unsets (unload, remove from memory, etc) all datasets required
for a particular ``stage`` (fit, validate, test, predict).
If you have set ``cache_samples``, samples are loaded, this may
effectivley release all the associated memory.
Parameters
----------
stage
Name of the stage to which the teardown is applicable. Can be one of
``fit``, ``validate``, ``test`` or ``predict``. Each stage
typically uses the following data loaders:
* ``fit``: uses both train and validation datasets
* ``validate``: uses only the validation dataset
* ``test``: uses only the test dataset
* ``predict``: uses only the test dataset
"""
self._datasets: dict[str, _DelayedLoadingDataset | _CachedDataset] = {} # type: ignore[no-redef]
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Returns the train data loader."""
return torch.utils.data.DataLoader(
self._datasets["train"],
shuffle=True,
batch_size=self._chunk_size,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
sampler=self.train_sampler,
**self._dataloader_multiproc,
)
def available_dataset_keys(self) -> typing.KeysView[str]:
"""Returns all names for datasets that are setup."""
return self._datasets.keys()
def val_database_split_keys(self) -> list[str]:
"""Returns list of validation dataset names."""
return ["validation"] + [
k for k in self.database_split.keys() if k.startswith("monitor-")
]
def val_dataloader(self) -> dict[str, torch.utils.data.DataLoader]:
"""Returns the validation data loader(s)"""
validation_loader_opts = {
"batch_size": self._chunk_size,
"shuffle": False,
"drop_last": self.drop_incomplete_batch,
"pin_memory": self.pin_memory,
}
validation_loader_opts.update(self._dataloader_multiproc)
# select all keys of interest
return {
k: torch.utils.data.DataLoader(
self._datasets[k], **validation_loader_opts
)
for k in self.val_database_split_keys()
}
def test_dataloader(self):
"""Returns the test data loader(s)"""
return torch.utils.data.DataLoader(
self._datasets["test"],
batch_size=self._chunk_size,
shuffle=False,
drop_last=self.drop_incomplete_batch,
pin_memory=self.pin_memory,
**self._dataloader_multiproc,
)
def predict_dataloader(self):
"""Returns the prediction data loader(s)"""
return self.test_dataloader()
......@@ -2,460 +2,19 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
import csv
import json
import logging
import os
import pathlib
import torch
from tqdm import tqdm
import torch.utils.data
logger = logging.getLogger(__name__)
class JSONProtocol:
"""Generic multi-protocol/subset filelist dataset that yields samples.
To create a new dataset, you need to provide one or more JSON formatted
filelists (one per protocol) with the following contents:
.. code-block:: json
{
"subset1": [
[
"value1",
"value2",
"value3"
],
[
"value4",
"value5",
"value6"
]
],
"subset2": [
]
}
Your dataset many contain any number of subsets, but all sample entries
must contain the same number of fields.
Parameters
----------
protocols : list, dict
Paths to one or more JSON formatted files containing the various
protocols to be recognized by this dataset, or a dictionary, mapping
protocol names to paths (or opened file objects) of CSV files.
Internally, we save a dictionary where keys default to the basename of
paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each entry in
the JSON file. It should have as many items as fields in each entry of
the JSON file.
loader : object
A function that receives as input, a context dictionary (with at least
a "protocol" and "subset" keys indicating which protocol and subset are
being served), and a dictionary with ``{fieldname: value}`` entries,
and returns an object with at least 2 attributes:
* ``key``: which must be a unique string for every sample across
subsets in a protocol, and
* ``data``: which contains the data associated witht this sample
"""
def __init__(self, protocols, fieldnames):
if isinstance(protocols, dict):
self._protocols = protocols
else:
self._protocols = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in protocols
}
self.fieldnames = fieldnames
def check(self, limit=0):
"""For each protocol, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for proto in self._protocols:
logger.info(f"Checking protocol '{proto}'...")
for name, samples in self.subsets(proto).items():
logger.info(f"Checking subset '{name}'...")
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"of protocol {proto} from file "
f"'{self._protocols[proto]}': {e}"
)
errors += 1
return errors
def subsets(self, protocol):
"""Returns all subsets in a protocol.
This method will load JSON information for a given protocol and return
all subsets of the given protocol after converting each entry through
the loader function.
Parameters
----------
protocol : str
Name of the protocol data to load
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
fileobj = self._protocols[protocol]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
if str(fileobj).endswith(".bz2"):
import bz2
with bz2.open(self._protocols[protocol]) as f:
data = json.load(f)
else:
with open(self._protocols[protocol]) as f:
data = json.load(f)
else:
data = json.load(fileobj)
fileobj.seek(0)
retval = {}
for subset, samples in data.items():
logger.info(f"Loading subset {subset} samples.")
retval[subset] = [
dict(zip(self.fieldnames, k))
for n, k in enumerate(tqdm(samples))
]
return retval
class CSVDataset:
"""Generic multi-subset filelist dataset that yields samples.
To create a new dataset, you only need to provide a CSV formatted filelist
using any separator (e.g. comma, space, semi-colon) with the following
information:
.. code-block:: text
value1,value2,value3
value4,value5,value6
...
Notice that all rows must have the same number of entries.
Parameters
----------
subsets : list, dict
Paths to one or more CSV formatted files containing the various subsets
to be recognized by this dataset, or a dictionary, mapping subset names
to paths (or opened file objects) of CSV files. Internally, we save a
dictionary where keys default to the basename of paths (list input).
fieldnames : list, tuple
An iterable over the field names (strings) to assign to each column in
the CSV file. It should have as many items as fields in each row of
the CSV file(s).
loader : object
A function that receives as input, a context dictionary (with, at
least, a "subset" key indicating which subset is being served), and a
dictionary with ``{key: path}`` entries, and returns a dictionary with
the loaded data.
"""
def __init__(self, subsets, fieldnames, loader):
if isinstance(subsets, dict):
self._subsets = subsets
else:
self._subsets = {
os.path.basename(
str(k).replace("".join(pathlib.Path(k).suffixes), "")
): k
for k in subsets
}
self.fieldnames = fieldnames
self._loader = loader
def check(self, limit=0):
"""For each subset, check if all data can be correctly accessed.
This function assumes each sample has a ``data`` and a ``key``
attribute. The ``key`` attribute should be a string, or representable
as such.
Parameters
----------
limit : int
Maximum number of samples to check (in each protocol/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors : int
Number of errors found
"""
logger.info("Checking dataset...")
errors = 0
for name in self._subsets.keys():
logger.info(f"Checking subset '{name}'...")
samples = self.samples(name)
if limit:
logger.info(f"Checking at most first '{limit}' samples...")
samples = samples[:limit]
for pos, sample in enumerate(samples):
try:
sample.data # may trigger data loading
logger.info(f"{sample.key}: OK")
except Exception as e:
logger.error(
f"Found error loading entry {pos} in subset {name} "
f"from file '{self._subsets[name]}': {e}"
)
errors += 1
return errors
def subsets(self):
"""Returns all available subsets at once.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of objects (respecting
the ``key``, ``data`` interface).
"""
return {k: self.samples(k) for k in self._subsets.keys()}
def samples(self, subset):
"""Returns all samples in a subset.
This method will load CSV information for a given subset and return
all samples of the given subset after passing each entry through the
loading function.
Parameters
----------
subset : str
Name of the subset data to load
Returns
-------
subset : list
A lists of objects (respecting the ``key``, ``data`` interface).
"""
fileobj = self._subsets[subset]
if isinstance(fileobj, (str, bytes, pathlib.Path)):
with open(self._subsets[subset], newline="") as f:
cf = csv.reader(f)
samples = [k for k in cf]
else:
cf = csv.reader(fileobj)
samples = [k for k in cf]
fileobj.seek(0)
return [
self._loader(
dict(subset=subset, order=n), dict(zip(self.fieldnames, k))
)
for n, k in enumerate(samples)
]
class CachedDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files, used during prediction
for s in self._samples:
s["name"] = s["data"]
logger.info(f"Caching {self.subset} samples")
for sample in tqdm(self._samples):
sample["data"] = self.raw_data_loader(sample["data"])
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(sample["data"])
return sample
def __len__(self):
return len(self._samples)
class RuntimeDataset(torch.utils.data.Dataset):
def __init__(
self, json_protocol, protocol, subset, raw_data_loader, transforms
):
self.json_protocol = json_protocol
self.subset = subset
self.raw_data_loader = raw_data_loader
self.transforms = transforms
self._samples = json_protocol.subsets(protocol)[self.subset]
# Dict entry with relative path to files
for s in self._samples:
s["name"] = s["data"]
def __getitem__(self, idx):
sample = self._samples[idx].copy()
sample["data"] = self.transforms(self.raw_data_loader(sample["data"]))
return sample
def __len__(self):
return len(self._samples)
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0]["label"], int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0]["label"], int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s["label"])
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
def _get_positive_weights(dataloader):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
This function takes as input a :py:class:`torch.utils.data.DataLoader`
and computes the positive weights of each class to use them to have
a balanced loss.
......@@ -463,9 +22,8 @@ def get_positive_weights(dataset):
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
dataloader : :py:class:`torch.utils.data.DataLoader`
A DataLoader from which to compute the positive weights. Must contain a 'label' key in the metadata returned by __getitem__().
Returns
......@@ -476,14 +34,8 @@ def get_positive_weights(dataset):
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s["label"])
else:
for s in dataset._samples:
targets.append(s["label"])
for batch in dataloader:
targets.extend(batch[1]["label"])
targets = torch.tensor(targets)
......@@ -512,75 +64,3 @@ def get_positive_weights(dataset):
)
return positive_weights
def reweight_BCEWithLogitsLoss(datamodule, model, criterion, criterion_valid):
from torch.nn import BCEWithLogitsLoss
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
validation_dataset = datamodule.validation_dataset
# Redefine a weighted criterion if possible
if isinstance(criterion, torch.nn.BCEWithLogitsLoss):
positive_weights = get_positive_weights(train_dataset)
model.hparams.criterion = BCEWithLogitsLoss(pos_weight=positive_weights)
else:
logger.warning("Weighted criterion not supported")
if validation_dataset is not None:
# Redefine a weighted valid criterion if possible
if (
isinstance(criterion_valid, torch.nn.BCEWithLogitsLoss)
or criterion_valid is None
):
positive_weights = get_positive_weights(validation_dataset)
model.hparams.criterion_valid = BCEWithLogitsLoss(
pos_weight=positive_weights
)
else:
logger.warning("Weighted valid criterion not supported")
def normalize_data(normalization, model, datamodule):
from torch.utils.data import DataLoader
datamodule.prepare_data()
datamodule.setup(stage="fit")
train_dataset = datamodule.train_dataset
# Create z-normalization model layer if needed
if normalization == "imagenet":
model.normalizer.set_mean_std(
[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
)
logger.info("Z-normalization with ImageNet mean and std")
elif normalization == "current":
# Compute mean/std of current train subset
temp_dl = DataLoader(
dataset=train_dataset, batch_size=len(train_dataset)
)
data = next(iter(temp_dl))
mean = data[1].mean(dim=[0, 2, 3])
std = data[1].std(dim=[0, 2, 3])
model.normalizer.set_mean_std(mean, std)
# Format mean and std for logging
mean = str(
[
round(x, 3)
for x in ((mean * 10**3).round() / (10**3)).tolist()
]
)
std = str(
[
round(x, 3)
for x in ((std * 10**3).round() / (10**3)).tolist()
]
)
logger.info(f"Z-normalization with mean {mean} and std {std}")
......@@ -264,7 +264,7 @@ json_dataset = JSONDataset(
def _maker(protocol, resize_size=512, cc_size=512, RGB=True):
import torchvision.transforms as transforms
from ..transforms import SingleAutoLevel16to8
from ..loader import SingleAutoLevel16to8
post_transforms = []
if not RGB:
......
......@@ -5,9 +5,42 @@
"""Data loading code."""
import numpy
import PIL.Image
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
def load_pil(path):
"""Loads a sample data.
......
......@@ -13,22 +13,11 @@ the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import importlib.resources
import os
from clapper.logging import setup
from ...utils.rc import load_rc
from ..loader import load_pil_baw
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"),
......@@ -43,11 +32,3 @@ _protocols = [
importlib.resources.files(__name__).joinpath("fold_8.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_9.json.bz2"),
]
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
def _raw_data_loader(img_path):
raw_data = load_pil_baw(os.path.join(_datadir, img_path))
return raw_data
......@@ -2,27 +2,34 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol)
"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ..transforms import ElasticDeformation
from .utils import ShenzhenDataModule
See :py:mod:`ptbench.data.shenzhen` for dataset details.
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
protocol_name = "default"
augmentation_transforms = [ElasticDeformation(p=0.8)]
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
"""
datamodule = ShenzhenDataModule(
protocol="default",
model_transforms=[],
augmentation_transforms=augmentation_transforms,
import importlib.resources
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
# model_transforms = [],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the National
Library of Medicine, Maryland, USA in collaboration with Shenzhen No.3 People’s
Hospital, Guangdong Medical College, Shenzhen, China. The Chest X-rays are from
out-patient clinics, and were captured as part of the daily routine using
Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
import os
import typing
import torch.nn
import torchvision.transforms
from ...utils.rc import load_rc
from ..raw_data_loader import RemoveBlackBorders, load_pil_baw
_datadir = load_rc().get("datadir.shenzhen", os.path.realpath(os.curdir))
"""This variable contains the base directory where the database raw data is
stored."""
_transform = torchvision.transforms.Compose(
[
RemoveBlackBorders(),
torchvision.transforms.Resize(512),
torchvision.transforms.CenterCrop(512),
torchvision.transforms.ToTensor(),
]
)
"""Transforms that are always applied to the loaded raw images."""
def raw_data_loader(
sample: tuple[str, int]
) -> tuple[torch.Tensor, typing.Mapping]:
"""Loads a single image sample from the disk.
Parameters
----------
img_path
The path suffix, within the dataset root folder, where to find the
image to be loaded.
Returns
-------
image
A PIL image in grayscale mode
"""
tensor = _transform(load_pil_baw(os.path.join(_datadir, sample[0])))
return tensor, dict(label=sample[1]) # type: ignore[arg-type]
......@@ -2,81 +2,40 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (cross validation fold 0, RGB)
"""Shenzhen datamodule for computer-aided diagnosis (default protocol)
* Split reference: first 80% of TB and healthy CXR for "train", rest for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from clapper.logging import setup
from ....data import return_subsets
from ....data.base_datamodule import BaseDataModule
from ....data.dataset import JSONProtocol
from ....data.shenzhen import _cached_loader, _delayed_loader, _protocols
See :py:mod:`ptbench.data.shenzhen` for dataset details.
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
cache_samples=False,
multiproc_kwargs=None,
data_transforms=[],
model_transforms=[],
train_transforms=[],
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
self.cache_samples = cache_samples
self.has_setup_fit = False
This configuration:
* raw data (default): :py:obj:`ptbench.data.shenzhen._tranforms`
* augmentations: elastic deformation (probability = 80%)
* output image resolution: 512x512 pixels
"""
self.data_transforms = data_transforms
self.model_transforms = model_transforms
self.train_transforms = train_transforms
import importlib.resources
"""[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]"""
from torchvision import transforms
def setup(self, stage: str):
if self.cache_samples:
logger.info(
"Argument cache_samples set to True. Samples will be loaded in memory."
)
samples_loader = _cached_loader
else:
logger.info(
"Argument cache_samples set to False. Samples will be loaded at runtime."
)
samples_loader = _delayed_loader
from ..datamodule import CachingDataModule
from ..split import JSONDatabaseSplit
from .raw_data_loader import raw_data_loader
self.json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
loader=samples_loader,
post_transforms=self.post_transforms,
datamodule = CachingDataModule(
database_split=JSONDatabaseSplit(
importlib.resources.files(__name__.rsplit(".", 1)[0]).joinpath(
"default.json.bz2"
)
if not self.has_setup_fit and stage == "fit":
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
) = return_subsets(self.json_protocol, "default", stage)
self.has_setup_fit = True
datamodule = DefaultModule
),
raw_data_loader=raw_data_loader,
cache_samples=False,
# train_sampler: typing.Optional[torch.utils.data.Sampler] = None,
model_transforms=[
transforms.ToPILImage(),
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
],
# batch_size = 1,
# batch_chunk_count = 1,
# drop_incomplete_batch = False,
# parallel = -1,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for computer-aided diagnosis.
The standard digital image database for Tuberculosis is created by the
National Library of Medicine, Maryland, USA in collaboration with Shenzhen
No.3 People’s Hospital, Guangdong Medical College, Shenzhen, China.
The Chest X-rays are from out-patient clinics, and were captured as part of
the daily routine using Philips DR Digital Diagnose systems.
* Reference: [MONTGOMERY-SHENZHEN-2014]_
* Original resolution (height x width or width x height): 3000 x 3000 or less
* Split reference: none
* Protocol ``default``:
* Training samples: 64% of TB and healthy CXR (including labels)
* Validation samples: 16% of TB and healthy CXR (including labels)
* Test samples: 20% of TB and healthy CXR (including labels)
"""
from clapper.logging import setup
from torchvision import transforms
from ..base_datamodule import BaseDataModule
from ..dataset import CachedDataset, JSONProtocol, RuntimeDataset
from ..shenzhen import _protocols, _raw_data_loader
from ..transforms import RemoveBlackBorders
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class ShenzhenDataModule(BaseDataModule):
def __init__(
self,
protocol="default",
model_transforms=[],
augmentation_transforms=[],
batch_size=1,
batch_chunk_count=1,
drop_incomplete_batch=False,
cache_samples=False,
parallel=-1,
):
super().__init__(
batch_size=batch_size,
drop_incomplete_batch=drop_incomplete_batch,
batch_chunk_count=batch_chunk_count,
parallel=parallel,
)
self._cache_samples = cache_samples
self._has_setup_fit = False
self._has_setup_predict = False
self._protocol = protocol
self.raw_data_transforms = [
RemoveBlackBorders(),
transforms.Resize(512),
transforms.CenterCrop(512),
transforms.ToTensor(),
]
self.model_transforms = model_transforms
self.augmentation_transforms = augmentation_transforms
def setup(self, stage: str):
json_protocol = JSONProtocol(
protocols=_protocols,
fieldnames=("data", "label"),
)
if self._cache_samples:
dataset = CachedDataset
else:
dataset = RuntimeDataset
if not self._has_setup_fit and stage == "fit":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=True),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_fit = True
if not self._has_setup_predict and stage == "predict":
self.train_dataset = dataset(
json_protocol,
self._protocol,
"train",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self.validation_dataset = dataset(
json_protocol,
self._protocol,
"validation",
_raw_data_loader,
self._build_transforms(is_train=False),
)
self._has_setup_predict = True
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
import collections.abc
import csv
import importlib.abc
import json
import logging
import pathlib
import typing
import torch
logger = logging.getLogger(__name__)
class JSONDatabaseSplit(
dict,
typing.Mapping[str, typing.Sequence[typing.Any]],
):
"""Defines a loader that understands a database split (train, test, etc) in
JSON format.
To create a new database split, you need to provide a JSON formatted
dictionary in a file, with contents similar to the following:
.. code-block:: json
{
"subset1": [
[
"sample1-data1",
"sample1-data2",
"sample1-data3",
],
[
"sample2-data1",
"sample2-data2",
"sample2-data3",
]
],
"subset2": [
[
"sample42-data1",
"sample42-data2",
"sample42-data3",
],
]
}
Your database split many contain any number of subsets (dictionary keys).
For simplicity, we recommend all sample entries are formatted similarly so
that raw-data-loading is simplified. Use the function
:py:func:`check_database_split_loading` to test raw data loading and fine
tune the dataset split, or its loading.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
path
Absolute path to a JSON formatted file containing the database split to be
recognized by this object.
"""
def __init__(self, path: pathlib.Path | str | importlib.abc.Traversable):
if isinstance(path, str):
path = pathlib.Path(path)
self.path = path
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(self) -> dict[str, typing.Sequence[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load JSON information for the current split and return
all subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
if str(self.path).endswith(".bz2"):
logger.debug(f"Loading database split from {str(self.path)}...")
with __import__("bz2").open(self.path) as f:
return json.load(f)
else:
with self.path.open() as f:
return json.load(f)
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
class CSVDatabaseSplit(collections.abc.Mapping):
"""Defines a loader that understands a database split (train, test, etc) in
CSV format.
To create a new database split, you need to provide one or more CSV
formatted files, each representing a subset of this split, containing the
sample data (one per row). Example:
Inside the directory ``my-split/``, one can file files ``train.csv``,
``validation.csv``, and ``test.csv``. Each file has a structure similar to
the following:
.. code-block:: text
sample1-value1,sample1-value2,sample1-value3
sample2-value1,sample2-value2,sample2-value3
...
Each file in the provided directory defines the subset name on the split.
So, the file ``train.csv`` will contain the data from the ``train`` subset,
and so on.
Objects of this class behave like a dictionary in which keys are subset
names in the split, and values represent samples data and meta-data.
Parameters
----------
directory
Absolute path to a directory containing the database split layed down
as a set of CSV files, one per subset.
"""
def __init__(
self, directory: pathlib.Path | str | importlib.abc.Traversable
):
if isinstance(directory, str):
directory = pathlib.Path(directory)
assert (
directory.is_dir()
), f"`{str(directory)}` is not a valid directory"
self.directory = directory
self.subsets = self._load_split_from_disk()
def _load_split_from_disk(
self,
) -> dict[str, list[typing.Any]]:
"""Loads all subsets in a split from its file system representation.
This method will load CSV information for the current split and return all
subsets of the given split after converting each entry through the
loader function.
Returns
-------
subsets : dict
A dictionary mapping subset names to lists of JSON objects
"""
retval = {}
for subset in self.directory.iterdir():
if str(subset).endswith(".csv.bz2"):
logger.debug(f"Loading database split from {subset}...")
with __import__("bz2").open(subset) as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv.bz2")]] = [
k for k in reader
]
elif str(subset).endswith(".csv"):
with subset.open() as f:
reader = csv.reader(f)
retval[subset.name[: -len(".csv")]] = [k for k in reader]
else:
logger.debug(
f"Ignoring file {subset} in CSVDatabaseSplit readout"
)
return retval
def __getitem__(self, key: str) -> typing.Sequence[typing.Any]:
"""Accesses subset ``key`` from this split."""
return self.subsets[key]
def __iter__(self):
"""Iterates over the subsets."""
return iter(self.subsets)
def __len__(self) -> int:
"""How many subsets we currently have."""
return len(self.subsets)
def check_database_split_loading(
database_split: typing.Mapping[str, typing.Sequence[typing.Any]],
loader: typing.Callable[[typing.Any], torch.Tensor],
limit: int = 0,
) -> int:
"""For each subset in the split, check if all data can be correctly loaded
using the provided loader function.
This function will return the number of errors loading samples, and will
log more detailed information to the logging stream.
Parameters
----------
database_split
A mapping that, contains the database split. Each key represents the
name of a subset in the split. Each value is a (potentially complex)
object that represents a single sample.
loader
A callable that transforms sample entries in the database split
into :py:class:`torch.Tensor` objects that can be used for training
or inference.
limit
Maximum number of samples to check (in each split/subset
combination) in this dataset. If set to zero, then check
everything.
Returns
-------
errors
Number of errors found
"""
logger.info(
"Checking if can load all samples in all subsets of this split..."
)
errors = 0
for subset in database_split.keys():
samples = subset if not limit else subset[:limit]
for pos, sample in enumerate(samples):
try:
data = loader(sample)
assert isinstance(data, torch.Tensor)
except Exception as e:
logger.info(
f"Found error loading entry {pos} in subset `{subset}`: {e}"
)
errors += 1
return errors
......@@ -22,39 +22,6 @@ from scipy.ndimage import gaussian_filter, map_coordinates
from torchvision import transforms
class SingleAutoLevel16to8:
"""Converts a 16-bit image to 8-bit representation using "auto-level".
This transform assumes that the input image is gray-scaled.
To auto-level, we calculate the maximum and the minimum of the
image, and
consider such a range should be mapped to the [0,255] range of the
destination image.
"""
def __call__(self, img):
imin, imax = img.getextrema()
irange = imax - imin
return PIL.Image.fromarray(
numpy.round(
255.0 * (numpy.array(img).astype(float) - imin) / irange
).astype("uint8"),
).convert("L")
class RemoveBlackBorders:
"""Remove black borders of CXR."""
def __init__(self, threshold=0):
self.threshold = threshold
def __call__(self, img):
img = numpy.asarray(img)
mask = numpy.asarray(img) > self.threshold
return PIL.Image.fromarray(img[numpy.ix_(mask.any(1), mask.any(0))])
class ElasticDeformation:
"""Elastic deformation of 2D image slightly adapted from [SIMARD-2003]_.
......@@ -68,7 +35,7 @@ class ElasticDeformation:
spline_order=1,
mode="nearest",
random_state=numpy.random,
p=1,
p=1.0,
):
self.alpha = alpha
self.sigma = sigma
......@@ -79,13 +46,15 @@ class ElasticDeformation:
def __call__(self, img):
if random.random() < self.p:
img = transforms.ToPILImage()(img)
assert img.ndim == 3
# Input tensor is of shape C x H x W
# If the tensor only contains one channel, this conversion results in H x W.
# With 3 channels, we get H x W x C
img = transforms.ToPILImage()(img)
img = numpy.asarray(img)
assert img.ndim == 2
shape = img.shape
shape = img.shape[:2]
dx = (
gaussian_filter(
......@@ -114,9 +83,22 @@ class ElasticDeformation:
numpy.reshape(y + dy, (-1, 1)),
]
result = numpy.empty_like(img)
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
if img.ndim == 2:
result[:, :] = map_coordinates(
img[:, :], indices, order=self.spline_order, mode=self.mode
).reshape(shape)
else:
for i in range(img.shape[2]):
result[:, :, i] = map_coordinates(
img[:, :, i],
indices,
order=self.spline_order,
mode=self.mode,
).reshape(shape)
return transforms.ToTensor()(PIL.Image.fromarray(result))
else:
return img
......@@ -94,10 +94,8 @@ class LoggingCallback(Callback):
self.log("total_time", current_time)
self.log("eta", eta_seconds)
self.log("loss", numpy.average(self.training_loss))
self.log(
"learning_rate", pl_module.hparams["optimizer_configs"]["lr"]
)
self.log("validation_loss", numpy.sum(self.validation_loss))
self.log("learning_rate", pl_module.optimizer_configs["lr"])
self.log("validation_loss", numpy.average(self.validation_loss))
if len(self.extra_validation_loss) > 0:
for (
......@@ -105,7 +103,8 @@ class LoggingCallback(Callback):
extra_valid_loss_values,
) in self.extra_validation_loss.items:
self.log(
extra_valid_loss_key, numpy.sum(extra_valid_loss_values)
extra_valid_loss_key,
numpy.average(extra_valid_loss_values),
)
queue_retries = 0
......
......@@ -7,19 +7,20 @@ import logging
import os
import shutil
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.utilities.model_summary import ModelSummary
import lightning.pytorch
import lightning.pytorch.callbacks
import lightning.pytorch.loggers
import torch.nn
from ..utils.accelerator import AcceleratorProcessor
from ..utils.resources import ResourceMonitor, cpu_constants, gpu_constants
from ..utils.save_sh_command import save_sh_command
from .callbacks import LoggingCallback
logger = logging.getLogger(__name__)
def check_gpu(device):
def check_gpu(device: str) -> None:
"""Check the device type and the availability of GPU.
Parameters
......@@ -35,32 +36,41 @@ def check_gpu(device):
), f"Device set to '{device}', but nvidia-smi is not installed"
def save_model_summary(output_folder, model):
def save_model_summary(
output_folder: str, model: torch.nn.Module
) -> tuple[lightning.pytorch.callbacks.ModelSummary, int]:
"""Save a little summary of the model in a txt file.
Parameters
----------
output_folder : str
output_folder
output path
model : :py:class:`torch.nn.Module`
model
Network (e.g. driu, hed, unet)
Returns
-------
r : str
summary:
The model summary in a text format.
n : int
total_parameters:
The number of parameters of the model.
"""
summary_path = os.path.join(output_folder, "model_summary.txt")
logger.info(f"Saving model summary at {summary_path}...")
with open(summary_path, "w") as f:
summary = ModelSummary(model, max_depth=-1)
summary = lightning.pytorch.utilities.model_summary.ModelSummary(
model, max_depth=-1
)
f.write(str(summary))
return summary, ModelSummary(model).total_parameters
return (
summary,
lightning.pytorch.utilities.model_summary.ModelSummary(
model
).total_parameters,
)
def static_information_to_csv(static_logfile_name, device, n):
......@@ -70,7 +80,8 @@ def static_information_to_csv(static_logfile_name, device, n):
----------
static_logfile_name : str
The static file name which is a join between the output folder and "constant.csv"
The static file name which is a join between the output folder and
"constant.csv"
"""
if os.path.exists(static_logfile_name):
backup = static_logfile_name + "~"
......@@ -188,7 +199,8 @@ def run(
not save intermediary checkpoints.
accelerator : str
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The device can also be specified (gpu:0)
A string indicating the accelerator to use (e.g. "cpu" or "gpu"). The
device can also be specified (gpu:0)
arguments : dict
Start and end epochs:
......@@ -215,10 +227,16 @@ def run(
os.makedirs(output_folder, exist_ok=True)
# Save model summary
r, n = save_model_summary(output_folder, model)
_, n = save_model_summary(output_folder, model)
csv_logger = CSVLogger(output_folder, "logs_csv")
tensorboard_logger = TensorBoardLogger(output_folder, "logs_tensorboard")
save_sh_command(output_folder)
# save_sh_command(os.path.join(output_folder, "cmd_line_config.txt"))
csv_logger = lightning.pytorch.loggers.CSVLogger(output_folder, "logs_csv")
tensorboard_logger = lightning.pytorch.loggers.TensorBoardLogger(
output_folder, "logs_tensorboard"
)
resource_monitor = ResourceMonitor(
interval=monitoring_interval,
......@@ -227,7 +245,7 @@ def run(
logging_level=logging.ERROR,
)
checkpoint_callback = ModelCheckpoint(
checkpoint_callback = lightning.pytorch.callbacks.ModelCheckpoint(
output_folder,
"model_lowest_valid_loss",
save_last=True,
......@@ -251,7 +269,7 @@ def run(
devices = accelerator_processor.device
with resource_monitor:
trainer = Trainer(
trainer = lightning.pytorch.Trainer(
accelerator=accelerator_processor.accelerator,
devices=devices,
max_epochs=max_epoch,
......