Skip to content
Snippets Groups Projects
Commit ea50dcaa authored by Daniel CARRON's avatar Daniel CARRON :b:
Browse files

Grouped configs and data for default Shenzhen dataset

Created base class for every LightningDataModule, that can be inherited
from. For each protocol, the user should create a new class inheriting
from BaseDataModule and implement the setup() function, in which the
dataset is defined.
parent ba0cbbf5
No related branches found
No related tags found
No related merge requests found
......@@ -117,7 +117,7 @@ montgomery_rs_f7 = "ptbench.configs.datasets.montgomery_RS.fold_7"
montgomery_rs_f8 = "ptbench.configs.datasets.montgomery_RS.fold_8"
montgomery_rs_f9 = "ptbench.configs.datasets.montgomery_RS.fold_9"
# shenzhen dataset (and cross-validation folds)
shenzhen = "ptbench.configs.datasets.shenzhen.default"
shenzhen = "ptbench.data.shenzhen.default"
shenzhen_rgb = "ptbench.configs.datasets.shenzhen.rgb"
shenzhen_f0 = "ptbench.configs.datasets.shenzhen.fold_0"
shenzhen_f1 = "ptbench.configs.datasets.shenzhen.fold_1"
......
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
"""Shenzhen dataset for TB detection (default protocol)
* Split reference: first 64% of TB and healthy CXR for "train" 16% for
* "validation", 20% for "test"
* This configuration resolution: 512 x 512 (default)
* See :py:mod:`ptbench.data.shenzhen` for dataset details
"""
from . import _maker
dataset = _maker("default")
......@@ -13,10 +13,9 @@ from ptbench.configs.datasets import get_samples_weights
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DataModule(pl.LightningDataModule):
class BaseDataModule(pl.LightningDataModule):
def __init__(
self,
dataset,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
......@@ -24,8 +23,6 @@ class DataModule(pl.LightningDataModule):
):
super().__init__()
self.dataset = dataset
self.train_batch_size = train_batch_size
self.predict_batch_size = predict_batch_size
......@@ -37,37 +34,9 @@ class DataModule(pl.LightningDataModule):
self.multiproc_kwargs = multiproc_kwargs
def setup(self, stage: str):
if stage == "fit":
if "__train__" in self.dataset:
logger.info("Found (dedicated) '__train__' set for training")
self.train_dataset = self.dataset["__train__"]
else:
self.train_dataset = self.dataset["train"]
if "__valid__" in self.dataset:
logger.info("Found (dedicated) '__valid__' set for validation")
self.validation_dataset = self.dataset["__valid__"]
if "__extra_valid__" in self.dataset:
if not isinstance(self.dataset["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(self.dataset['__extra_valid__'])}, "
f"which is invalid."
)
logger.info(
f"Found {len(self.dataset['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
self.extra_validation_datasets = self.dataset["__extra_valid__"]
else:
self.extra_validation_datasets = None
if stage == "predict":
self.predict_dataset = self.dataset
# Implemented by user
# Must define self.train_dataset, self.validation_dataset, self.extra_validation_datasets and self.predict_dataset
raise NotImplementedError
def train_dataloader(self):
train_samples_weights = get_samples_weights(self.train_dataset)
......
......@@ -7,6 +7,14 @@ import json
import logging
import os
import pathlib
import random
import torch
from torchvision.transforms import RandomRotation
RANDOM_ROTATION = [RandomRotation(15)]
"""Shared data augmentation based on random rotation only."""
logger = logging.getLogger(__name__)
......@@ -313,3 +321,295 @@ class CSVDataset:
)
for n, k in enumerate(samples)
]
def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
"""Creates a new data set, applying transforms.
.. note::
This is a convenience function for our own dataset definitions inside
this module, guaranteeting homogenity between dataset definitions
provided in this package. It assumes certain strategies for data
augmentation that may not be translatable to other applications.
Parameters
----------
samples : list
List of delayed samples
transforms : list
A list of transforms that needs to be applied to all samples in the set
prefixes : list
A list of data augmentation operations that needs to be applied
**before** the transforms above
suffixes : list
A list of data augmentation operations that needs to be applied
**after** the transforms above
Returns
-------
subset : :py:class:`ptbench.data.utils.SampleListDataset`
A pre-formatted dataset that can be fed to one of our engines
"""
from .utils import SampleListDataset as wrapper
return wrapper(samples, prefixes + transforms + suffixes)
def make_dataset(
subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
):
"""Creates a new configuration dataset from a list of dictionaries and
transforms.
This function takes as input a list of dictionaries as those that can be
returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
mapping protocol names (such as ``train``, ``dev`` and ``test``) to
:py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
transforms, and returns a dictionary applying
:py:class:`ptbench.data.utils.SampleListDataset` to these
lists, and our standard data augmentation if a ``train`` set exists.
For example, if ``subsets`` is composed of two sets named ``train`` and
``test``, this function will yield a dictionary with the following entries:
* ``__train__``: Wraps the ``train`` subset, includes data augmentation
(note: datasets with names starting with ``_`` (underscore) are excluded
from prediction and evaluation by default, as they contain data
augmentation transformations.)
* ``train``: Wraps the ``train`` subset, **without** data augmentation
* ``test``: Wraps the ``test`` subset, **without** data augmentation
.. note::
This is a convenience function for our own dataset definitions inside
this module, guaranteeting homogenity between dataset definitions
provided in this package. It assumes certain strategies for data
augmentation that may not be translatable to other applications.
Parameters
----------
subsets : list
A list of dictionaries that contains the delayed sample lists
for a number of named lists. The subsets will be aggregated in one
final subset. If one of the keys is ``train``, our standard dataset
augmentation transforms are appended to the definition of that subset.
All other subsets remain un-augmented.
transforms : list
A list of transforms that needs to be applied to all samples in the set
t_transforms : list
A list of transforms that needs to be applied to the train samples
post_transforms : list
A list of transforms that needs to be applied to all samples in the set
after all the other transforms
Returns
-------
dataset : dict
A pre-formatted dataset that can be fed to one of our engines. It maps
string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
"""
retval = {}
if len(subsets_groups) == 1:
subsets = subsets_groups[0]
else:
# If multiple subsets groups: aggregation
aggregated_subsets = {}
for subsets in subsets_groups:
for key in subsets.keys():
if key in aggregated_subsets:
aggregated_subsets[key] += subsets[key]
# Shuffle if data comes from multiple datasets
random.shuffle(aggregated_subsets[key])
else:
aggregated_subsets[key] = subsets[key]
subsets = aggregated_subsets
# Add post_transforms after t_transforms for the train set
t_transforms += post_transforms
for key in subsets.keys():
retval[key] = make_subset(
subsets[key], transforms=transforms, suffixes=post_transforms
)
if key == "train":
retval["__train__"] = make_subset(
subsets[key], transforms=transforms, suffixes=(t_transforms)
)
if key == "validation":
# also use it for validation during training
retval["__valid__"] = retval[key]
if (
("__train__" in retval)
and ("train" in retval)
and ("__valid__" not in retval)
):
# if the dataset does not have a validation set, we use the unaugmented
# training set as validation set
retval["__valid__"] = retval["train"]
return retval
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0].label, int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0].label, int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the positive weights of each class to use them to have
a balanced loss.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
positive_weights : :py:class:`torch.Tensor`
the positive weight of each class in the dataset given as input
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s.label)
else:
for s in dataset._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]]
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
......@@ -22,11 +22,19 @@ the daily routine using Philips DR Digital Diagnose systems.
import importlib.resources
import os
import random
import torch
from clapper.logging import setup
from ...utils.rc import load_rc
from ..dataset import JSONDataset
from ..loader import load_pil_baw, make_delayed
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
_protocols = [
importlib.resources.files(__name__).joinpath("default.json.bz2"),
importlib.resources.files(__name__).joinpath("fold_0.json.bz2"),
......@@ -57,9 +65,367 @@ def _loader(context, sample):
return make_delayed(sample, _raw_data_loader)
dataset = JSONDataset(
protocols=_protocols,
fieldnames=("data", "label"),
loader=_loader,
json_dataset = JSONDataset(
protocols=_protocols, fieldnames=("data", "label"), loader=_loader
)
"""Shenzhen dataset object."""
def _maker(protocol, resize_size=512, cc_size=512, RGB=False):
from torchvision import transforms
from ..transforms import ElasticDeformation, RemoveBlackBorders
post_transforms = []
if RGB:
post_transforms = [
transforms.Lambda(lambda x: x.convert("RGB")),
transforms.ToTensor(),
]
return make_dataset(
[json_dataset.subsets(protocol)],
[
RemoveBlackBorders(),
transforms.Resize(resize_size),
transforms.CenterCrop(cc_size),
],
[ElasticDeformation(p=0.8)],
post_transforms,
)
def make_subset(samples, transforms=[], prefixes=[], suffixes=[]):
"""Creates a new data set, applying transforms.
.. note::
This is a convenience function for our own dataset definitions inside
this module, guaranteeting homogenity between dataset definitions
provided in this package. It assumes certain strategies for data
augmentation that may not be translatable to other applications.
Parameters
----------
samples : list
List of delayed samples
transforms : list
A list of transforms that needs to be applied to all samples in the set
prefixes : list
A list of data augmentation operations that needs to be applied
**before** the transforms above
suffixes : list
A list of data augmentation operations that needs to be applied
**after** the transforms above
Returns
-------
subset : :py:class:`ptbench.data.utils.SampleListDataset`
A pre-formatted dataset that can be fed to one of our engines
"""
from ...data.utils import SampleListDataset as wrapper
return wrapper(samples, prefixes + transforms + suffixes)
def make_dataset(
subsets_groups, transforms=[], t_transforms=[], post_transforms=[]
):
"""Creates a new configuration dataset from a list of dictionaries and
transforms.
This function takes as input a list of dictionaries as those that can be
returned by :py:meth:`ptbench.data.dataset.JSONDataset.subsets`
mapping protocol names (such as ``train``, ``dev`` and ``test``) to
:py:class:`ptbench.data.sample.DelayedSample` lists, and a set of
transforms, and returns a dictionary applying
:py:class:`ptbench.data.utils.SampleListDataset` to these
lists, and our standard data augmentation if a ``train`` set exists.
For example, if ``subsets`` is composed of two sets named ``train`` and
``test``, this function will yield a dictionary with the following entries:
* ``__train__``: Wraps the ``train`` subset, includes data augmentation
(note: datasets with names starting with ``_`` (underscore) are excluded
from prediction and evaluation by default, as they contain data
augmentation transformations.)
* ``train``: Wraps the ``train`` subset, **without** data augmentation
* ``test``: Wraps the ``test`` subset, **without** data augmentation
.. note::
This is a convenience function for our own dataset definitions inside
this module, guaranteeting homogenity between dataset definitions
provided in this package. It assumes certain strategies for data
augmentation that may not be translatable to other applications.
Parameters
----------
subsets : list
A list of dictionaries that contains the delayed sample lists
for a number of named lists. The subsets will be aggregated in one
final subset. If one of the keys is ``train``, our standard dataset
augmentation transforms are appended to the definition of that subset.
All other subsets remain un-augmented.
transforms : list
A list of transforms that needs to be applied to all samples in the set
t_transforms : list
A list of transforms that needs to be applied to the train samples
post_transforms : list
A list of transforms that needs to be applied to all samples in the set
after all the other transforms
Returns
-------
dataset : dict
A pre-formatted dataset that can be fed to one of our engines. It maps
string names to :py:class:`ptbench.data.utils.SampleListDataset`'s.
"""
retval = {}
if len(subsets_groups) == 1:
subsets = subsets_groups[0]
else:
# If multiple subsets groups: aggregation
aggregated_subsets = {}
for subsets in subsets_groups:
for key in subsets.keys():
if key in aggregated_subsets:
aggregated_subsets[key] += subsets[key]
# Shuffle if data comes from multiple datasets
random.shuffle(aggregated_subsets[key])
else:
aggregated_subsets[key] = subsets[key]
subsets = aggregated_subsets
# Add post_transforms after t_transforms for the train set
t_transforms += post_transforms
for key in subsets.keys():
retval[key] = make_subset(
subsets[key], transforms=transforms, suffixes=post_transforms
)
if key == "train":
retval["__train__"] = make_subset(
subsets[key], transforms=transforms, suffixes=(t_transforms)
)
if key == "validation":
# also use it for validation during training
retval["__valid__"] = retval[key]
if (
("__train__" in retval)
and ("train" in retval)
and ("__valid__" not in retval)
):
# if the dataset does not have a validation set, we use the unaugmented
# training set as validation set
retval["__valid__"] = retval["train"]
return retval
def get_samples_weights(dataset):
"""Compute the weights of all the samples of the dataset to balance it
using the sampler of the dataloader.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the weights to balance each class in the dataset and the
datasets themselves if we have a ConcatDataset.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
samples_weights : :py:class:`torch.Tensor`
the weights for all the samples in the dataset given as input
"""
samples_weights = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
# Weighting only for binary labels
if isinstance(ds._samples[0].label, int):
# Groundtruth
targets = []
for s in ds._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights.append(
torch.tensor([weight[t] for t in targets])
)
else:
# We only weight to sample equally from each dataset
samples_weights.append(torch.full((len(ds),), 1.0 / len(ds)))
# Concatenate sample weights from all the datasets
samples_weights = torch.cat(samples_weights)
else:
# Weighting only for binary labels
if isinstance(dataset._samples[0].label, int):
# Groundtruth
targets = []
for s in dataset._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Count number of samples per class
class_sample_count = torch.tensor(
[
(targets == t).sum()
for t in torch.unique(targets, sorted=True)
]
)
weight = 1.0 / class_sample_count.float()
samples_weights = torch.tensor([weight[t] for t in targets])
else:
# Equal weights for non-binary labels
samples_weights = torch.ones(len(dataset._samples))
return samples_weights
def get_positive_weights(dataset):
"""Compute the positive weights of each class of the dataset to balance the
BCEWithLogitsLoss criterion.
This function takes as input a :py:class:`torch.utils.data.dataset.Dataset`
and computes the positive weights of each class to use them to have
a balanced loss.
Parameters
----------
dataset : torch.utils.data.dataset.Dataset
An instance of torch.utils.data.dataset.Dataset
ConcatDataset are supported
Returns
-------
positive_weights : :py:class:`torch.Tensor`
the positive weight of each class in the dataset given as input
"""
targets = []
if isinstance(dataset, torch.utils.data.ConcatDataset):
for ds in dataset.datasets:
for s in ds._samples:
targets.append(s.label)
else:
for s in dataset._samples:
targets.append(s.label)
targets = torch.tensor(targets)
# Binary labels
if len(list(targets.shape)) == 1:
class_sample_count = [
float((targets == t).sum().item())
for t in torch.unique(targets, sorted=True)
]
# Divide negatives by positives
positive_weights = torch.tensor(
[class_sample_count[0] / class_sample_count[1]]
).reshape(-1)
# Multiclass labels
else:
class_sample_count = torch.sum(targets, dim=0)
negative_class_sample_count = (
torch.full((targets.size()[1],), float(targets.size()[0]))
- class_sample_count
)
positive_weights = negative_class_sample_count / (
class_sample_count + negative_class_sample_count
)
return positive_weights
def return_subsets(dataset):
train_dataset = None
validation_dataset = None
extra_validation_datasets = None
predict_dataset = None
if "__train__" in dataset:
logger.info("Found (dedicated) '__train__' set for training")
train_dataset = dataset["__train__"]
else:
train_dataset = dataset["train"]
if "__valid__" in dataset:
logger.info("Found (dedicated) '__valid__' set for validation")
validation_dataset = dataset["__valid__"]
if "__extra_valid__" in dataset:
if not isinstance(dataset["__extra_valid__"], list):
raise RuntimeError(
f"If present, dataset['__extra_valid__'] must be a list, "
f"but you passed a {type(dataset['__extra_valid__'])}, "
f"which is invalid."
)
logger.info(
f"Found {len(dataset['__extra_valid__'])} extra validation "
f"set(s) to be tracked during training"
)
logger.info(
"Extra validation sets are NOT used for model checkpointing!"
)
extra_validation_datasets = dataset["__extra_valid__"]
else:
extra_validation_datasets = None
predict_dataset = dataset
return (
train_dataset,
validation_dataset,
extra_validation_datasets,
predict_dataset,
)
# SPDX-FileCopyrightText: Copyright © 2023 Idiap Research Institute <contact@idiap.ch>
#
# SPDX-License-Identifier: GPL-3.0-or-later
from clapper.logging import setup
from ..base_datamodule import BaseDataModule
from . import _maker, return_subsets
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
class DefaultModule(BaseDataModule):
def __init__(
self,
train_batch_size=1,
predict_batch_size=1,
drop_incomplete_batch=False,
multiproc_kwargs=None,
):
super().__init__(
train_batch_size=train_batch_size,
predict_batch_size=predict_batch_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
def setup(self, stage: str):
self.dataset = _maker("default")
(
self.train_dataset,
self.validation_dataset,
self.extra_validation_datasets,
self.predict_dataset,
) = return_subsets(self.dataset)
datamodule = DefaultModule
......@@ -8,7 +8,6 @@ from clapper.click import ConfigCommand, ResourceOption, verbosity_option
from clapper.logging import setup
from lightning.pytorch import seed_everything
from ..data.datamodule import DataModule
from ..utils.checkpointer import get_checkpoint
logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
......@@ -45,7 +44,7 @@ logger = setup(__name__.split(".")[0], format="%(levelname)s: %(message)s")
cls=ResourceOption,
)
@click.option(
"--dataset",
"--datamodule",
"-d",
help="A dictionary mapping string keys to "
"torch.utils.data.dataset.Dataset instances implementing datasets "
......@@ -233,7 +232,7 @@ def train(
drop_incomplete_batch,
criterion,
criterion_valid,
dataset,
datamodule,
checkpoint_period,
accelerator,
seed,
......@@ -290,9 +289,9 @@ def train(
else:
batch_chunk_size = batch_size // batch_chunk_count
datamodule = DataModule(
dataset,
datamodule = datamodule(
train_batch_size=batch_chunk_size,
drop_incomplete_batch=drop_incomplete_batch,
multiproc_kwargs=multiproc_kwargs,
)
# Manually calling these as we need to access some values to reweight the criterion
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment